diff options
Diffstat (limited to 'vendor/github.com/jackc/pgx/v4/extended_query_builder.go')
-rw-r--r-- | vendor/github.com/jackc/pgx/v4/extended_query_builder.go | 168 |
1 files changed, 168 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/v4/extended_query_builder.go b/vendor/github.com/jackc/pgx/v4/extended_query_builder.go new file mode 100644 index 000000000..09419f0d0 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/extended_query_builder.go @@ -0,0 +1,168 @@ +package pgx + +import ( + "database/sql/driver" + "fmt" + "reflect" + + "github.com/jackc/pgtype" +) + +type extendedQueryBuilder struct { + paramValues [][]byte + paramValueBytes []byte + paramFormats []int16 + resultFormats []int16 + + resetCount int +} + +func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error { + f := chooseParameterFormatCode(ci, oid, arg) + eqb.paramFormats = append(eqb.paramFormats, f) + + v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg) + if err != nil { + return err + } + eqb.paramValues = append(eqb.paramValues, v) + + return nil +} + +func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) { + eqb.resultFormats = append(eqb.resultFormats, f) +} + +func (eqb *extendedQueryBuilder) Reset() { + eqb.paramValues = eqb.paramValues[0:0] + eqb.paramValueBytes = eqb.paramValueBytes[0:0] + eqb.paramFormats = eqb.paramFormats[0:0] + eqb.resultFormats = eqb.resultFormats[0:0] + + eqb.resetCount++ + + // Every so often shrink our reserved memory if it is abnormally high + if eqb.resetCount%128 == 0 { + if cap(eqb.paramValues) > 64 { + eqb.paramValues = make([][]byte, 0, cap(eqb.paramValues)/2) + } + + if cap(eqb.paramValueBytes) > 256 { + eqb.paramValueBytes = make([]byte, 0, cap(eqb.paramValueBytes)/2) + } + + if cap(eqb.paramFormats) > 64 { + eqb.paramFormats = make([]int16, 0, cap(eqb.paramFormats)/2) + } + if cap(eqb.resultFormats) > 64 { + eqb.resultFormats = make([]int16, 0, cap(eqb.resultFormats)/2) + } + } + +} + +func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { + if arg == nil { + return nil, nil + } + + refVal := reflect.ValueOf(arg) + argIsPtr := refVal.Kind() == reflect.Ptr + + if argIsPtr && refVal.IsNil() { + return nil, nil + } + + if eqb.paramValueBytes == nil { + eqb.paramValueBytes = make([]byte, 0, 128) + } + + var err error + var buf []byte + pos := len(eqb.paramValueBytes) + + if arg, ok := arg.(string); ok { + return []byte(arg), nil + } + + if formatCode == TextFormatCode { + if arg, ok := arg.(pgtype.TextEncoder); ok { + buf, err = arg.EncodeText(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil + } + } else if formatCode == BinaryFormatCode { + if arg, ok := arg.(pgtype.BinaryEncoder); ok { + buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil + } + } + + if argIsPtr { + // We have already checked that arg is not pointing to nil, + // so it is safe to dereference here. + arg = refVal.Elem().Interface() + return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) + } + + if dt, ok := ci.DataTypeForOID(oid); ok { + value := dt.Value + err := value.Set(arg) + if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) + } + } + + return nil, err + } + + return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) + } + + // There is no data type registered for the destination OID, but maybe there is data type registered for the arg + // type. If so use it's text encoder (if available). + if dt, ok := ci.DataTypeForValue(arg); ok { + value := dt.Value + if textEncoder, ok := value.(pgtype.TextEncoder); ok { + err := value.Set(arg) + if err != nil { + return nil, err + } + + buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil + } + } + + if strippedArg, ok := stripNamedType(&refVal); ok { + return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg) + } + return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) +} |