diff options
Diffstat (limited to 'vendor/github.com/jackc/pgx/v4/values.go')
-rw-r--r-- | vendor/github.com/jackc/pgx/v4/values.go | 280 |
1 files changed, 280 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/v4/values.go b/vendor/github.com/jackc/pgx/v4/values.go new file mode 100644 index 000000000..1a9454753 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/values.go @@ -0,0 +1,280 @@ +package pgx + +import ( + "database/sql/driver" + "fmt" + "math" + "reflect" + "time" + + "github.com/jackc/pgio" + "github.com/jackc/pgtype" +) + +// PostgreSQL format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + +// SerializationError occurs on failure to encode or decode a value +type SerializationError string + +func (e SerializationError) Error() string { + return string(e) +} + +func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) { + if arg == nil { + return nil, nil + } + + refVal := reflect.ValueOf(arg) + if refVal.Kind() == reflect.Ptr && refVal.IsNil() { + return nil, nil + } + + switch arg := arg.(type) { + + // https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface + // []byte to database/sql instead of string. But that caused problems with the + // simple protocol because the driver.Valuer case got taken before the + // pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual + // case because of https://github.com/jackc/pgx/issues/339. So instead we + // special case JSON and JSONB. + case *pgtype.JSON: + buf, err := arg.EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + case *pgtype.JSONB: + buf, err := arg.EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + + case driver.Valuer: + return callValuerValue(arg) + case pgtype.TextEncoder: + buf, err := arg.EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + case float32: + return float64(arg), nil + case float64: + return arg, nil + case bool: + return arg, nil + case time.Duration: + return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil + case time.Time: + return arg, nil + case string: + return arg, nil + case []byte: + return arg, nil + case int8: + return int64(arg), nil + case int16: + return int64(arg), nil + case int32: + return int64(arg), nil + case int64: + return arg, nil + case int: + return int64(arg), nil + case uint8: + return int64(arg), nil + case uint16: + return int64(arg), nil + case uint32: + return int64(arg), nil + case uint64: + if arg > math.MaxInt64 { + return nil, fmt.Errorf("arg too big for int64: %v", arg) + } + return int64(arg), nil + case uint: + if uint64(arg) > math.MaxInt64 { + return nil, fmt.Errorf("arg too big for int64: %v", arg) + } + return int64(arg), nil + } + + if dt, found := ci.DataTypeForValue(arg); found { + v := dt.Value + err := v.Set(arg) + if err != nil { + return nil, err + } + buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + } + + if refVal.Kind() == reflect.Ptr { + arg = refVal.Elem().Interface() + return convertSimpleArgument(ci, arg) + } + + if strippedArg, ok := stripNamedType(&refVal); ok { + return convertSimpleArgument(ci, strippedArg) + } + return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) +} + +func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32, arg interface{}) ([]byte, error) { + if arg == nil { + return pgio.AppendInt32(buf, -1), nil + } + + switch arg := arg.(type) { + case pgtype.BinaryEncoder: + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := arg.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + case pgtype.TextEncoder: + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := arg.EncodeText(ci, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + case string: + buf = pgio.AppendInt32(buf, int32(len(arg))) + buf = append(buf, arg...) + return buf, nil + } + + refVal := reflect.ValueOf(arg) + + if refVal.Kind() == reflect.Ptr { + if refVal.IsNil() { + return pgio.AppendInt32(buf, -1), nil + } + arg = refVal.Elem().Interface() + return encodePreparedStatementArgument(ci, buf, oid, arg) + } + + if dt, ok := ci.DataTypeForOID(oid); ok { + value := dt.Value + err := value.Set(arg) + if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + return encodePreparedStatementArgument(ci, buf, oid, v) + } + } + + return nil, err + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + } + + if strippedArg, ok := stripNamedType(&refVal); ok { + return encodePreparedStatementArgument(ci, buf, oid, strippedArg) + } + return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) +} + +// chooseParameterFormatCode determines the correct format code for an +// argument to a prepared statement. It defaults to TextFormatCode if no +// determination can be made. +func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 { + switch arg := arg.(type) { + case pgtype.ParamFormatPreferrer: + return arg.PreferredParamFormat() + case pgtype.BinaryEncoder: + return BinaryFormatCode + case string, *string, pgtype.TextEncoder: + return TextFormatCode + } + + return ci.ParamFormatCodeForOID(oid) +} + +func stripNamedType(val *reflect.Value) (interface{}, bool) { + switch val.Kind() { + case reflect.Int: + convVal := int(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Int8: + convVal := int8(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Int16: + convVal := int16(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Int32: + convVal := int32(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Int64: + convVal := int64(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Uint: + convVal := uint(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Uint8: + convVal := uint8(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Uint16: + convVal := uint16(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Uint32: + convVal := uint32(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Uint64: + convVal := uint64(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.String: + convVal := val.String() + return convVal, reflect.TypeOf(convVal) != val.Type() + } + + return nil, false +} |