summaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/rpc_util.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/rpc_util.go')
-rw-r--r--vendor/google.golang.org/grpc/rpc_util.go299
1 files changed, 193 insertions, 106 deletions
diff --git a/vendor/google.golang.org/grpc/rpc_util.go b/vendor/google.golang.org/grpc/rpc_util.go
index fdd49e6e9..db8865ec3 100644
--- a/vendor/google.golang.org/grpc/rpc_util.go
+++ b/vendor/google.golang.org/grpc/rpc_util.go
@@ -19,7 +19,6 @@
package grpc
import (
- "bytes"
"compress/gzip"
"context"
"encoding/binary"
@@ -35,6 +34,7 @@ import (
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal/transport"
+ "google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@@ -271,17 +271,13 @@ func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) {
}
}
-// WaitForReady configures the action to take when an RPC is attempted on broken
-// connections or unreachable servers. If waitForReady is false and the
-// connection is in the TRANSIENT_FAILURE state, the RPC will fail
-// immediately. Otherwise, the RPC client will block the call until a
-// connection is available (or the call is canceled or times out) and will
-// retry the call if it fails due to a transient error. gRPC will not retry if
-// data was written to the wire unless the server indicates it did not process
-// the data. Please refer to
-// https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md.
+// WaitForReady configures the RPC's behavior when the client is in
+// TRANSIENT_FAILURE, which occurs when all addresses fail to connect. If
+// waitForReady is false, the RPC will fail immediately. Otherwise, the client
+// will wait until a connection becomes available or the RPC's deadline is
+// reached.
//
-// By default, RPCs don't "wait for ready".
+// By default, RPCs do not "wait for ready".
func WaitForReady(waitForReady bool) CallOption {
return FailFastCallOption{FailFast: !waitForReady}
}
@@ -515,11 +511,51 @@ type ForceCodecCallOption struct {
}
func (o ForceCodecCallOption) before(c *callInfo) error {
- c.codec = o.Codec
+ c.codec = newCodecV1Bridge(o.Codec)
return nil
}
func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
+// ForceCodecV2 returns a CallOption that will set codec to be used for all
+// request and response messages for a call. The result of calling Name() will
+// be used as the content-subtype after converting to lowercase, unless
+// CallContentSubtype is also used.
+//
+// See Content-Type on
+// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
+// more details. Also see the documentation on RegisterCodec and
+// CallContentSubtype for more details on the interaction between Codec and
+// content-subtype.
+//
+// This function is provided for advanced users; prefer to use only
+// CallContentSubtype to select a registered codec instead.
+//
+// # Experimental
+//
+// Notice: This API is EXPERIMENTAL and may be changed or removed in a
+// later release.
+func ForceCodecV2(codec encoding.CodecV2) CallOption {
+ return ForceCodecV2CallOption{CodecV2: codec}
+}
+
+// ForceCodecV2CallOption is a CallOption that indicates the codec used for
+// marshaling messages.
+//
+// # Experimental
+//
+// Notice: This type is EXPERIMENTAL and may be changed or removed in a
+// later release.
+type ForceCodecV2CallOption struct {
+ CodecV2 encoding.CodecV2
+}
+
+func (o ForceCodecV2CallOption) before(c *callInfo) error {
+ c.codec = o.CodecV2
+ return nil
+}
+
+func (o ForceCodecV2CallOption) after(c *callInfo, attempt *csAttempt) {}
+
// CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of
// an encoding.Codec.
//
@@ -540,7 +576,7 @@ type CustomCodecCallOption struct {
}
func (o CustomCodecCallOption) before(c *callInfo) error {
- c.codec = o.Codec
+ c.codec = newCodecV0Bridge(o.Codec)
return nil
}
func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
@@ -581,19 +617,28 @@ const (
compressionMade payloadFormat = 1 // compressed
)
+func (pf payloadFormat) isCompressed() bool {
+ return pf == compressionMade
+}
+
+type streamReader interface {
+ ReadHeader(header []byte) error
+ Read(n int) (mem.BufferSlice, error)
+}
+
// parser reads complete gRPC messages from the underlying reader.
type parser struct {
// r is the underlying reader.
// See the comment on recvMsg for the permissible
// error types.
- r io.Reader
+ r streamReader
// The header of a gRPC message. Find more detail at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
header [5]byte
- // recvBufferPool is the pool of shared receive buffers.
- recvBufferPool SharedBufferPool
+ // bufferPool is the pool of shared receive buffers.
+ bufferPool mem.BufferPool
}
// recvMsg reads a complete gRPC message from the stream.
@@ -608,14 +653,15 @@ type parser struct {
// - an error from the status package
//
// No other error values or types must be returned, which also means
-// that the underlying io.Reader must not return an incompatible
+// that the underlying streamReader must not return an incompatible
// error.
-func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
- if _, err := p.r.Read(p.header[:]); err != nil {
+func (p *parser) recvMsg(maxReceiveMessageSize int) (payloadFormat, mem.BufferSlice, error) {
+ err := p.r.ReadHeader(p.header[:])
+ if err != nil {
return 0, nil, err
}
- pf = payloadFormat(p.header[0])
+ pf := payloadFormat(p.header[0])
length := binary.BigEndian.Uint32(p.header[1:])
if length == 0 {
@@ -627,20 +673,21 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
if int(length) > maxReceiveMessageSize {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
}
- msg = p.recvBufferPool.Get(int(length))
- if _, err := p.r.Read(msg); err != nil {
+
+ data, err := p.r.Read(int(length))
+ if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, nil, err
}
- return pf, msg, nil
+ return pf, data, nil
}
// encode serializes msg and returns a buffer containing the message, or an
// error if it is too large to be transmitted by grpc. If msg is nil, it
// generates an empty message.
-func encode(c baseCodec, msg any) ([]byte, error) {
+func encode(c baseCodec, msg any) (mem.BufferSlice, error) {
if msg == nil { // NOTE: typed nils will not be caught by this check
return nil, nil
}
@@ -648,7 +695,8 @@ func encode(c baseCodec, msg any) ([]byte, error) {
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
}
- if uint(len(b)) > math.MaxUint32 {
+ if uint(b.Len()) > math.MaxUint32 {
+ b.Free()
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
}
return b, nil
@@ -659,34 +707,41 @@ func encode(c baseCodec, msg any) ([]byte, error) {
// indicating no compression was done.
//
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
-func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) {
- if compressor == nil && cp == nil {
- return nil, nil
- }
- if len(in) == 0 {
- return nil, nil
+func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool) (mem.BufferSlice, payloadFormat, error) {
+ if (compressor == nil && cp == nil) || in.Len() == 0 {
+ return nil, compressionNone, nil
}
+ var out mem.BufferSlice
+ w := mem.NewWriter(&out, pool)
wrapErr := func(err error) error {
+ out.Free()
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
- cbuf := &bytes.Buffer{}
if compressor != nil {
- z, err := compressor.Compress(cbuf)
+ z, err := compressor.Compress(w)
if err != nil {
- return nil, wrapErr(err)
+ return nil, 0, wrapErr(err)
}
- if _, err := z.Write(in); err != nil {
- return nil, wrapErr(err)
+ for _, b := range in {
+ if _, err := z.Write(b.ReadOnlyData()); err != nil {
+ return nil, 0, wrapErr(err)
+ }
}
if err := z.Close(); err != nil {
- return nil, wrapErr(err)
+ return nil, 0, wrapErr(err)
}
} else {
- if err := cp.Do(cbuf, in); err != nil {
- return nil, wrapErr(err)
+ // This is obviously really inefficient since it fully materializes the data, but
+ // there is no way around this with the old Compressor API. At least it attempts
+ // to return the buffer to the provider, in the hopes it can be reused (maybe
+ // even by a subsequent call to this very function).
+ buf := in.MaterializeToBuffer(pool)
+ defer buf.Free()
+ if err := cp.Do(w, buf.ReadOnlyData()); err != nil {
+ return nil, 0, wrapErr(err)
}
}
- return cbuf.Bytes(), nil
+ return out, compressionMade, nil
}
const (
@@ -697,33 +752,36 @@ const (
// msgHeader returns a 5-byte header for the message being transmitted and the
// payload, which is compData if non-nil or data otherwise.
-func msgHeader(data, compData []byte) (hdr []byte, payload []byte) {
+func msgHeader(data, compData mem.BufferSlice, pf payloadFormat) (hdr []byte, payload mem.BufferSlice) {
hdr = make([]byte, headerLen)
- if compData != nil {
- hdr[0] = byte(compressionMade)
- data = compData
+ hdr[0] = byte(pf)
+
+ var length uint32
+ if pf.isCompressed() {
+ length = uint32(compData.Len())
+ payload = compData
} else {
- hdr[0] = byte(compressionNone)
+ length = uint32(data.Len())
+ payload = data
}
// Write length of payload into buf
- binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data)))
- return hdr, data
+ binary.BigEndian.PutUint32(hdr[payloadLen:], length)
+ return hdr, payload
}
-func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats.OutPayload {
+func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time) *stats.OutPayload {
return &stats.OutPayload{
Client: client,
Payload: msg,
- Data: data,
- Length: len(data),
- WireLength: len(payload) + headerLen,
- CompressedLength: len(payload),
+ Length: dataLength,
+ WireLength: payloadLength + headerLen,
+ CompressedLength: payloadLength,
SentTime: t,
}
}
-func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
+func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
switch pf {
case compressionNone:
case compressionMade:
@@ -731,7 +789,11 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
}
if !haveCompressor {
- return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
+ if isServer {
+ return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
+ } else {
+ return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
+ }
}
default:
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
@@ -741,104 +803,129 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
type payloadInfo struct {
compressedLength int // The compressed length got from wire.
- uncompressedBytes []byte
+ uncompressedBytes mem.BufferSlice
+}
+
+func (p *payloadInfo) free() {
+ if p != nil && p.uncompressedBytes != nil {
+ p.uncompressedBytes.Free()
+ }
}
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
//
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed.
-func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
-) (uncompressedBuf []byte, cancel func(), err error) {
- pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
+// TODO: Refactor this function to reduce the number of arguments.
+// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
+func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
+) (out mem.BufferSlice, err error) {
+ pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
- return nil, nil, err
+ return nil, err
}
- if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
- return nil, nil, st.Err()
+ compressedLength := compressed.Len()
+
+ if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
+ compressed.Free()
+ return nil, st.Err()
}
var size int
- if pf == compressionMade {
+ if pf.isCompressed() {
+ defer compressed.Free()
+
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
- uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf))
+ var uncompressedBuf []byte
+ uncompressedBuf, err = dc.Do(compressed.Reader())
+ if err == nil {
+ out = mem.BufferSlice{mem.NewBuffer(&uncompressedBuf, nil)}
+ }
size = len(uncompressedBuf)
} else {
- uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize)
+ out, size, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool)
}
if err != nil {
- return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
+ return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
}
if size > maxReceiveMessageSize {
+ out.Free()
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
- return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
+ return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
}
} else {
- uncompressedBuf = compressedBuf
+ out = compressed
}
if payInfo != nil {
- payInfo.compressedLength = len(compressedBuf)
- payInfo.uncompressedBytes = uncompressedBuf
-
- cancel = func() {}
- } else {
- cancel = func() {
- p.recvBufferPool.Put(&compressedBuf)
- }
+ payInfo.compressedLength = compressedLength
+ out.Ref()
+ payInfo.uncompressedBytes = out
}
- return uncompressedBuf, cancel, nil
+ return out, nil
}
// Using compressor, decompress d, returning data and size.
// Optionally, if data will be over maxReceiveMessageSize, just return the size.
-func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize int) ([]byte, int, error) {
- dcReader, err := compressor.Decompress(bytes.NewReader(d))
+func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, int, error) {
+ dcReader, err := compressor.Decompress(d.Reader())
if err != nil {
return nil, 0, err
}
- if sizer, ok := compressor.(interface {
- DecompressedSize(compressedBytes []byte) int
- }); ok {
- if size := sizer.DecompressedSize(d); size >= 0 {
- if size > maxReceiveMessageSize {
- return nil, size, nil
- }
- // size is used as an estimate to size the buffer, but we
- // will read more data if available.
- // +MinRead so ReadFrom will not reallocate if size is correct.
- //
- // TODO: If we ensure that the buffer size is the same as the DecompressedSize,
- // we can also utilize the recv buffer pool here.
- buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
- bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
- return buf.Bytes(), int(bytesRead), err
- }
+
+ // TODO: Can/should this still be preserved with the new BufferSlice API? Are
+ // there any actual benefits to allocating a single large buffer instead of
+ // multiple smaller ones?
+ //if sizer, ok := compressor.(interface {
+ // DecompressedSize(compressedBytes []byte) int
+ //}); ok {
+ // if size := sizer.DecompressedSize(d); size >= 0 {
+ // if size > maxReceiveMessageSize {
+ // return nil, size, nil
+ // }
+ // // size is used as an estimate to size the buffer, but we
+ // // will read more data if available.
+ // // +MinRead so ReadFrom will not reallocate if size is correct.
+ // //
+ // // TODO: If we ensure that the buffer size is the same as the DecompressedSize,
+ // // we can also utilize the recv buffer pool here.
+ // buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
+ // bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
+ // return buf.Bytes(), int(bytesRead), err
+ // }
+ //}
+
+ var out mem.BufferSlice
+ _, err = io.Copy(mem.NewWriter(&out, pool), io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
+ if err != nil {
+ out.Free()
+ return nil, 0, err
}
- // Read from LimitReader with limit max+1. So if the underlying
- // reader is over limit, the result will be bigger than max.
- d, err = io.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
- return d, len(d), err
+ return out, out.Len(), nil
}
// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
-func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
- buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
+func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
+ data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
if err != nil {
return err
}
- defer cancel()
- if err := c.Unmarshal(buf, m); err != nil {
+ // If the codec wants its own reference to the data, it can get it. Otherwise, always
+ // free the buffers.
+ defer data.Free()
+
+ if err := c.Unmarshal(data, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
}
+
return nil
}
@@ -941,7 +1028,7 @@ func setCallInfoCodec(c *callInfo) error {
// encoding.Codec (Name vs. String method name). We only support
// setting content subtype from encoding.Codec to avoid a behavior
// change with the deprecated version.
- if ec, ok := c.codec.(encoding.Codec); ok {
+ if ec, ok := c.codec.(encoding.CodecV2); ok {
c.contentSubtype = strings.ToLower(ec.Name())
}
}
@@ -950,12 +1037,12 @@ func setCallInfoCodec(c *callInfo) error {
if c.contentSubtype == "" {
// No codec specified in CallOptions; use proto by default.
- c.codec = encoding.GetCodec(proto.Name)
+ c.codec = getCodec(proto.Name)
return nil
}
// c.contentSubtype is already lowercased in CallContentSubtype
- c.codec = encoding.GetCodec(c.contentSubtype)
+ c.codec = getCodec(c.contentSubtype)
if c.codec == nil {
return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype)
}