summaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/pgproto3/bind.go')
-rw-r--r--vendor/github.com/jackc/pgx/v5/pgproto3/bind.go21
1 files changed, 14 insertions, 7 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go b/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go
index fdd2d3b81..ad6ac48bf 100644
--- a/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go
+++ b/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go
@@ -5,7 +5,9 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
+ "errors"
"fmt"
+ "math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
-func (src *Bind) Encode(dst []byte) []byte {
- dst = append(dst, 'B')
- sp := len(dst)
- dst = pgio.AppendInt32(dst, -1)
+func (src *Bind) Encode(dst []byte) ([]byte, error) {
+ dst, sp := beginMessage(dst, 'B')
dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)
+ if len(src.ParameterFormatCodes) > math.MaxUint16 {
+ return nil, errors.New("too many parameter format codes")
+ }
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
+ if len(src.Parameters) > math.MaxUint16 {
+ return nil, errors.New("too many parameters")
+ }
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
@@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, p...)
}
+ if len(src.ResultFormatCodes) > math.MaxUint16 {
+ return nil, errors.New("too many result format codes")
+ }
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
- pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
-
- return dst
+ return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.