diff options
author | 2024-03-06 09:05:45 -0800 | |
---|---|---|
committer | 2024-03-06 18:05:45 +0100 | |
commit | fc3741365c27f1d703e8a736af95b95ff811cc45 (patch) | |
tree | 929f1d5e20d1469d63a3dfe81d38d89f9a073c5a /vendor/go.mongodb.org/mongo-driver/bson/bsonrw | |
parent | [chore/bugfix] Little DB fixes (#2726) (diff) | |
download | gotosocial-fc3741365c27f1d703e8a736af95b95ff811cc45.tar.xz |
[bugfix] Fix Swagger spec and add test script (#2698)
* Add Swagger spec test script
* Fix Swagger spec errors not related to statuses with polls
* Add API tests that post a status with a poll
* Fix creating a status with a poll from form params
* Fix Swagger spec errors related to statuses with polls (this is the last error)
* Fix Swagger spec warnings not related to unused definitions
* Suppress a duplicate list update params definition that was somehow causing wrong param names
* Add Swagger test to CI
- updates Drone config
- vendorizes go-swagger
- fixes a file extension issue that caused the test script to generate JSON instead of YAML with the vendorized version
* Put `Sample: ` on its own line everywhere
* Remove unused id param from emojiCategoriesGet
* Add 5 more pairs of profile fields to account update API Swagger
* Remove Swagger prefix from dummy fields
It makes the generated code look weird
* Manually annotate params for statusCreate operation
* Fix all remaining Swagger spec warnings
- Change some models into operation parameters
- Ignore models that already correspond to manually documented operation parameters but can't be trivially changed (those with file fields)
* Documented that creating a status with scheduled_at isn't implemented yet
* sign drone.yml
* Fix filter API Swagger errors
* fixup! Fix filter API Swagger errors
---------
Co-authored-by: tobi <tobi.smethurst@protonmail.com>
Diffstat (limited to 'vendor/go.mongodb.org/mongo-driver/bson/bsonrw')
13 files changed, 5608 insertions, 0 deletions
diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go new file mode 100644 index 000000000..5cdf6460b --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go @@ -0,0 +1,445 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "fmt" + "io" + + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +// Copier is a type that allows copying between ValueReaders, ValueWriters, and +// []byte values. +type Copier struct{} + +// NewCopier creates a new copier with the given registry. If a nil registry is provided +// a default registry is used. +func NewCopier() Copier { + return Copier{} +} + +// CopyDocument handles copying a document from src to dst. +func CopyDocument(dst ValueWriter, src ValueReader) error { + return Copier{}.CopyDocument(dst, src) +} + +// CopyDocument handles copying one document from the src to the dst. +func (c Copier) CopyDocument(dst ValueWriter, src ValueReader) error { + dr, err := src.ReadDocument() + if err != nil { + return err + } + + dw, err := dst.WriteDocument() + if err != nil { + return err + } + + return c.copyDocumentCore(dw, dr) +} + +// CopyArrayFromBytes copies the values from a BSON array represented as a +// []byte to a ValueWriter. +func (c Copier) CopyArrayFromBytes(dst ValueWriter, src []byte) error { + aw, err := dst.WriteArray() + if err != nil { + return err + } + + err = c.CopyBytesToArrayWriter(aw, src) + if err != nil { + return err + } + + return aw.WriteArrayEnd() +} + +// CopyDocumentFromBytes copies the values from a BSON document represented as a +// []byte to a ValueWriter. +func (c Copier) CopyDocumentFromBytes(dst ValueWriter, src []byte) error { + dw, err := dst.WriteDocument() + if err != nil { + return err + } + + err = c.CopyBytesToDocumentWriter(dw, src) + if err != nil { + return err + } + + return dw.WriteDocumentEnd() +} + +type writeElementFn func(key string) (ValueWriter, error) + +// CopyBytesToArrayWriter copies the values from a BSON Array represented as a []byte to an +// ArrayWriter. +func (c Copier) CopyBytesToArrayWriter(dst ArrayWriter, src []byte) error { + wef := func(_ string) (ValueWriter, error) { + return dst.WriteArrayElement() + } + + return c.copyBytesToValueWriter(src, wef) +} + +// CopyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a +// DocumentWriter. +func (c Copier) CopyBytesToDocumentWriter(dst DocumentWriter, src []byte) error { + wef := func(key string) (ValueWriter, error) { + return dst.WriteDocumentElement(key) + } + + return c.copyBytesToValueWriter(src, wef) +} + +func (c Copier) copyBytesToValueWriter(src []byte, wef writeElementFn) error { + // TODO(skriptble): Create errors types here. Anything thats a tag should be a property. + length, rem, ok := bsoncore.ReadLength(src) + if !ok { + return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src)) + } + if len(src) < int(length) { + return fmt.Errorf("length read exceeds number of bytes available. length=%d bytes=%d", len(src), length) + } + rem = rem[:length-4] + + var t bsontype.Type + var key string + var val bsoncore.Value + for { + t, rem, ok = bsoncore.ReadType(rem) + if !ok { + return io.EOF + } + if t == bsontype.Type(0) { + if len(rem) != 0 { + return fmt.Errorf("document end byte found before end of document. remaining bytes=%v", rem) + } + break + } + + key, rem, ok = bsoncore.ReadKey(rem) + if !ok { + return fmt.Errorf("invalid key found. remaining bytes=%v", rem) + } + + // write as either array element or document element using writeElementFn + vw, err := wef(key) + if err != nil { + return err + } + + val, rem, ok = bsoncore.ReadValue(rem, t) + if !ok { + return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t) + } + err = c.CopyValueFromBytes(vw, t, val.Data) + if err != nil { + return err + } + } + return nil +} + +// CopyDocumentToBytes copies an entire document from the ValueReader and +// returns it as bytes. +func (c Copier) CopyDocumentToBytes(src ValueReader) ([]byte, error) { + return c.AppendDocumentBytes(nil, src) +} + +// AppendDocumentBytes functions the same as CopyDocumentToBytes, but will +// append the result to dst. +func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) { + if br, ok := src.(BytesReader); ok { + _, dst, err := br.ReadValueBytes(dst) + return dst, err + } + + vw := vwPool.Get().(*valueWriter) + defer vwPool.Put(vw) + + vw.reset(dst) + + err := c.CopyDocument(vw, src) + dst = vw.buf + return dst, err +} + +// AppendArrayBytes copies an array from the ValueReader to dst. +func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) { + if br, ok := src.(BytesReader); ok { + _, dst, err := br.ReadValueBytes(dst) + return dst, err + } + + vw := vwPool.Get().(*valueWriter) + defer vwPool.Put(vw) + + vw.reset(dst) + + err := c.copyArray(vw, src) + dst = vw.buf + return dst, err +} + +// CopyValueFromBytes will write the value represtend by t and src to dst. +func (c Copier) CopyValueFromBytes(dst ValueWriter, t bsontype.Type, src []byte) error { + if wvb, ok := dst.(BytesWriter); ok { + return wvb.WriteValueBytes(t, src) + } + + vr := vrPool.Get().(*valueReader) + defer vrPool.Put(vr) + + vr.reset(src) + vr.pushElement(t) + + return c.CopyValue(dst, vr) +} + +// CopyValueToBytes copies a value from src and returns it as a bsontype.Type and a +// []byte. +func (c Copier) CopyValueToBytes(src ValueReader) (bsontype.Type, []byte, error) { + return c.AppendValueBytes(nil, src) +} + +// AppendValueBytes functions the same as CopyValueToBytes, but will append the +// result to dst. +func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, []byte, error) { + if br, ok := src.(BytesReader); ok { + return br.ReadValueBytes(dst) + } + + vw := vwPool.Get().(*valueWriter) + defer vwPool.Put(vw) + + start := len(dst) + + vw.reset(dst) + vw.push(mElement) + + err := c.CopyValue(vw, src) + if err != nil { + return 0, dst, err + } + + return bsontype.Type(vw.buf[start]), vw.buf[start+2:], nil +} + +// CopyValue will copy a single value from src to dst. +func (c Copier) CopyValue(dst ValueWriter, src ValueReader) error { + var err error + switch src.Type() { + case bsontype.Double: + var f64 float64 + f64, err = src.ReadDouble() + if err != nil { + break + } + err = dst.WriteDouble(f64) + case bsontype.String: + var str string + str, err = src.ReadString() + if err != nil { + return err + } + err = dst.WriteString(str) + case bsontype.EmbeddedDocument: + err = c.CopyDocument(dst, src) + case bsontype.Array: + err = c.copyArray(dst, src) + case bsontype.Binary: + var data []byte + var subtype byte + data, subtype, err = src.ReadBinary() + if err != nil { + break + } + err = dst.WriteBinaryWithSubtype(data, subtype) + case bsontype.Undefined: + err = src.ReadUndefined() + if err != nil { + break + } + err = dst.WriteUndefined() + case bsontype.ObjectID: + var oid primitive.ObjectID + oid, err = src.ReadObjectID() + if err != nil { + break + } + err = dst.WriteObjectID(oid) + case bsontype.Boolean: + var b bool + b, err = src.ReadBoolean() + if err != nil { + break + } + err = dst.WriteBoolean(b) + case bsontype.DateTime: + var dt int64 + dt, err = src.ReadDateTime() + if err != nil { + break + } + err = dst.WriteDateTime(dt) + case bsontype.Null: + err = src.ReadNull() + if err != nil { + break + } + err = dst.WriteNull() + case bsontype.Regex: + var pattern, options string + pattern, options, err = src.ReadRegex() + if err != nil { + break + } + err = dst.WriteRegex(pattern, options) + case bsontype.DBPointer: + var ns string + var pointer primitive.ObjectID + ns, pointer, err = src.ReadDBPointer() + if err != nil { + break + } + err = dst.WriteDBPointer(ns, pointer) + case bsontype.JavaScript: + var js string + js, err = src.ReadJavascript() + if err != nil { + break + } + err = dst.WriteJavascript(js) + case bsontype.Symbol: + var symbol string + symbol, err = src.ReadSymbol() + if err != nil { + break + } + err = dst.WriteSymbol(symbol) + case bsontype.CodeWithScope: + var code string + var srcScope DocumentReader + code, srcScope, err = src.ReadCodeWithScope() + if err != nil { + break + } + + var dstScope DocumentWriter + dstScope, err = dst.WriteCodeWithScope(code) + if err != nil { + break + } + err = c.copyDocumentCore(dstScope, srcScope) + case bsontype.Int32: + var i32 int32 + i32, err = src.ReadInt32() + if err != nil { + break + } + err = dst.WriteInt32(i32) + case bsontype.Timestamp: + var t, i uint32 + t, i, err = src.ReadTimestamp() + if err != nil { + break + } + err = dst.WriteTimestamp(t, i) + case bsontype.Int64: + var i64 int64 + i64, err = src.ReadInt64() + if err != nil { + break + } + err = dst.WriteInt64(i64) + case bsontype.Decimal128: + var d128 primitive.Decimal128 + d128, err = src.ReadDecimal128() + if err != nil { + break + } + err = dst.WriteDecimal128(d128) + case bsontype.MinKey: + err = src.ReadMinKey() + if err != nil { + break + } + err = dst.WriteMinKey() + case bsontype.MaxKey: + err = src.ReadMaxKey() + if err != nil { + break + } + err = dst.WriteMaxKey() + default: + err = fmt.Errorf("Cannot copy unknown BSON type %s", src.Type()) + } + + return err +} + +func (c Copier) copyArray(dst ValueWriter, src ValueReader) error { + ar, err := src.ReadArray() + if err != nil { + return err + } + + aw, err := dst.WriteArray() + if err != nil { + return err + } + + for { + vr, err := ar.ReadValue() + if err == ErrEOA { + break + } + if err != nil { + return err + } + + vw, err := aw.WriteArrayElement() + if err != nil { + return err + } + + err = c.CopyValue(vw, vr) + if err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (c Copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error { + for { + key, vr, err := dr.ReadElement() + if err == ErrEOD { + break + } + if err != nil { + return err + } + + vw, err := dw.WriteDocumentElement(key) + if err != nil { + return err + } + + err = c.CopyValue(vw, vr) + if err != nil { + return err + } + } + + return dw.WriteDocumentEnd() +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/doc.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/doc.go new file mode 100644 index 000000000..750b0d2af --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/doc.go @@ -0,0 +1,9 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// Package bsonrw contains abstractions for reading and writing +// BSON and BSON like types from sources. +package bsonrw // import "go.mongodb.org/mongo-driver/bson/bsonrw" diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go new file mode 100644 index 000000000..54c76bf74 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go @@ -0,0 +1,806 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" + + "go.mongodb.org/mongo-driver/bson/bsontype" +) + +const maxNestingDepth = 200 + +// ErrInvalidJSON indicates the JSON input is invalid +var ErrInvalidJSON = errors.New("invalid JSON input") + +type jsonParseState byte + +const ( + jpsStartState jsonParseState = iota + jpsSawBeginObject + jpsSawEndObject + jpsSawBeginArray + jpsSawEndArray + jpsSawColon + jpsSawComma + jpsSawKey + jpsSawValue + jpsDoneState + jpsInvalidState +) + +type jsonParseMode byte + +const ( + jpmInvalidMode jsonParseMode = iota + jpmObjectMode + jpmArrayMode +) + +type extJSONValue struct { + t bsontype.Type + v interface{} +} + +type extJSONObject struct { + keys []string + values []*extJSONValue +} + +type extJSONParser struct { + js *jsonScanner + s jsonParseState + m []jsonParseMode + k string + v *extJSONValue + + err error + canonical bool + depth int + maxDepth int + + emptyObject bool + relaxedUUID bool +} + +// newExtJSONParser returns a new extended JSON parser, ready to to begin +// parsing from the first character of the argued json input. It will not +// perform any read-ahead and will therefore not report any errors about +// malformed JSON at this point. +func newExtJSONParser(r io.Reader, canonical bool) *extJSONParser { + return &extJSONParser{ + js: &jsonScanner{r: r}, + s: jpsStartState, + m: []jsonParseMode{}, + canonical: canonical, + maxDepth: maxNestingDepth, + } +} + +// peekType examines the next value and returns its BSON Type +func (ejp *extJSONParser) peekType() (bsontype.Type, error) { + var t bsontype.Type + var err error + initialState := ejp.s + + ejp.advanceState() + switch ejp.s { + case jpsSawValue: + t = ejp.v.t + case jpsSawBeginArray: + t = bsontype.Array + case jpsInvalidState: + err = ejp.err + case jpsSawComma: + // in array mode, seeing a comma means we need to progress again to actually observe a type + if ejp.peekMode() == jpmArrayMode { + return ejp.peekType() + } + case jpsSawEndArray: + // this would only be a valid state if we were in array mode, so return end-of-array error + err = ErrEOA + case jpsSawBeginObject: + // peek key to determine type + ejp.advanceState() + switch ejp.s { + case jpsSawEndObject: // empty embedded document + t = bsontype.EmbeddedDocument + ejp.emptyObject = true + case jpsInvalidState: + err = ejp.err + case jpsSawKey: + if initialState == jpsStartState { + return bsontype.EmbeddedDocument, nil + } + t = wrapperKeyBSONType(ejp.k) + + // if $uuid is encountered, parse as binary subtype 4 + if ejp.k == "$uuid" { + ejp.relaxedUUID = true + t = bsontype.Binary + } + + switch t { + case bsontype.JavaScript: + // just saw $code, need to check for $scope at same level + _, err = ejp.readValue(bsontype.JavaScript) + if err != nil { + break + } + + switch ejp.s { + case jpsSawEndObject: // type is TypeJavaScript + case jpsSawComma: + ejp.advanceState() + + if ejp.s == jpsSawKey && ejp.k == "$scope" { + t = bsontype.CodeWithScope + } else { + err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k) + } + case jpsInvalidState: + err = ejp.err + default: + err = ErrInvalidJSON + } + case bsontype.CodeWithScope: + err = errors.New("invalid extended JSON: code with $scope must contain $code before $scope") + } + } + } + + return t, err +} + +// readKey parses the next key and its type and returns them +func (ejp *extJSONParser) readKey() (string, bsontype.Type, error) { + if ejp.emptyObject { + ejp.emptyObject = false + return "", 0, ErrEOD + } + + // advance to key (or return with error) + switch ejp.s { + case jpsStartState: + ejp.advanceState() + if ejp.s == jpsSawBeginObject { + ejp.advanceState() + } + case jpsSawBeginObject: + ejp.advanceState() + case jpsSawValue, jpsSawEndObject, jpsSawEndArray: + ejp.advanceState() + switch ejp.s { + case jpsSawBeginObject, jpsSawComma: + ejp.advanceState() + case jpsSawEndObject: + return "", 0, ErrEOD + case jpsDoneState: + return "", 0, io.EOF + case jpsInvalidState: + return "", 0, ejp.err + default: + return "", 0, ErrInvalidJSON + } + case jpsSawKey: // do nothing (key was peeked before) + default: + return "", 0, invalidRequestError("key") + } + + // read key + var key string + + switch ejp.s { + case jpsSawKey: + key = ejp.k + case jpsSawEndObject: + return "", 0, ErrEOD + case jpsInvalidState: + return "", 0, ejp.err + default: + return "", 0, invalidRequestError("key") + } + + // check for colon + ejp.advanceState() + if err := ensureColon(ejp.s, key); err != nil { + return "", 0, err + } + + // peek at the value to determine type + t, err := ejp.peekType() + if err != nil { + return "", 0, err + } + + return key, t, nil +} + +// readValue returns the value corresponding to the Type returned by peekType +func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) { + if ejp.s == jpsInvalidState { + return nil, ejp.err + } + + var v *extJSONValue + + switch t { + case bsontype.Null, bsontype.Boolean, bsontype.String: + if ejp.s != jpsSawValue { + return nil, invalidRequestError(t.String()) + } + v = ejp.v + case bsontype.Int32, bsontype.Int64, bsontype.Double: + // relaxed version allows these to be literal number values + if ejp.s == jpsSawValue { + v = ejp.v + break + } + fallthrough + case bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID, bsontype.MinKey, bsontype.MaxKey, bsontype.Undefined: + switch ejp.s { + case jpsSawKey: + // read colon + ejp.advanceState() + if err := ensureColon(ejp.s, ejp.k); err != nil { + return nil, err + } + + // read value + ejp.advanceState() + if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) { + return nil, invalidJSONErrorForType("value", t) + } + + v = ejp.v + + // read end object + ejp.advanceState() + if ejp.s != jpsSawEndObject { + return nil, invalidJSONErrorForType("} after value", t) + } + default: + return nil, invalidRequestError(t.String()) + } + case bsontype.Binary, bsontype.Regex, bsontype.Timestamp, bsontype.DBPointer: + if ejp.s != jpsSawKey { + return nil, invalidRequestError(t.String()) + } + // read colon + ejp.advanceState() + if err := ensureColon(ejp.s, ejp.k); err != nil { + return nil, err + } + + ejp.advanceState() + if t == bsontype.Binary && ejp.s == jpsSawValue { + // convert relaxed $uuid format + if ejp.relaxedUUID { + defer func() { ejp.relaxedUUID = false }() + uuid, err := ejp.v.parseSymbol() + if err != nil { + return nil, err + } + + // RFC 4122 defines the length of a UUID as 36 and the hyphens in a UUID as appearing + // in the 8th, 13th, 18th, and 23rd characters. + // + // See https://tools.ietf.org/html/rfc4122#section-3 + valid := len(uuid) == 36 && + string(uuid[8]) == "-" && + string(uuid[13]) == "-" && + string(uuid[18]) == "-" && + string(uuid[23]) == "-" + if !valid { + return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens") + } + + // remove hyphens + uuidNoHyphens := strings.Replace(uuid, "-", "", -1) + if len(uuidNoHyphens) != 32 { + return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens") + } + + // convert hex to bytes + bytes, err := hex.DecodeString(uuidNoHyphens) + if err != nil { + return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %v", err) + } + + ejp.advanceState() + if ejp.s != jpsSawEndObject { + return nil, invalidJSONErrorForType("$uuid and value and then }", bsontype.Binary) + } + + base64 := &extJSONValue{ + t: bsontype.String, + v: base64.StdEncoding.EncodeToString(bytes), + } + subType := &extJSONValue{ + t: bsontype.String, + v: "04", + } + + v = &extJSONValue{ + t: bsontype.EmbeddedDocument, + v: &extJSONObject{ + keys: []string{"base64", "subType"}, + values: []*extJSONValue{base64, subType}, + }, + } + + break + } + + // convert legacy $binary format + base64 := ejp.v + + ejp.advanceState() + if ejp.s != jpsSawComma { + return nil, invalidJSONErrorForType(",", bsontype.Binary) + } + + ejp.advanceState() + key, t, err := ejp.readKey() + if err != nil { + return nil, err + } + if key != "$type" { + return nil, invalidJSONErrorForType("$type", bsontype.Binary) + } + + subType, err := ejp.readValue(t) + if err != nil { + return nil, err + } + + ejp.advanceState() + if ejp.s != jpsSawEndObject { + return nil, invalidJSONErrorForType("2 key-value pairs and then }", bsontype.Binary) + } + + v = &extJSONValue{ + t: bsontype.EmbeddedDocument, + v: &extJSONObject{ + keys: []string{"base64", "subType"}, + values: []*extJSONValue{base64, subType}, + }, + } + break + } + + // read KV pairs + if ejp.s != jpsSawBeginObject { + return nil, invalidJSONErrorForType("{", t) + } + + keys, vals, err := ejp.readObject(2, true) + if err != nil { + return nil, err + } + + ejp.advanceState() + if ejp.s != jpsSawEndObject { + return nil, invalidJSONErrorForType("2 key-value pairs and then }", t) + } + + v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}} + + case bsontype.DateTime: + switch ejp.s { + case jpsSawValue: + v = ejp.v + case jpsSawKey: + // read colon + ejp.advanceState() + if err := ensureColon(ejp.s, ejp.k); err != nil { + return nil, err + } + + ejp.advanceState() + switch ejp.s { + case jpsSawBeginObject: + keys, vals, err := ejp.readObject(1, true) + if err != nil { + return nil, err + } + v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}} + case jpsSawValue: + if ejp.canonical { + return nil, invalidJSONError("{") + } + v = ejp.v + default: + if ejp.canonical { + return nil, invalidJSONErrorForType("object", t) + } + return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as described in RFC-3339", t) + } + + ejp.advanceState() + if ejp.s != jpsSawEndObject { + return nil, invalidJSONErrorForType("value and then }", t) + } + default: + return nil, invalidRequestError(t.String()) + } + case bsontype.JavaScript: + switch ejp.s { + case jpsSawKey: + // read colon + ejp.advanceState() + if err := ensureColon(ejp.s, ejp.k); err != nil { + return nil, err + } + + // read value + ejp.advanceState() + if ejp.s != jpsSawValue { + return nil, invalidJSONErrorForType("value", t) + } + v = ejp.v + + // read end object or comma and just return + ejp.advanceState() + case jpsSawEndObject: + v = ejp.v + default: + return nil, invalidRequestError(t.String()) + } + case bsontype.CodeWithScope: + if ejp.s == jpsSawKey && ejp.k == "$scope" { + v = ejp.v // this is the $code string from earlier + + // read colon + ejp.advanceState() + if err := ensureColon(ejp.s, ejp.k); err != nil { + return nil, err + } + + // read { + ejp.advanceState() + if ejp.s != jpsSawBeginObject { + return nil, invalidJSONError("$scope to be embedded document") + } + } else { + return nil, invalidRequestError(t.String()) + } + case bsontype.EmbeddedDocument, bsontype.Array: + return nil, invalidRequestError(t.String()) + } + + return v, nil +} + +// readObject is a utility method for reading full objects of known (or expected) size +// it is useful for extended JSON types such as binary, datetime, regex, and timestamp +func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) { + keys := make([]string, numKeys) + vals := make([]*extJSONValue, numKeys) + + if !started { + ejp.advanceState() + if ejp.s != jpsSawBeginObject { + return nil, nil, invalidJSONError("{") + } + } + + for i := 0; i < numKeys; i++ { + key, t, err := ejp.readKey() + if err != nil { + return nil, nil, err + } + + switch ejp.s { + case jpsSawKey: + v, err := ejp.readValue(t) + if err != nil { + return nil, nil, err + } + + keys[i] = key + vals[i] = v + case jpsSawValue: + keys[i] = key + vals[i] = ejp.v + default: + return nil, nil, invalidJSONError("value") + } + } + + ejp.advanceState() + if ejp.s != jpsSawEndObject { + return nil, nil, invalidJSONError("}") + } + + return keys, vals, nil +} + +// advanceState reads the next JSON token from the scanner and transitions +// from the current state based on that token's type +func (ejp *extJSONParser) advanceState() { + if ejp.s == jpsDoneState || ejp.s == jpsInvalidState { + return + } + + jt, err := ejp.js.nextToken() + + if err != nil { + ejp.err = err + ejp.s = jpsInvalidState + return + } + + valid := ejp.validateToken(jt.t) + if !valid { + ejp.err = unexpectedTokenError(jt) + ejp.s = jpsInvalidState + return + } + + switch jt.t { + case jttBeginObject: + ejp.s = jpsSawBeginObject + ejp.pushMode(jpmObjectMode) + ejp.depth++ + + if ejp.depth > ejp.maxDepth { + ejp.err = nestingDepthError(jt.p, ejp.depth) + ejp.s = jpsInvalidState + } + case jttEndObject: + ejp.s = jpsSawEndObject + ejp.depth-- + + if ejp.popMode() != jpmObjectMode { + ejp.err = unexpectedTokenError(jt) + ejp.s = jpsInvalidState + } + case jttBeginArray: + ejp.s = jpsSawBeginArray + ejp.pushMode(jpmArrayMode) + case jttEndArray: + ejp.s = jpsSawEndArray + + if ejp.popMode() != jpmArrayMode { + ejp.err = unexpectedTokenError(jt) + ejp.s = jpsInvalidState + } + case jttColon: + ejp.s = jpsSawColon + case jttComma: + ejp.s = jpsSawComma + case jttEOF: + ejp.s = jpsDoneState + if len(ejp.m) != 0 { + ejp.err = unexpectedTokenError(jt) + ejp.s = jpsInvalidState + } + case jttString: + switch ejp.s { + case jpsSawComma: + if ejp.peekMode() == jpmArrayMode { + ejp.s = jpsSawValue + ejp.v = extendJSONToken(jt) + return + } + fallthrough + case jpsSawBeginObject: + ejp.s = jpsSawKey + ejp.k = jt.v.(string) + return + } + fallthrough + default: + ejp.s = jpsSawValue + ejp.v = extendJSONToken(jt) + } +} + +var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{ + jpsStartState: { + jttBeginObject: true, + jttBeginArray: true, + jttInt32: true, + jttInt64: true, + jttDouble: true, + jttString: true, + jttBool: true, + jttNull: true, + jttEOF: true, + }, + jpsSawBeginObject: { + jttEndObject: true, + jttString: true, + }, + jpsSawEndObject: { + jttEndObject: true, + jttEndArray: true, + jttComma: true, + jttEOF: true, + }, + jpsSawBeginArray: { + jttBeginObject: true, + jttBeginArray: true, + jttEndArray: true, + jttInt32: true, + jttInt64: true, + jttDouble: true, + jttString: true, + jttBool: true, + jttNull: true, + }, + jpsSawEndArray: { + jttEndObject: true, + jttEndArray: true, + jttComma: true, + jttEOF: true, + }, + jpsSawColon: { + jttBeginObject: true, + jttBeginArray: true, + jttInt32: true, + jttInt64: true, + jttDouble: true, + jttString: true, + jttBool: true, + jttNull: true, + }, + jpsSawComma: { + jttBeginObject: true, + jttBeginArray: true, + jttInt32: true, + jttInt64: true, + jttDouble: true, + jttString: true, + jttBool: true, + jttNull: true, + }, + jpsSawKey: { + jttColon: true, + }, + jpsSawValue: { + jttEndObject: true, + jttEndArray: true, + jttComma: true, + jttEOF: true, + }, + jpsDoneState: {}, + jpsInvalidState: {}, +} + +func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool { + switch ejp.s { + case jpsSawEndObject: + // if we are at depth zero and the next token is a '{', + // we can consider it valid only if we are not in array mode. + if jtt == jttBeginObject && ejp.depth == 0 { + return ejp.peekMode() != jpmArrayMode + } + case jpsSawComma: + switch ejp.peekMode() { + // the only valid next token after a comma inside a document is a string (a key) + case jpmObjectMode: + return jtt == jttString + case jpmInvalidMode: + return false + } + } + + _, ok := jpsValidTransitionTokens[ejp.s][jtt] + return ok +} + +// ensureExtValueType returns true if the current value has the expected +// value type for single-key extended JSON types. For example, +// {"$numberInt": v} v must be TypeString +func (ejp *extJSONParser) ensureExtValueType(t bsontype.Type) bool { + switch t { + case bsontype.MinKey, bsontype.MaxKey: + return ejp.v.t == bsontype.Int32 + case bsontype.Undefined: + return ejp.v.t == bsontype.Boolean + case bsontype.Int32, bsontype.Int64, bsontype.Double, bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID: + return ejp.v.t == bsontype.String + default: + return false + } +} + +func (ejp *extJSONParser) pushMode(m jsonParseMode) { + ejp.m = append(ejp.m, m) +} + +func (ejp *extJSONParser) popMode() jsonParseMode { + l := len(ejp.m) + if l == 0 { + return jpmInvalidMode + } + + m := ejp.m[l-1] + ejp.m = ejp.m[:l-1] + + return m +} + +func (ejp *extJSONParser) peekMode() jsonParseMode { + l := len(ejp.m) + if l == 0 { + return jpmInvalidMode + } + + return ejp.m[l-1] +} + +func extendJSONToken(jt *jsonToken) *extJSONValue { + var t bsontype.Type + + switch jt.t { + case jttInt32: + t = bsontype.Int32 + case jttInt64: + t = bsontype.Int64 + case jttDouble: + t = bsontype.Double + case jttString: + t = bsontype.String + case jttBool: + t = bsontype.Boolean + case jttNull: + t = bsontype.Null + default: + return nil + } + + return &extJSONValue{t: t, v: jt.v} +} + +func ensureColon(s jsonParseState, key string) error { + if s != jpsSawColon { + return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key) + } + + return nil +} + +func invalidRequestError(s string) error { + return fmt.Errorf("invalid request to read %s", s) +} + +func invalidJSONError(expected string) error { + return fmt.Errorf("invalid JSON input; expected %s", expected) +} + +func invalidJSONErrorForType(expected string, t bsontype.Type) error { + return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t) +} + +func unexpectedTokenError(jt *jsonToken) error { + switch jt.t { + case jttInt32, jttInt64, jttDouble: + return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p) + case jttString: + return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p) + case jttBool: + return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p) + case jttNull: + return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p) + case jttEOF: + return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p) + default: + return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p) + } +} + +func nestingDepthError(p, depth int) error { + return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p) +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_reader.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_reader.go new file mode 100644 index 000000000..35832d73a --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_reader.go @@ -0,0 +1,644 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "fmt" + "io" + "sync" + + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +// ExtJSONValueReaderPool is a pool for ValueReaders that read ExtJSON. +type ExtJSONValueReaderPool struct { + pool sync.Pool +} + +// NewExtJSONValueReaderPool instantiates a new ExtJSONValueReaderPool. +func NewExtJSONValueReaderPool() *ExtJSONValueReaderPool { + return &ExtJSONValueReaderPool{ + pool: sync.Pool{ + New: func() interface{} { + return new(extJSONValueReader) + }, + }, + } +} + +// Get retrieves a ValueReader from the pool and uses src as the underlying ExtJSON. +func (bvrp *ExtJSONValueReaderPool) Get(r io.Reader, canonical bool) (ValueReader, error) { + vr := bvrp.pool.Get().(*extJSONValueReader) + return vr.reset(r, canonical) +} + +// Put inserts a ValueReader into the pool. If the ValueReader is not a ExtJSON ValueReader nothing +// is inserted into the pool and ok will be false. +func (bvrp *ExtJSONValueReaderPool) Put(vr ValueReader) (ok bool) { + bvr, ok := vr.(*extJSONValueReader) + if !ok { + return false + } + + bvr, _ = bvr.reset(nil, false) + bvrp.pool.Put(bvr) + return true +} + +type ejvrState struct { + mode mode + vType bsontype.Type + depth int +} + +// extJSONValueReader is for reading extended JSON. +type extJSONValueReader struct { + p *extJSONParser + + stack []ejvrState + frame int +} + +// NewExtJSONValueReader creates a new ValueReader from a given io.Reader +// It will interpret the JSON of r as canonical or relaxed according to the +// given canonical flag +func NewExtJSONValueReader(r io.Reader, canonical bool) (ValueReader, error) { + return newExtJSONValueReader(r, canonical) +} + +func newExtJSONValueReader(r io.Reader, canonical bool) (*extJSONValueReader, error) { + ejvr := new(extJSONValueReader) + return ejvr.reset(r, canonical) +} + +func (ejvr *extJSONValueReader) reset(r io.Reader, canonical bool) (*extJSONValueReader, error) { + p := newExtJSONParser(r, canonical) + typ, err := p.peekType() + + if err != nil { + return nil, ErrInvalidJSON + } + + var m mode + switch typ { + case bsontype.EmbeddedDocument: + m = mTopLevel + case bsontype.Array: + m = mArray + default: + m = mValue + } + + stack := make([]ejvrState, 1, 5) + stack[0] = ejvrState{ + mode: m, + vType: typ, + } + return &extJSONValueReader{ + p: p, + stack: stack, + }, nil +} + +func (ejvr *extJSONValueReader) advanceFrame() { + if ejvr.frame+1 >= len(ejvr.stack) { // We need to grow the stack + length := len(ejvr.stack) + if length+1 >= cap(ejvr.stack) { + // double it + buf := make([]ejvrState, 2*cap(ejvr.stack)+1) + copy(buf, ejvr.stack) + ejvr.stack = buf + } + ejvr.stack = ejvr.stack[:length+1] + } + ejvr.frame++ + + // Clean the stack + ejvr.stack[ejvr.frame].mode = 0 + ejvr.stack[ejvr.frame].vType = 0 + ejvr.stack[ejvr.frame].depth = 0 +} + +func (ejvr *extJSONValueReader) pushDocument() { + ejvr.advanceFrame() + + ejvr.stack[ejvr.frame].mode = mDocument + ejvr.stack[ejvr.frame].depth = ejvr.p.depth +} + +func (ejvr *extJSONValueReader) pushCodeWithScope() { + ejvr.advanceFrame() + + ejvr.stack[ejvr.frame].mode = mCodeWithScope +} + +func (ejvr *extJSONValueReader) pushArray() { + ejvr.advanceFrame() + + ejvr.stack[ejvr.frame].mode = mArray +} + +func (ejvr *extJSONValueReader) push(m mode, t bsontype.Type) { + ejvr.advanceFrame() + + ejvr.stack[ejvr.frame].mode = m + ejvr.stack[ejvr.frame].vType = t +} + +func (ejvr *extJSONValueReader) pop() { + switch ejvr.stack[ejvr.frame].mode { + case mElement, mValue: + ejvr.frame-- + case mDocument, mArray, mCodeWithScope: + ejvr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc... + } +} + +func (ejvr *extJSONValueReader) skipObject() { + // read entire object until depth returns to 0 (last ending } or ] seen) + depth := 1 + for depth > 0 { + ejvr.p.advanceState() + + // If object is empty, raise depth and continue. When emptyObject is true, the + // parser has already read both the opening and closing brackets of an empty + // object ("{}"), so the next valid token will be part of the parent document, + // not part of the nested document. + // + // If there is a comma, there are remaining fields, emptyObject must be set back + // to false, and comma must be skipped with advanceState(). + if ejvr.p.emptyObject { + if ejvr.p.s == jpsSawComma { + ejvr.p.emptyObject = false + ejvr.p.advanceState() + } + depth-- + continue + } + + switch ejvr.p.s { + case jpsSawBeginObject, jpsSawBeginArray: + depth++ + case jpsSawEndObject, jpsSawEndArray: + depth-- + } + } +} + +func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error { + te := TransitionError{ + name: name, + current: ejvr.stack[ejvr.frame].mode, + destination: destination, + modes: modes, + action: "read", + } + if ejvr.frame != 0 { + te.parent = ejvr.stack[ejvr.frame-1].mode + } + return te +} + +func (ejvr *extJSONValueReader) typeError(t bsontype.Type) error { + return fmt.Errorf("positioned on %s, but attempted to read %s", ejvr.stack[ejvr.frame].vType, t) +} + +func (ejvr *extJSONValueReader) ensureElementValue(t bsontype.Type, destination mode, callerName string, addModes ...mode) error { + switch ejvr.stack[ejvr.frame].mode { + case mElement, mValue: + if ejvr.stack[ejvr.frame].vType != t { + return ejvr.typeError(t) + } + default: + modes := []mode{mElement, mValue} + if addModes != nil { + modes = append(modes, addModes...) + } + return ejvr.invalidTransitionErr(destination, callerName, modes) + } + + return nil +} + +func (ejvr *extJSONValueReader) Type() bsontype.Type { + return ejvr.stack[ejvr.frame].vType +} + +func (ejvr *extJSONValueReader) Skip() error { + switch ejvr.stack[ejvr.frame].mode { + case mElement, mValue: + default: + return ejvr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue}) + } + + defer ejvr.pop() + + t := ejvr.stack[ejvr.frame].vType + switch t { + case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope: + // read entire array, doc or CodeWithScope + ejvr.skipObject() + default: + _, err := ejvr.p.readValue(t) + if err != nil { + return err + } + } + + return nil +} + +func (ejvr *extJSONValueReader) ReadArray() (ArrayReader, error) { + switch ejvr.stack[ejvr.frame].mode { + case mTopLevel: // allow reading array from top level + case mArray: + return ejvr, nil + default: + if err := ejvr.ensureElementValue(bsontype.Array, mArray, "ReadArray", mTopLevel, mArray); err != nil { + return nil, err + } + } + + ejvr.pushArray() + + return ejvr, nil +} + +func (ejvr *extJSONValueReader) ReadBinary() (b []byte, btype byte, err error) { + if err := ejvr.ensureElementValue(bsontype.Binary, 0, "ReadBinary"); err != nil { + return nil, 0, err + } + + v, err := ejvr.p.readValue(bsontype.Binary) + if err != nil { + return nil, 0, err + } + + b, btype, err = v.parseBinary() + + ejvr.pop() + return b, btype, err +} + +func (ejvr *extJSONValueReader) ReadBoolean() (bool, error) { + if err := ejvr.ensureElementValue(bsontype.Boolean, 0, "ReadBoolean"); err != nil { + return false, err + } + + v, err := ejvr.p.readValue(bsontype.Boolean) + if err != nil { + return false, err + } + + if v.t != bsontype.Boolean { + return false, fmt.Errorf("expected type bool, but got type %s", v.t) + } + + ejvr.pop() + return v.v.(bool), nil +} + +func (ejvr *extJSONValueReader) ReadDocument() (DocumentReader, error) { + switch ejvr.stack[ejvr.frame].mode { + case mTopLevel: + return ejvr, nil + case mElement, mValue: + if ejvr.stack[ejvr.frame].vType != bsontype.EmbeddedDocument { + return nil, ejvr.typeError(bsontype.EmbeddedDocument) + } + + ejvr.pushDocument() + return ejvr, nil + default: + return nil, ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue}) + } +} + +func (ejvr *extJSONValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) { + if err = ejvr.ensureElementValue(bsontype.CodeWithScope, 0, "ReadCodeWithScope"); err != nil { + return "", nil, err + } + + v, err := ejvr.p.readValue(bsontype.CodeWithScope) + if err != nil { + return "", nil, err + } + + code, err = v.parseJavascript() + + ejvr.pushCodeWithScope() + return code, ejvr, err +} + +func (ejvr *extJSONValueReader) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) { + if err = ejvr.ensureElementValue(bsontype.DBPointer, 0, "ReadDBPointer"); err != nil { + return "", primitive.NilObjectID, err + } + + v, err := ejvr.p.readValue(bsontype.DBPointer) + if err != nil { + return "", primitive.NilObjectID, err + } + + ns, oid, err = v.parseDBPointer() + + ejvr.pop() + return ns, oid, err +} + +func (ejvr *extJSONValueReader) ReadDateTime() (int64, error) { + if err := ejvr.ensureElementValue(bsontype.DateTime, 0, "ReadDateTime"); err != nil { + return 0, err + } + + v, err := ejvr.p.readValue(bsontype.DateTime) + if err != nil { + return 0, err + } + + d, err := v.parseDateTime() + + ejvr.pop() + return d, err +} + +func (ejvr *extJSONValueReader) ReadDecimal128() (primitive.Decimal128, error) { + if err := ejvr.ensureElementValue(bsontype.Decimal128, 0, "ReadDecimal128"); err != nil { + return primitive.Decimal128{}, err + } + + v, err := ejvr.p.readValue(bsontype.Decimal128) + if err != nil { + return primitive.Decimal128{}, err + } + + d, err := v.parseDecimal128() + + ejvr.pop() + return d, err +} + +func (ejvr *extJSONValueReader) ReadDouble() (float64, error) { + if err := ejvr.ensureElementValue(bsontype.Double, 0, "ReadDouble"); err != nil { + return 0, err + } + + v, err := ejvr.p.readValue(bsontype.Double) + if err != nil { + return 0, err + } + + d, err := v.parseDouble() + + ejvr.pop() + return d, err +} + +func (ejvr *extJSONValueReader) ReadInt32() (int32, error) { + if err := ejvr.ensureElementValue(bsontype.Int32, 0, "ReadInt32"); err != nil { + return 0, err + } + + v, err := ejvr.p.readValue(bsontype.Int32) + if err != nil { + return 0, err + } + + i, err := v.parseInt32() + + ejvr.pop() + return i, err +} + +func (ejvr *extJSONValueReader) ReadInt64() (int64, error) { + if err := ejvr.ensureElementValue(bsontype.Int64, 0, "ReadInt64"); err != nil { + return 0, err + } + + v, err := ejvr.p.readValue(bsontype.Int64) + if err != nil { + return 0, err + } + + i, err := v.parseInt64() + + ejvr.pop() + return i, err +} + +func (ejvr *extJSONValueReader) ReadJavascript() (code string, err error) { + if err = ejvr.ensureElementValue(bsontype.JavaScript, 0, "ReadJavascript"); err != nil { + return "", err + } + + v, err := ejvr.p.readValue(bsontype.JavaScript) + if err != nil { + return "", err + } + + code, err = v.parseJavascript() + + ejvr.pop() + return code, err +} + +func (ejvr *extJSONValueReader) ReadMaxKey() error { + if err := ejvr.ensureElementValue(bsontype.MaxKey, 0, "ReadMaxKey"); err != nil { + return err + } + + v, err := ejvr.p.readValue(bsontype.MaxKey) + if err != nil { + return err + } + + err = v.parseMinMaxKey("max") + + ejvr.pop() + return err +} + +func (ejvr *extJSONValueReader) ReadMinKey() error { + if err := ejvr.ensureElementValue(bsontype.MinKey, 0, "ReadMinKey"); err != nil { + return err + } + + v, err := ejvr.p.readValue(bsontype.MinKey) + if err != nil { + return err + } + + err = v.parseMinMaxKey("min") + + ejvr.pop() + return err +} + +func (ejvr *extJSONValueReader) ReadNull() error { + if err := ejvr.ensureElementValue(bsontype.Null, 0, "ReadNull"); err != nil { + return err + } + + v, err := ejvr.p.readValue(bsontype.Null) + if err != nil { + return err + } + + if v.t != bsontype.Null { + return fmt.Errorf("expected type null but got type %s", v.t) + } + + ejvr.pop() + return nil +} + +func (ejvr *extJSONValueReader) ReadObjectID() (primitive.ObjectID, error) { + if err := ejvr.ensureElementValue(bsontype.ObjectID, 0, "ReadObjectID"); err != nil { + return primitive.ObjectID{}, err + } + + v, err := ejvr.p.readValue(bsontype.ObjectID) + if err != nil { + return primitive.ObjectID{}, err + } + + oid, err := v.parseObjectID() + + ejvr.pop() + return oid, err +} + +func (ejvr *extJSONValueReader) ReadRegex() (pattern string, options string, err error) { + if err = ejvr.ensureElementValue(bsontype.Regex, 0, "ReadRegex"); err != nil { + return "", "", err + } + + v, err := ejvr.p.readValue(bsontype.Regex) + if err != nil { + return "", "", err + } + + pattern, options, err = v.parseRegex() + + ejvr.pop() + return pattern, options, err +} + +func (ejvr *extJSONValueReader) ReadString() (string, error) { + if err := ejvr.ensureElementValue(bsontype.String, 0, "ReadString"); err != nil { + return "", err + } + + v, err := ejvr.p.readValue(bsontype.String) + if err != nil { + return "", err + } + + if v.t != bsontype.String { + return "", fmt.Errorf("expected type string but got type %s", v.t) + } + + ejvr.pop() + return v.v.(string), nil +} + +func (ejvr *extJSONValueReader) ReadSymbol() (symbol string, err error) { + if err = ejvr.ensureElementValue(bsontype.Symbol, 0, "ReadSymbol"); err != nil { + return "", err + } + + v, err := ejvr.p.readValue(bsontype.Symbol) + if err != nil { + return "", err + } + + symbol, err = v.parseSymbol() + + ejvr.pop() + return symbol, err +} + +func (ejvr *extJSONValueReader) ReadTimestamp() (t uint32, i uint32, err error) { + if err = ejvr.ensureElementValue(bsontype.Timestamp, 0, "ReadTimestamp"); err != nil { + return 0, 0, err + } + + v, err := ejvr.p.readValue(bsontype.Timestamp) + if err != nil { + return 0, 0, err + } + + t, i, err = v.parseTimestamp() + + ejvr.pop() + return t, i, err +} + +func (ejvr *extJSONValueReader) ReadUndefined() error { + if err := ejvr.ensureElementValue(bsontype.Undefined, 0, "ReadUndefined"); err != nil { + return err + } + + v, err := ejvr.p.readValue(bsontype.Undefined) + if err != nil { + return err + } + + err = v.parseUndefined() + + ejvr.pop() + return err +} + +func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) { + switch ejvr.stack[ejvr.frame].mode { + case mTopLevel, mDocument, mCodeWithScope: + default: + return "", nil, ejvr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope}) + } + + name, t, err := ejvr.p.readKey() + + if err != nil { + if err == ErrEOD { + if ejvr.stack[ejvr.frame].mode == mCodeWithScope { + _, err := ejvr.p.peekType() + if err != nil { + return "", nil, err + } + } + + ejvr.pop() + } + + return "", nil, err + } + + ejvr.push(mElement, t) + return name, ejvr, nil +} + +func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) { + switch ejvr.stack[ejvr.frame].mode { + case mArray: + default: + return nil, ejvr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray}) + } + + t, err := ejvr.p.peekType() + if err != nil { + if err == ErrEOA { + ejvr.pop() + } + + return nil, err + } + + ejvr.push(mValue, t) + return ejvr, nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_tables.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_tables.go new file mode 100644 index 000000000..ba39c9601 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_tables.go @@ -0,0 +1,223 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +// +// Based on github.com/golang/go by The Go Authors +// See THIRD-PARTY-NOTICES for original license terms. + +package bsonrw + +import "unicode/utf8" + +// safeSet holds the value true if the ASCII character with the given array +// position can be represented inside a JSON string without any further +// escaping. +// +// All values are true except for the ASCII control characters (0-31), the +// double quote ("), and the backslash character ("\"). +var safeSet = [utf8.RuneSelf]bool{ + ' ': true, + '!': true, + '"': false, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '(': true, + ')': true, + '*': true, + '+': true, + ',': true, + '-': true, + '.': true, + '/': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + ':': true, + ';': true, + '<': true, + '=': true, + '>': true, + '?': true, + '@': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'V': true, + 'W': true, + 'X': true, + 'Y': true, + 'Z': true, + '[': true, + '\\': false, + ']': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '{': true, + '|': true, + '}': true, + '~': true, + '\u007f': true, +} + +// htmlSafeSet holds the value true if the ASCII character with the given +// array position can be safely represented inside a JSON string, embedded +// inside of HTML <script> tags, without any additional escaping. +// +// All values are true except for the ASCII control characters (0-31), the +// double quote ("), the backslash character ("\"), HTML opening and closing +// tags ("<" and ">"), and the ampersand ("&"). +var htmlSafeSet = [utf8.RuneSelf]bool{ + ' ': true, + '!': true, + '"': false, + '#': true, + '$': true, + '%': true, + '&': false, + '\'': true, + '(': true, + ')': true, + '*': true, + '+': true, + ',': true, + '-': true, + '.': true, + '/': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + ':': true, + ';': true, + '<': false, + '=': true, + '>': false, + '?': true, + '@': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'V': true, + 'W': true, + 'X': true, + 'Y': true, + 'Z': true, + '[': true, + '\\': false, + ']': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '{': true, + '|': true, + '}': true, + '~': true, + '\u007f': true, +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_wrappers.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_wrappers.go new file mode 100644 index 000000000..969570424 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_wrappers.go @@ -0,0 +1,492 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "encoding/base64" + "errors" + "fmt" + "math" + "strconv" + "time" + + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +func wrapperKeyBSONType(key string) bsontype.Type { + switch key { + case "$numberInt": + return bsontype.Int32 + case "$numberLong": + return bsontype.Int64 + case "$oid": + return bsontype.ObjectID + case "$symbol": + return bsontype.Symbol + case "$numberDouble": + return bsontype.Double + case "$numberDecimal": + return bsontype.Decimal128 + case "$binary": + return bsontype.Binary + case "$code": + return bsontype.JavaScript + case "$scope": + return bsontype.CodeWithScope + case "$timestamp": + return bsontype.Timestamp + case "$regularExpression": + return bsontype.Regex + case "$dbPointer": + return bsontype.DBPointer + case "$date": + return bsontype.DateTime + case "$minKey": + return bsontype.MinKey + case "$maxKey": + return bsontype.MaxKey + case "$undefined": + return bsontype.Undefined + } + + return bsontype.EmbeddedDocument +} + +func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) { + if ejv.t != bsontype.EmbeddedDocument { + return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t) + } + + binObj := ejv.v.(*extJSONObject) + bFound := false + stFound := false + + for i, key := range binObj.keys { + val := binObj.values[i] + + switch key { + case "base64": + if bFound { + return nil, 0, errors.New("duplicate base64 key in $binary") + } + + if val.t != bsontype.String { + return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t) + } + + base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string)) + if err != nil { + return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string)) + } + + b = base64Bytes + bFound = true + case "subType": + if stFound { + return nil, 0, errors.New("duplicate subType key in $binary") + } + + if val.t != bsontype.String { + return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t) + } + + i, err := strconv.ParseInt(val.v.(string), 16, 64) + if err != nil { + return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string)) + } + + subType = byte(i) + stFound = true + default: + return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key) + } + } + + if !bFound { + return nil, 0, errors.New("missing base64 field in $binary object") + } + + if !stFound { + return nil, 0, errors.New("missing subType field in $binary object") + + } + + return b, subType, nil +} + +func (ejv *extJSONValue) parseDBPointer() (ns string, oid primitive.ObjectID, err error) { + if ejv.t != bsontype.EmbeddedDocument { + return "", primitive.NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t) + } + + dbpObj := ejv.v.(*extJSONObject) + oidFound := false + nsFound := false + + for i, key := range dbpObj.keys { + val := dbpObj.values[i] + + switch key { + case "$ref": + if nsFound { + return "", primitive.NilObjectID, errors.New("duplicate $ref key in $dbPointer") + } + + if val.t != bsontype.String { + return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t) + } + + ns = val.v.(string) + nsFound = true + case "$id": + if oidFound { + return "", primitive.NilObjectID, errors.New("duplicate $id key in $dbPointer") + } + + if val.t != bsontype.String { + return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t) + } + + oid, err = primitive.ObjectIDFromHex(val.v.(string)) + if err != nil { + return "", primitive.NilObjectID, err + } + + oidFound = true + default: + return "", primitive.NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key) + } + } + + if !nsFound { + return "", oid, errors.New("missing $ref field in $dbPointer object") + } + + if !oidFound { + return "", oid, errors.New("missing $id field in $dbPointer object") + } + + return ns, oid, nil +} + +const ( + rfc3339Milli = "2006-01-02T15:04:05.999Z07:00" +) + +var ( + timeFormats = []string{rfc3339Milli, "2006-01-02T15:04:05.999Z0700"} +) + +func (ejv *extJSONValue) parseDateTime() (int64, error) { + switch ejv.t { + case bsontype.Int32: + return int64(ejv.v.(int32)), nil + case bsontype.Int64: + return ejv.v.(int64), nil + case bsontype.String: + return parseDatetimeString(ejv.v.(string)) + case bsontype.EmbeddedDocument: + return parseDatetimeObject(ejv.v.(*extJSONObject)) + default: + return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t) + } +} + +func parseDatetimeString(data string) (int64, error) { + var t time.Time + var err error + // try acceptable time formats until one matches + for _, format := range timeFormats { + t, err = time.Parse(format, data) + if err == nil { + break + } + } + if err != nil { + return 0, fmt.Errorf("invalid $date value string: %s", data) + } + + return int64(primitive.NewDateTimeFromTime(t)), nil +} + +func parseDatetimeObject(data *extJSONObject) (d int64, err error) { + dFound := false + + for i, key := range data.keys { + val := data.values[i] + + switch key { + case "$numberLong": + if dFound { + return 0, errors.New("duplicate $numberLong key in $date") + } + + if val.t != bsontype.String { + return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t) + } + + d, err = val.parseInt64() + if err != nil { + return 0, err + } + dFound = true + default: + return 0, fmt.Errorf("invalid key in $date object: %s", key) + } + } + + if !dFound { + return 0, errors.New("missing $numberLong field in $date object") + } + + return d, nil +} + +func (ejv *extJSONValue) parseDecimal128() (primitive.Decimal128, error) { + if ejv.t != bsontype.String { + return primitive.Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t) + } + + d, err := primitive.ParseDecimal128(ejv.v.(string)) + if err != nil { + return primitive.Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string)) + } + + return d, nil +} + +func (ejv *extJSONValue) parseDouble() (float64, error) { + if ejv.t == bsontype.Double { + return ejv.v.(float64), nil + } + + if ejv.t != bsontype.String { + return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t) + } + + switch ejv.v.(string) { + case "Infinity": + return math.Inf(1), nil + case "-Infinity": + return math.Inf(-1), nil + case "NaN": + return math.NaN(), nil + } + + f, err := strconv.ParseFloat(ejv.v.(string), 64) + if err != nil { + return 0, err + } + + return f, nil +} + +func (ejv *extJSONValue) parseInt32() (int32, error) { + if ejv.t == bsontype.Int32 { + return ejv.v.(int32), nil + } + + if ejv.t != bsontype.String { + return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t) + } + + i, err := strconv.ParseInt(ejv.v.(string), 10, 64) + if err != nil { + return 0, err + } + + if i < math.MinInt32 || i > math.MaxInt32 { + return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i) + } + + return int32(i), nil +} + +func (ejv *extJSONValue) parseInt64() (int64, error) { + if ejv.t == bsontype.Int64 { + return ejv.v.(int64), nil + } + + if ejv.t != bsontype.String { + return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t) + } + + i, err := strconv.ParseInt(ejv.v.(string), 10, 64) + if err != nil { + return 0, err + } + + return i, nil +} + +func (ejv *extJSONValue) parseJavascript() (code string, err error) { + if ejv.t != bsontype.String { + return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t) + } + + return ejv.v.(string), nil +} + +func (ejv *extJSONValue) parseMinMaxKey(minmax string) error { + if ejv.t != bsontype.Int32 { + return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t) + } + + if ejv.v.(int32) != 1 { + return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32)) + } + + return nil +} + +func (ejv *extJSONValue) parseObjectID() (primitive.ObjectID, error) { + if ejv.t != bsontype.String { + return primitive.NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t) + } + + return primitive.ObjectIDFromHex(ejv.v.(string)) +} + +func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) { + if ejv.t != bsontype.EmbeddedDocument { + return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t) + } + + regexObj := ejv.v.(*extJSONObject) + patFound := false + optFound := false + + for i, key := range regexObj.keys { + val := regexObj.values[i] + + switch key { + case "pattern": + if patFound { + return "", "", errors.New("duplicate pattern key in $regularExpression") + } + + if val.t != bsontype.String { + return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t) + } + + pattern = val.v.(string) + patFound = true + case "options": + if optFound { + return "", "", errors.New("duplicate options key in $regularExpression") + } + + if val.t != bsontype.String { + return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t) + } + + options = val.v.(string) + optFound = true + default: + return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key) + } + } + + if !patFound { + return "", "", errors.New("missing pattern field in $regularExpression object") + } + + if !optFound { + return "", "", errors.New("missing options field in $regularExpression object") + + } + + return pattern, options, nil +} + +func (ejv *extJSONValue) parseSymbol() (string, error) { + if ejv.t != bsontype.String { + return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t) + } + + return ejv.v.(string), nil +} + +func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) { + if ejv.t != bsontype.EmbeddedDocument { + return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t) + } + + handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) { + if flag { + return 0, fmt.Errorf("duplicate %s key in $timestamp", key) + } + + switch val.t { + case bsontype.Int32: + value := val.v.(int32) + + if value < 0 { + return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value) + } + + return uint32(value), nil + case bsontype.Int64: + value := val.v.(int64) + if value < 0 || value > int64(math.MaxUint32) { + return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value) + } + + return uint32(value), nil + default: + return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t) + } + } + + tsObj := ejv.v.(*extJSONObject) + tFound := false + iFound := false + + for j, key := range tsObj.keys { + val := tsObj.values[j] + + switch key { + case "t": + if t, err = handleKey(key, val, tFound); err != nil { + return 0, 0, err + } + + tFound = true + case "i": + if i, err = handleKey(key, val, iFound); err != nil { + return 0, 0, err + } + + iFound = true + default: + return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key) + } + } + + if !tFound { + return 0, 0, errors.New("missing t field in $timestamp object") + } + + if !iFound { + return 0, 0, errors.New("missing i field in $timestamp object") + } + + return t, i, nil +} + +func (ejv *extJSONValue) parseUndefined() error { + if ejv.t != bsontype.Boolean { + return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t) + } + + if !ejv.v.(bool) { + return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool)) + } + + return nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_writer.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_writer.go new file mode 100644 index 000000000..99ed524b7 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_writer.go @@ -0,0 +1,732 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "math" + "sort" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + + "go.mongodb.org/mongo-driver/bson/primitive" +) + +// ExtJSONValueWriterPool is a pool for ExtJSON ValueWriters. +type ExtJSONValueWriterPool struct { + pool sync.Pool +} + +// NewExtJSONValueWriterPool creates a new pool for ValueWriter instances that write to ExtJSON. +func NewExtJSONValueWriterPool() *ExtJSONValueWriterPool { + return &ExtJSONValueWriterPool{ + pool: sync.Pool{ + New: func() interface{} { + return new(extJSONValueWriter) + }, + }, + } +} + +// Get retrieves a ExtJSON ValueWriter from the pool and resets it to use w as the destination. +func (bvwp *ExtJSONValueWriterPool) Get(w io.Writer, canonical, escapeHTML bool) ValueWriter { + vw := bvwp.pool.Get().(*extJSONValueWriter) + if writer, ok := w.(*SliceWriter); ok { + vw.reset(*writer, canonical, escapeHTML) + vw.w = writer + return vw + } + vw.buf = vw.buf[:0] + vw.w = w + return vw +} + +// Put inserts a ValueWriter into the pool. If the ValueWriter is not a ExtJSON ValueWriter, nothing +// happens and ok will be false. +func (bvwp *ExtJSONValueWriterPool) Put(vw ValueWriter) (ok bool) { + bvw, ok := vw.(*extJSONValueWriter) + if !ok { + return false + } + + if _, ok := bvw.w.(*SliceWriter); ok { + bvw.buf = nil + } + bvw.w = nil + + bvwp.pool.Put(bvw) + return true +} + +type ejvwState struct { + mode mode +} + +type extJSONValueWriter struct { + w io.Writer + buf []byte + + stack []ejvwState + frame int64 + canonical bool + escapeHTML bool +} + +// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w. +func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) (ValueWriter, error) { + if w == nil { + return nil, errNilWriter + } + + return newExtJSONWriter(w, canonical, escapeHTML), nil +} + +func newExtJSONWriter(w io.Writer, canonical, escapeHTML bool) *extJSONValueWriter { + stack := make([]ejvwState, 1, 5) + stack[0] = ejvwState{mode: mTopLevel} + + return &extJSONValueWriter{ + w: w, + buf: []byte{}, + stack: stack, + canonical: canonical, + escapeHTML: escapeHTML, + } +} + +func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter { + stack := make([]ejvwState, 1, 5) + stack[0] = ejvwState{mode: mTopLevel} + + return &extJSONValueWriter{ + buf: buf, + stack: stack, + canonical: canonical, + escapeHTML: escapeHTML, + } +} + +func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) { + if ejvw.stack == nil { + ejvw.stack = make([]ejvwState, 1, 5) + } + + ejvw.stack = ejvw.stack[:1] + ejvw.stack[0] = ejvwState{mode: mTopLevel} + ejvw.canonical = canonical + ejvw.escapeHTML = escapeHTML + ejvw.frame = 0 + ejvw.buf = buf + ejvw.w = nil +} + +func (ejvw *extJSONValueWriter) advanceFrame() { + if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack + length := len(ejvw.stack) + if length+1 >= cap(ejvw.stack) { + // double it + buf := make([]ejvwState, 2*cap(ejvw.stack)+1) + copy(buf, ejvw.stack) + ejvw.stack = buf + } + ejvw.stack = ejvw.stack[:length+1] + } + ejvw.frame++ +} + +func (ejvw *extJSONValueWriter) push(m mode) { + ejvw.advanceFrame() + + ejvw.stack[ejvw.frame].mode = m +} + +func (ejvw *extJSONValueWriter) pop() { + switch ejvw.stack[ejvw.frame].mode { + case mElement, mValue: + ejvw.frame-- + case mDocument, mArray, mCodeWithScope: + ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc... + } +} + +func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error { + te := TransitionError{ + name: name, + current: ejvw.stack[ejvw.frame].mode, + destination: destination, + modes: modes, + action: "write", + } + if ejvw.frame != 0 { + te.parent = ejvw.stack[ejvw.frame-1].mode + } + return te +} + +func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error { + switch ejvw.stack[ejvw.frame].mode { + case mElement, mValue: + default: + modes := []mode{mElement, mValue} + if addmodes != nil { + modes = append(modes, addmodes...) + } + return ejvw.invalidTransitionErr(destination, callerName, modes) + } + + return nil +} + +func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) { + var s string + if quotes { + s = fmt.Sprintf(`{"$%s":"%s"}`, key, value) + } else { + s = fmt.Sprintf(`{"$%s":%s}`, key, value) + } + + ejvw.buf = append(ejvw.buf, []byte(s)...) +} + +func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) { + if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil { + return nil, err + } + + ejvw.buf = append(ejvw.buf, '[') + + ejvw.push(mArray) + return ejvw, nil +} + +func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error { + return ejvw.WriteBinaryWithSubtype(b, 0x00) +} + +func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error { + if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil { + return err + } + + var buf bytes.Buffer + buf.WriteString(`{"$binary":{"base64":"`) + buf.WriteString(base64.StdEncoding.EncodeToString(b)) + buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype)) + + ejvw.buf = append(ejvw.buf, buf.Bytes()...) + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error { + if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil { + return err + } + + ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...) + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) { + if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil { + return nil, err + } + + var buf bytes.Buffer + buf.WriteString(`{"$code":`) + writeStringWithEscapes(code, &buf, ejvw.escapeHTML) + buf.WriteString(`,"$scope":{`) + + ejvw.buf = append(ejvw.buf, buf.Bytes()...) + + ejvw.push(mCodeWithScope) + return ejvw, nil +} + +func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error { + if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil { + return err + } + + var buf bytes.Buffer + buf.WriteString(`{"$dbPointer":{"$ref":"`) + buf.WriteString(ns) + buf.WriteString(`","$id":{"$oid":"`) + buf.WriteString(oid.Hex()) + buf.WriteString(`"}}},`) + + ejvw.buf = append(ejvw.buf, buf.Bytes()...) + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error { + if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil { + return err + } + + t := time.Unix(dt/1e3, dt%1e3*1e6).UTC() + + if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 { + s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt) + ejvw.writeExtendedSingleValue("date", s, false) + } else { + ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true) + } + + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteDecimal128(d primitive.Decimal128) error { + if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil { + return err + } + + ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true) + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) { + if ejvw.stack[ejvw.frame].mode == mTopLevel { + ejvw.buf = append(ejvw.buf, '{') + return ejvw, nil + } + + if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil { + return nil, err + } + + ejvw.buf = append(ejvw.buf, '{') + ejvw.push(mDocument) + return ejvw, nil +} + +func (ejvw *extJSONValueWriter) WriteDouble(f float64) error { + if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil { + return err + } + + s := formatDouble(f) + + if ejvw.canonical { + ejvw.writeExtendedSingleValue("numberDouble", s, true) + } else { + switch s { + case "Infinity": + fallthrough + case "-Infinity": + fallthrough + case "NaN": + s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s) + } + ejvw.buf = append(ejvw.buf, []byte(s)...) + } + + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteInt32(i int32) error { + if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil { + return err + } + + s := strconv.FormatInt(int64(i), 10) + + if ejvw.canonical { + ejvw.writeExtendedSingleValue("numberInt", s, true) + } else { + ejvw.buf = append(ejvw.buf, []byte(s)...) + } + + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteInt64(i int64) error { + if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil { + return err + } + + s := strconv.FormatInt(i, 10) + + if ejvw.canonical { + ejvw.writeExtendedSingleValue("numberLong", s, true) + } else { + ejvw.buf = append(ejvw.buf, []byte(s)...) + } + + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteJavascript(code string) error { + if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil { + return err + } + + var buf bytes.Buffer + writeStringWithEscapes(code, &buf, ejvw.escapeHTML) + + ejvw.writeExtendedSingleValue("code", buf.String(), false) + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteMaxKey() error { + if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil { + return err + } + + ejvw.writeExtendedSingleValue("maxKey", "1", false) + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteMinKey() error { + if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil { + return err + } + + ejvw.writeExtendedSingleValue("minKey", "1", false) + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteNull() error { + if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil { + return err + } + + ejvw.buf = append(ejvw.buf, []byte("null")...) + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteObjectID(oid primitive.ObjectID) error { + if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil { + return err + } + + ejvw.writeExtendedSingleValue("oid", oid.Hex(), true) + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error { + if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil { + return err + } + + var buf bytes.Buffer + buf.WriteString(`{"$regularExpression":{"pattern":`) + writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML) + buf.WriteString(`,"options":"`) + buf.WriteString(sortStringAlphebeticAscending(options)) + buf.WriteString(`"}},`) + + ejvw.buf = append(ejvw.buf, buf.Bytes()...) + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteString(s string) error { + if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil { + return err + } + + var buf bytes.Buffer + writeStringWithEscapes(s, &buf, ejvw.escapeHTML) + + ejvw.buf = append(ejvw.buf, buf.Bytes()...) + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error { + if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil { + return err + } + + var buf bytes.Buffer + writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML) + + ejvw.writeExtendedSingleValue("symbol", buf.String(), false) + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error { + if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil { + return err + } + + var buf bytes.Buffer + buf.WriteString(`{"$timestamp":{"t":`) + buf.WriteString(strconv.FormatUint(uint64(t), 10)) + buf.WriteString(`,"i":`) + buf.WriteString(strconv.FormatUint(uint64(i), 10)) + buf.WriteString(`}},`) + + ejvw.buf = append(ejvw.buf, buf.Bytes()...) + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteUndefined() error { + if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil { + return err + } + + ejvw.writeExtendedSingleValue("undefined", "true", false) + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) { + switch ejvw.stack[ejvw.frame].mode { + case mDocument, mTopLevel, mCodeWithScope: + var buf bytes.Buffer + writeStringWithEscapes(key, &buf, ejvw.escapeHTML) + + ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`%s:`, buf.String()))...) + ejvw.push(mElement) + default: + return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope}) + } + + return ejvw, nil +} + +func (ejvw *extJSONValueWriter) WriteDocumentEnd() error { + switch ejvw.stack[ejvw.frame].mode { + case mDocument, mTopLevel, mCodeWithScope: + default: + return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode) + } + + // close the document + if ejvw.buf[len(ejvw.buf)-1] == ',' { + ejvw.buf[len(ejvw.buf)-1] = '}' + } else { + ejvw.buf = append(ejvw.buf, '}') + } + + switch ejvw.stack[ejvw.frame].mode { + case mCodeWithScope: + ejvw.buf = append(ejvw.buf, '}') + fallthrough + case mDocument: + ejvw.buf = append(ejvw.buf, ',') + case mTopLevel: + if ejvw.w != nil { + if _, err := ejvw.w.Write(ejvw.buf); err != nil { + return err + } + ejvw.buf = ejvw.buf[:0] + } + } + + ejvw.pop() + return nil +} + +func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) { + switch ejvw.stack[ejvw.frame].mode { + case mArray: + ejvw.push(mValue) + default: + return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray}) + } + + return ejvw, nil +} + +func (ejvw *extJSONValueWriter) WriteArrayEnd() error { + switch ejvw.stack[ejvw.frame].mode { + case mArray: + // close the array + if ejvw.buf[len(ejvw.buf)-1] == ',' { + ejvw.buf[len(ejvw.buf)-1] = ']' + } else { + ejvw.buf = append(ejvw.buf, ']') + } + + ejvw.buf = append(ejvw.buf, ',') + + ejvw.pop() + default: + return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode) + } + + return nil +} + +func formatDouble(f float64) string { + var s string + if math.IsInf(f, 1) { + s = "Infinity" + } else if math.IsInf(f, -1) { + s = "-Infinity" + } else if math.IsNaN(f) { + s = "NaN" + } else { + // Print exactly one decimalType place for integers; otherwise, print as many are necessary to + // perfectly represent it. + s = strconv.FormatFloat(f, 'G', -1, 64) + if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') { + s += ".0" + } + } + + return s +} + +var hexChars = "0123456789abcdef" + +func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) { + buf.WriteByte('"') + start := 0 + for i := 0; i < len(s); { + if b := s[i]; b < utf8.RuneSelf { + if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) { + i++ + continue + } + if start < i { + buf.WriteString(s[start:i]) + } + switch b { + case '\\', '"': + buf.WriteByte('\\') + buf.WriteByte(b) + case '\n': + buf.WriteByte('\\') + buf.WriteByte('n') + case '\r': + buf.WriteByte('\\') + buf.WriteByte('r') + case '\t': + buf.WriteByte('\\') + buf.WriteByte('t') + case '\b': + buf.WriteByte('\\') + buf.WriteByte('b') + case '\f': + buf.WriteByte('\\') + buf.WriteByte('f') + default: + // This encodes bytes < 0x20 except for \t, \n and \r. + // If escapeHTML is set, it also escapes <, >, and & + // because they can lead to security holes when + // user-controlled strings are rendered into JSON + // and served to some browsers. + buf.WriteString(`\u00`) + buf.WriteByte(hexChars[b>>4]) + buf.WriteByte(hexChars[b&0xF]) + } + i++ + start = i + continue + } + c, size := utf8.DecodeRuneInString(s[i:]) + if c == utf8.RuneError && size == 1 { + if start < i { + buf.WriteString(s[start:i]) + } + buf.WriteString(`\ufffd`) + i += size + start = i + continue + } + // U+2028 is LINE SEPARATOR. + // U+2029 is PARAGRAPH SEPARATOR. + // They are both technically valid characters in JSON strings, + // but don't work in JSONP, which has to be evaluated as JavaScript, + // and can lead to security holes there. It is valid JSON to + // escape them, so we do so unconditionally. + // See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion. + if c == '\u2028' || c == '\u2029' { + if start < i { + buf.WriteString(s[start:i]) + } + buf.WriteString(`\u202`) + buf.WriteByte(hexChars[c&0xF]) + i += size + start = i + continue + } + i += size + } + if start < len(s) { + buf.WriteString(s[start:]) + } + buf.WriteByte('"') +} + +type sortableString []rune + +func (ss sortableString) Len() int { + return len(ss) +} + +func (ss sortableString) Less(i, j int) bool { + return ss[i] < ss[j] +} + +func (ss sortableString) Swap(i, j int) { + oldI := ss[i] + ss[i] = ss[j] + ss[j] = oldI +} + +func sortStringAlphebeticAscending(s string) string { + ss := sortableString([]rune(s)) + sort.Sort(ss) + return string([]rune(ss)) +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/json_scanner.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/json_scanner.go new file mode 100644 index 000000000..cd4843a3a --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/json_scanner.go @@ -0,0 +1,528 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "bytes" + "errors" + "fmt" + "io" + "math" + "strconv" + "unicode" + "unicode/utf16" +) + +type jsonTokenType byte + +const ( + jttBeginObject jsonTokenType = iota + jttEndObject + jttBeginArray + jttEndArray + jttColon + jttComma + jttInt32 + jttInt64 + jttDouble + jttString + jttBool + jttNull + jttEOF +) + +type jsonToken struct { + t jsonTokenType + v interface{} + p int +} + +type jsonScanner struct { + r io.Reader + buf []byte + pos int + lastReadErr error +} + +// nextToken returns the next JSON token if one exists. A token is a character +// of the JSON grammar, a number, a string, or a literal. +func (js *jsonScanner) nextToken() (*jsonToken, error) { + c, err := js.readNextByte() + + // keep reading until a non-space is encountered (break on read error or EOF) + for isWhiteSpace(c) && err == nil { + c, err = js.readNextByte() + } + + if err == io.EOF { + return &jsonToken{t: jttEOF}, nil + } else if err != nil { + return nil, err + } + + // switch on the character + switch c { + case '{': + return &jsonToken{t: jttBeginObject, v: byte('{'), p: js.pos - 1}, nil + case '}': + return &jsonToken{t: jttEndObject, v: byte('}'), p: js.pos - 1}, nil + case '[': + return &jsonToken{t: jttBeginArray, v: byte('['), p: js.pos - 1}, nil + case ']': + return &jsonToken{t: jttEndArray, v: byte(']'), p: js.pos - 1}, nil + case ':': + return &jsonToken{t: jttColon, v: byte(':'), p: js.pos - 1}, nil + case ',': + return &jsonToken{t: jttComma, v: byte(','), p: js.pos - 1}, nil + case '"': // RFC-8259 only allows for double quotes (") not single (') + return js.scanString() + default: + // check if it's a number + if c == '-' || isDigit(c) { + return js.scanNumber(c) + } else if c == 't' || c == 'f' || c == 'n' { + // maybe a literal + return js.scanLiteral(c) + } else { + return nil, fmt.Errorf("invalid JSON input. Position: %d. Character: %c", js.pos-1, c) + } + } +} + +// readNextByte attempts to read the next byte from the buffer. If the buffer +// has been exhausted, this function calls readIntoBuf, thus refilling the +// buffer and resetting the read position to 0 +func (js *jsonScanner) readNextByte() (byte, error) { + if js.pos >= len(js.buf) { + err := js.readIntoBuf() + + if err != nil { + return 0, err + } + } + + b := js.buf[js.pos] + js.pos++ + + return b, nil +} + +// readNNextBytes reads n bytes into dst, starting at offset +func (js *jsonScanner) readNNextBytes(dst []byte, n, offset int) error { + var err error + + for i := 0; i < n; i++ { + dst[i+offset], err = js.readNextByte() + if err != nil { + return err + } + } + + return nil +} + +// readIntoBuf reads up to 512 bytes from the scanner's io.Reader into the buffer +func (js *jsonScanner) readIntoBuf() error { + if js.lastReadErr != nil { + js.buf = js.buf[:0] + js.pos = 0 + return js.lastReadErr + } + + if cap(js.buf) == 0 { + js.buf = make([]byte, 0, 512) + } + + n, err := js.r.Read(js.buf[:cap(js.buf)]) + if err != nil { + js.lastReadErr = err + if n > 0 { + err = nil + } + } + js.buf = js.buf[:n] + js.pos = 0 + + return err +} + +func isWhiteSpace(c byte) bool { + return c == ' ' || c == '\t' || c == '\r' || c == '\n' +} + +func isDigit(c byte) bool { + return unicode.IsDigit(rune(c)) +} + +func isValueTerminator(c byte) bool { + return c == ',' || c == '}' || c == ']' || isWhiteSpace(c) +} + +// getu4 decodes the 4-byte hex sequence from the beginning of s, returning the hex value as a rune, +// or it returns -1. Note that the "\u" from the unicode escape sequence should not be present. +// It is copied and lightly modified from the Go JSON decode function at +// https://github.com/golang/go/blob/1b0a0316802b8048d69da49dc23c5a5ab08e8ae8/src/encoding/json/decode.go#L1169-L1188 +func getu4(s []byte) rune { + if len(s) < 4 { + return -1 + } + var r rune + for _, c := range s[:4] { + switch { + case '0' <= c && c <= '9': + c = c - '0' + case 'a' <= c && c <= 'f': + c = c - 'a' + 10 + case 'A' <= c && c <= 'F': + c = c - 'A' + 10 + default: + return -1 + } + r = r*16 + rune(c) + } + return r +} + +// scanString reads from an opening '"' to a closing '"' and handles escaped characters +func (js *jsonScanner) scanString() (*jsonToken, error) { + var b bytes.Buffer + var c byte + var err error + + p := js.pos - 1 + + for { + c, err = js.readNextByte() + if err != nil { + if err == io.EOF { + return nil, errors.New("end of input in JSON string") + } + return nil, err + } + + evalNextChar: + switch c { + case '\\': + c, err = js.readNextByte() + if err != nil { + if err == io.EOF { + return nil, errors.New("end of input in JSON string") + } + return nil, err + } + + evalNextEscapeChar: + switch c { + case '"', '\\', '/': + b.WriteByte(c) + case 'b': + b.WriteByte('\b') + case 'f': + b.WriteByte('\f') + case 'n': + b.WriteByte('\n') + case 'r': + b.WriteByte('\r') + case 't': + b.WriteByte('\t') + case 'u': + us := make([]byte, 4) + err = js.readNNextBytes(us, 4, 0) + if err != nil { + return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us) + } + + rn := getu4(us) + + // If the rune we just decoded is the high or low value of a possible surrogate pair, + // try to decode the next sequence as the low value of a surrogate pair. We're + // expecting the next sequence to be another Unicode escape sequence (e.g. "\uDD1E"), + // but need to handle cases where the input is not a valid surrogate pair. + // For more context on unicode surrogate pairs, see: + // https://www.christianfscott.com/rust-chars-vs-go-runes/ + // https://www.unicode.org/glossary/#high_surrogate_code_point + if utf16.IsSurrogate(rn) { + c, err = js.readNextByte() + if err != nil { + if err == io.EOF { + return nil, errors.New("end of input in JSON string") + } + return nil, err + } + + // If the next value isn't the beginning of a backslash escape sequence, write + // the Unicode replacement character for the surrogate value and goto the + // beginning of the next char eval block. + if c != '\\' { + b.WriteRune(unicode.ReplacementChar) + goto evalNextChar + } + + c, err = js.readNextByte() + if err != nil { + if err == io.EOF { + return nil, errors.New("end of input in JSON string") + } + return nil, err + } + + // If the next value isn't the beginning of a unicode escape sequence, write the + // Unicode replacement character for the surrogate value and goto the beginning + // of the next escape char eval block. + if c != 'u' { + b.WriteRune(unicode.ReplacementChar) + goto evalNextEscapeChar + } + + err = js.readNNextBytes(us, 4, 0) + if err != nil { + return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us) + } + + rn2 := getu4(us) + + // Try to decode the pair of runes as a utf16 surrogate pair. If that fails, write + // the Unicode replacement character for the surrogate value and the 2nd decoded rune. + if rnPair := utf16.DecodeRune(rn, rn2); rnPair != unicode.ReplacementChar { + b.WriteRune(rnPair) + } else { + b.WriteRune(unicode.ReplacementChar) + b.WriteRune(rn2) + } + + break + } + + b.WriteRune(rn) + default: + return nil, fmt.Errorf("invalid escape sequence in JSON string '\\%c'", c) + } + case '"': + return &jsonToken{t: jttString, v: b.String(), p: p}, nil + default: + b.WriteByte(c) + } + } +} + +// scanLiteral reads an unquoted sequence of characters and determines if it is one of +// three valid JSON literals (true, false, null); if so, it returns the appropriate +// jsonToken; otherwise, it returns an error +func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) { + p := js.pos - 1 + + lit := make([]byte, 4) + lit[0] = first + + err := js.readNNextBytes(lit, 3, 1) + if err != nil { + return nil, err + } + + c5, err := js.readNextByte() + + if bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || err == io.EOF) { + js.pos = int(math.Max(0, float64(js.pos-1))) + return &jsonToken{t: jttBool, v: true, p: p}, nil + } else if bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || err == io.EOF) { + js.pos = int(math.Max(0, float64(js.pos-1))) + return &jsonToken{t: jttNull, v: nil, p: p}, nil + } else if bytes.Equal([]byte("fals"), lit) { + if c5 == 'e' { + c5, err = js.readNextByte() + + if isValueTerminator(c5) || err == io.EOF { + js.pos = int(math.Max(0, float64(js.pos-1))) + return &jsonToken{t: jttBool, v: false, p: p}, nil + } + } + } + + return nil, fmt.Errorf("invalid JSON literal. Position: %d, literal: %s", p, lit) +} + +type numberScanState byte + +const ( + nssSawLeadingMinus numberScanState = iota + nssSawLeadingZero + nssSawIntegerDigits + nssSawDecimalPoint + nssSawFractionDigits + nssSawExponentLetter + nssSawExponentSign + nssSawExponentDigits + nssDone + nssInvalid +) + +// scanNumber reads a JSON number (according to RFC-8259) +func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { + var b bytes.Buffer + var s numberScanState + var c byte + var err error + + t := jttInt64 // assume it's an int64 until the type can be determined + start := js.pos - 1 + + b.WriteByte(first) + + switch first { + case '-': + s = nssSawLeadingMinus + case '0': + s = nssSawLeadingZero + default: + s = nssSawIntegerDigits + } + + for { + c, err = js.readNextByte() + + if err != nil && err != io.EOF { + return nil, err + } + + switch s { + case nssSawLeadingMinus: + switch c { + case '0': + s = nssSawLeadingZero + b.WriteByte(c) + default: + if isDigit(c) { + s = nssSawIntegerDigits + b.WriteByte(c) + } else { + s = nssInvalid + } + } + case nssSawLeadingZero: + switch c { + case '.': + s = nssSawDecimalPoint + b.WriteByte(c) + case 'e', 'E': + s = nssSawExponentLetter + b.WriteByte(c) + case '}', ']', ',': + s = nssDone + default: + if isWhiteSpace(c) || err == io.EOF { + s = nssDone + } else { + s = nssInvalid + } + } + case nssSawIntegerDigits: + switch c { + case '.': + s = nssSawDecimalPoint + b.WriteByte(c) + case 'e', 'E': + s = nssSawExponentLetter + b.WriteByte(c) + case '}', ']', ',': + s = nssDone + default: + if isWhiteSpace(c) || err == io.EOF { + s = nssDone + } else if isDigit(c) { + s = nssSawIntegerDigits + b.WriteByte(c) + } else { + s = nssInvalid + } + } + case nssSawDecimalPoint: + t = jttDouble + if isDigit(c) { + s = nssSawFractionDigits + b.WriteByte(c) + } else { + s = nssInvalid + } + case nssSawFractionDigits: + switch c { + case 'e', 'E': + s = nssSawExponentLetter + b.WriteByte(c) + case '}', ']', ',': + s = nssDone + default: + if isWhiteSpace(c) || err == io.EOF { + s = nssDone + } else if isDigit(c) { + s = nssSawFractionDigits + b.WriteByte(c) + } else { + s = nssInvalid + } + } + case nssSawExponentLetter: + t = jttDouble + switch c { + case '+', '-': + s = nssSawExponentSign + b.WriteByte(c) + default: + if isDigit(c) { + s = nssSawExponentDigits + b.WriteByte(c) + } else { + s = nssInvalid + } + } + case nssSawExponentSign: + if isDigit(c) { + s = nssSawExponentDigits + b.WriteByte(c) + } else { + s = nssInvalid + } + case nssSawExponentDigits: + switch c { + case '}', ']', ',': + s = nssDone + default: + if isWhiteSpace(c) || err == io.EOF { + s = nssDone + } else if isDigit(c) { + s = nssSawExponentDigits + b.WriteByte(c) + } else { + s = nssInvalid + } + } + } + + switch s { + case nssInvalid: + return nil, fmt.Errorf("invalid JSON number. Position: %d", start) + case nssDone: + js.pos = int(math.Max(0, float64(js.pos-1))) + if t != jttDouble { + v, err := strconv.ParseInt(b.String(), 10, 64) + if err == nil { + if v < math.MinInt32 || v > math.MaxInt32 { + return &jsonToken{t: jttInt64, v: v, p: start}, nil + } + + return &jsonToken{t: jttInt32, v: int32(v), p: start}, nil + } + } + + v, err := strconv.ParseFloat(b.String(), 64) + if err != nil { + return nil, err + } + + return &jsonToken{t: jttDouble, v: v, p: start}, nil + } + } +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/mode.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/mode.go new file mode 100644 index 000000000..617b5e221 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/mode.go @@ -0,0 +1,108 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "fmt" +) + +type mode int + +const ( + _ mode = iota + mTopLevel + mDocument + mArray + mValue + mElement + mCodeWithScope + mSpacer +) + +func (m mode) String() string { + var str string + + switch m { + case mTopLevel: + str = "TopLevel" + case mDocument: + str = "DocumentMode" + case mArray: + str = "ArrayMode" + case mValue: + str = "ValueMode" + case mElement: + str = "ElementMode" + case mCodeWithScope: + str = "CodeWithScopeMode" + case mSpacer: + str = "CodeWithScopeSpacerFrame" + default: + str = "UnknownMode" + } + + return str +} + +func (m mode) TypeString() string { + var str string + + switch m { + case mTopLevel: + str = "TopLevel" + case mDocument: + str = "Document" + case mArray: + str = "Array" + case mValue: + str = "Value" + case mElement: + str = "Element" + case mCodeWithScope: + str = "CodeWithScope" + case mSpacer: + str = "CodeWithScopeSpacer" + default: + str = "Unknown" + } + + return str +} + +// TransitionError is an error returned when an invalid progressing a +// ValueReader or ValueWriter state machine occurs. +// If read is false, the error is for writing +type TransitionError struct { + name string + parent mode + current mode + destination mode + modes []mode + action string +} + +func (te TransitionError) Error() string { + errString := fmt.Sprintf("%s can only %s", te.name, te.action) + if te.destination != mode(0) { + errString = fmt.Sprintf("%s a %s", errString, te.destination.TypeString()) + } + errString = fmt.Sprintf("%s while positioned on a", errString) + for ind, m := range te.modes { + if ind != 0 && len(te.modes) > 2 { + errString = fmt.Sprintf("%s,", errString) + } + if ind == len(te.modes)-1 && len(te.modes) > 1 { + errString = fmt.Sprintf("%s or", errString) + } + errString = fmt.Sprintf("%s %s", errString, m.TypeString()) + } + errString = fmt.Sprintf("%s but is positioned on a %s", errString, te.current.TypeString()) + if te.parent != mode(0) { + errString = fmt.Sprintf("%s with parent %s", errString, te.parent.TypeString()) + } + return errString +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/reader.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/reader.go new file mode 100644 index 000000000..0b8fa28d5 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/reader.go @@ -0,0 +1,63 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +// ArrayReader is implemented by types that allow reading values from a BSON +// array. +type ArrayReader interface { + ReadValue() (ValueReader, error) +} + +// DocumentReader is implemented by types that allow reading elements from a +// BSON document. +type DocumentReader interface { + ReadElement() (string, ValueReader, error) +} + +// ValueReader is a generic interface used to read values from BSON. This type +// is implemented by several types with different underlying representations of +// BSON, such as a bson.Document, raw BSON bytes, or extended JSON. +type ValueReader interface { + Type() bsontype.Type + Skip() error + + ReadArray() (ArrayReader, error) + ReadBinary() (b []byte, btype byte, err error) + ReadBoolean() (bool, error) + ReadDocument() (DocumentReader, error) + ReadCodeWithScope() (code string, dr DocumentReader, err error) + ReadDBPointer() (ns string, oid primitive.ObjectID, err error) + ReadDateTime() (int64, error) + ReadDecimal128() (primitive.Decimal128, error) + ReadDouble() (float64, error) + ReadInt32() (int32, error) + ReadInt64() (int64, error) + ReadJavascript() (code string, err error) + ReadMaxKey() error + ReadMinKey() error + ReadNull() error + ReadObjectID() (primitive.ObjectID, error) + ReadRegex() (pattern, options string, err error) + ReadString() (string, error) + ReadSymbol() (symbol string, err error) + ReadTimestamp() (t, i uint32, err error) + ReadUndefined() error +} + +// BytesReader is a generic interface used to read BSON bytes from a +// ValueReader. This imterface is meant to be a superset of ValueReader, so that +// types that implement ValueReader may also implement this interface. +// +// The bytes of the value will be appended to dst. +type BytesReader interface { + ReadValueBytes(dst []byte) (bsontype.Type, []byte, error) +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_reader.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_reader.go new file mode 100644 index 000000000..ef5d837c2 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_reader.go @@ -0,0 +1,874 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "sync" + + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +var _ ValueReader = (*valueReader)(nil) + +var vrPool = sync.Pool{ + New: func() interface{} { + return new(valueReader) + }, +} + +// BSONValueReaderPool is a pool for ValueReaders that read BSON. +type BSONValueReaderPool struct { + pool sync.Pool +} + +// NewBSONValueReaderPool instantiates a new BSONValueReaderPool. +func NewBSONValueReaderPool() *BSONValueReaderPool { + return &BSONValueReaderPool{ + pool: sync.Pool{ + New: func() interface{} { + return new(valueReader) + }, + }, + } +} + +// Get retrieves a ValueReader from the pool and uses src as the underlying BSON. +func (bvrp *BSONValueReaderPool) Get(src []byte) ValueReader { + vr := bvrp.pool.Get().(*valueReader) + vr.reset(src) + return vr +} + +// Put inserts a ValueReader into the pool. If the ValueReader is not a BSON ValueReader nothing +// is inserted into the pool and ok will be false. +func (bvrp *BSONValueReaderPool) Put(vr ValueReader) (ok bool) { + bvr, ok := vr.(*valueReader) + if !ok { + return false + } + + bvr.reset(nil) + bvrp.pool.Put(bvr) + return true +} + +// ErrEOA is the error returned when the end of a BSON array has been reached. +var ErrEOA = errors.New("end of array") + +// ErrEOD is the error returned when the end of a BSON document has been reached. +var ErrEOD = errors.New("end of document") + +type vrState struct { + mode mode + vType bsontype.Type + end int64 +} + +// valueReader is for reading BSON values. +type valueReader struct { + offset int64 + d []byte + + stack []vrState + frame int64 +} + +// NewBSONDocumentReader returns a ValueReader using b for the underlying BSON +// representation. Parameter b must be a BSON Document. +func NewBSONDocumentReader(b []byte) ValueReader { + // TODO(skriptble): There's a lack of symmetry between the reader and writer, since the reader takes a []byte while the + // TODO writer takes an io.Writer. We should have two versions of each, one that takes a []byte and one that takes an + // TODO io.Reader or io.Writer. The []byte version will need to return a thing that can return the finished []byte since + // TODO it might be reallocated when appended to. + return newValueReader(b) +} + +// NewBSONValueReader returns a ValueReader that starts in the Value mode instead of in top +// level document mode. This enables the creation of a ValueReader for a single BSON value. +func NewBSONValueReader(t bsontype.Type, val []byte) ValueReader { + stack := make([]vrState, 1, 5) + stack[0] = vrState{ + mode: mValue, + vType: t, + } + return &valueReader{ + d: val, + stack: stack, + } +} + +func newValueReader(b []byte) *valueReader { + stack := make([]vrState, 1, 5) + stack[0] = vrState{ + mode: mTopLevel, + } + return &valueReader{ + d: b, + stack: stack, + } +} + +func (vr *valueReader) reset(b []byte) { + if vr.stack == nil { + vr.stack = make([]vrState, 1, 5) + } + vr.stack = vr.stack[:1] + vr.stack[0] = vrState{mode: mTopLevel} + vr.d = b + vr.offset = 0 + vr.frame = 0 +} + +func (vr *valueReader) advanceFrame() { + if vr.frame+1 >= int64(len(vr.stack)) { // We need to grow the stack + length := len(vr.stack) + if length+1 >= cap(vr.stack) { + // double it + buf := make([]vrState, 2*cap(vr.stack)+1) + copy(buf, vr.stack) + vr.stack = buf + } + vr.stack = vr.stack[:length+1] + } + vr.frame++ + + // Clean the stack + vr.stack[vr.frame].mode = 0 + vr.stack[vr.frame].vType = 0 + vr.stack[vr.frame].end = 0 +} + +func (vr *valueReader) pushDocument() error { + vr.advanceFrame() + + vr.stack[vr.frame].mode = mDocument + + size, err := vr.readLength() + if err != nil { + return err + } + vr.stack[vr.frame].end = int64(size) + vr.offset - 4 + + return nil +} + +func (vr *valueReader) pushArray() error { + vr.advanceFrame() + + vr.stack[vr.frame].mode = mArray + + size, err := vr.readLength() + if err != nil { + return err + } + vr.stack[vr.frame].end = int64(size) + vr.offset - 4 + + return nil +} + +func (vr *valueReader) pushElement(t bsontype.Type) { + vr.advanceFrame() + + vr.stack[vr.frame].mode = mElement + vr.stack[vr.frame].vType = t +} + +func (vr *valueReader) pushValue(t bsontype.Type) { + vr.advanceFrame() + + vr.stack[vr.frame].mode = mValue + vr.stack[vr.frame].vType = t +} + +func (vr *valueReader) pushCodeWithScope() (int64, error) { + vr.advanceFrame() + + vr.stack[vr.frame].mode = mCodeWithScope + + size, err := vr.readLength() + if err != nil { + return 0, err + } + vr.stack[vr.frame].end = int64(size) + vr.offset - 4 + + return int64(size), nil +} + +func (vr *valueReader) pop() { + switch vr.stack[vr.frame].mode { + case mElement, mValue: + vr.frame-- + case mDocument, mArray, mCodeWithScope: + vr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc... + } +} + +func (vr *valueReader) invalidTransitionErr(destination mode, name string, modes []mode) error { + te := TransitionError{ + name: name, + current: vr.stack[vr.frame].mode, + destination: destination, + modes: modes, + action: "read", + } + if vr.frame != 0 { + te.parent = vr.stack[vr.frame-1].mode + } + return te +} + +func (vr *valueReader) typeError(t bsontype.Type) error { + return fmt.Errorf("positioned on %s, but attempted to read %s", vr.stack[vr.frame].vType, t) +} + +func (vr *valueReader) invalidDocumentLengthError() error { + return fmt.Errorf("document is invalid, end byte is at %d, but null byte found at %d", vr.stack[vr.frame].end, vr.offset) +} + +func (vr *valueReader) ensureElementValue(t bsontype.Type, destination mode, callerName string) error { + switch vr.stack[vr.frame].mode { + case mElement, mValue: + if vr.stack[vr.frame].vType != t { + return vr.typeError(t) + } + default: + return vr.invalidTransitionErr(destination, callerName, []mode{mElement, mValue}) + } + + return nil +} + +func (vr *valueReader) Type() bsontype.Type { + return vr.stack[vr.frame].vType +} + +func (vr *valueReader) nextElementLength() (int32, error) { + var length int32 + var err error + switch vr.stack[vr.frame].vType { + case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope: + length, err = vr.peekLength() + case bsontype.Binary: + length, err = vr.peekLength() + length += 4 + 1 // binary length + subtype byte + case bsontype.Boolean: + length = 1 + case bsontype.DBPointer: + length, err = vr.peekLength() + length += 4 + 12 // string length + ObjectID length + case bsontype.DateTime, bsontype.Double, bsontype.Int64, bsontype.Timestamp: + length = 8 + case bsontype.Decimal128: + length = 16 + case bsontype.Int32: + length = 4 + case bsontype.JavaScript, bsontype.String, bsontype.Symbol: + length, err = vr.peekLength() + length += 4 + case bsontype.MaxKey, bsontype.MinKey, bsontype.Null, bsontype.Undefined: + length = 0 + case bsontype.ObjectID: + length = 12 + case bsontype.Regex: + regex := bytes.IndexByte(vr.d[vr.offset:], 0x00) + if regex < 0 { + err = io.EOF + break + } + pattern := bytes.IndexByte(vr.d[vr.offset+int64(regex)+1:], 0x00) + if pattern < 0 { + err = io.EOF + break + } + length = int32(int64(regex) + 1 + int64(pattern) + 1) + default: + return 0, fmt.Errorf("attempted to read bytes of unknown BSON type %v", vr.stack[vr.frame].vType) + } + + return length, err +} + +func (vr *valueReader) ReadValueBytes(dst []byte) (bsontype.Type, []byte, error) { + switch vr.stack[vr.frame].mode { + case mTopLevel: + length, err := vr.peekLength() + if err != nil { + return bsontype.Type(0), nil, err + } + dst, err = vr.appendBytes(dst, length) + if err != nil { + return bsontype.Type(0), nil, err + } + return bsontype.Type(0), dst, nil + case mElement, mValue: + length, err := vr.nextElementLength() + if err != nil { + return bsontype.Type(0), dst, err + } + + dst, err = vr.appendBytes(dst, length) + t := vr.stack[vr.frame].vType + vr.pop() + return t, dst, err + default: + return bsontype.Type(0), nil, vr.invalidTransitionErr(0, "ReadValueBytes", []mode{mElement, mValue}) + } +} + +func (vr *valueReader) Skip() error { + switch vr.stack[vr.frame].mode { + case mElement, mValue: + default: + return vr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue}) + } + + length, err := vr.nextElementLength() + if err != nil { + return err + } + + err = vr.skipBytes(length) + vr.pop() + return err +} + +func (vr *valueReader) ReadArray() (ArrayReader, error) { + if err := vr.ensureElementValue(bsontype.Array, mArray, "ReadArray"); err != nil { + return nil, err + } + + err := vr.pushArray() + if err != nil { + return nil, err + } + + return vr, nil +} + +func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) { + if err := vr.ensureElementValue(bsontype.Binary, 0, "ReadBinary"); err != nil { + return nil, 0, err + } + + length, err := vr.readLength() + if err != nil { + return nil, 0, err + } + + btype, err = vr.readByte() + if err != nil { + return nil, 0, err + } + + // Check length in case it is an old binary without a length. + if btype == 0x02 && length > 4 { + length, err = vr.readLength() + if err != nil { + return nil, 0, err + } + } + + b, err = vr.readBytes(length) + if err != nil { + return nil, 0, err + } + // Make a copy of the returned byte slice because it's just a subslice from the valueReader's + // buffer and is not safe to return in the unmarshaled value. + cp := make([]byte, len(b)) + copy(cp, b) + + vr.pop() + return cp, btype, nil +} + +func (vr *valueReader) ReadBoolean() (bool, error) { + if err := vr.ensureElementValue(bsontype.Boolean, 0, "ReadBoolean"); err != nil { + return false, err + } + + b, err := vr.readByte() + if err != nil { + return false, err + } + + if b > 1 { + return false, fmt.Errorf("invalid byte for boolean, %b", b) + } + + vr.pop() + return b == 1, nil +} + +func (vr *valueReader) ReadDocument() (DocumentReader, error) { + switch vr.stack[vr.frame].mode { + case mTopLevel: + // read size + size, err := vr.readLength() + if err != nil { + return nil, err + } + if int(size) != len(vr.d) { + return nil, fmt.Errorf("invalid document length") + } + vr.stack[vr.frame].end = int64(size) + vr.offset - 4 + return vr, nil + case mElement, mValue: + if vr.stack[vr.frame].vType != bsontype.EmbeddedDocument { + return nil, vr.typeError(bsontype.EmbeddedDocument) + } + default: + return nil, vr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue}) + } + + err := vr.pushDocument() + if err != nil { + return nil, err + } + + return vr, nil +} + +func (vr *valueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) { + if err := vr.ensureElementValue(bsontype.CodeWithScope, 0, "ReadCodeWithScope"); err != nil { + return "", nil, err + } + + totalLength, err := vr.readLength() + if err != nil { + return "", nil, err + } + strLength, err := vr.readLength() + if err != nil { + return "", nil, err + } + if strLength <= 0 { + return "", nil, fmt.Errorf("invalid string length: %d", strLength) + } + strBytes, err := vr.readBytes(strLength) + if err != nil { + return "", nil, err + } + code = string(strBytes[:len(strBytes)-1]) + + size, err := vr.pushCodeWithScope() + if err != nil { + return "", nil, err + } + + // The total length should equal: + // 4 (total length) + strLength + 4 (the length of str itself) + (document length) + componentsLength := int64(4+strLength+4) + size + if int64(totalLength) != componentsLength { + return "", nil, fmt.Errorf( + "length of CodeWithScope does not match lengths of components; total: %d; components: %d", + totalLength, componentsLength, + ) + } + return code, vr, nil +} + +func (vr *valueReader) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) { + if err := vr.ensureElementValue(bsontype.DBPointer, 0, "ReadDBPointer"); err != nil { + return "", oid, err + } + + ns, err = vr.readString() + if err != nil { + return "", oid, err + } + + oidbytes, err := vr.readBytes(12) + if err != nil { + return "", oid, err + } + + copy(oid[:], oidbytes) + + vr.pop() + return ns, oid, nil +} + +func (vr *valueReader) ReadDateTime() (int64, error) { + if err := vr.ensureElementValue(bsontype.DateTime, 0, "ReadDateTime"); err != nil { + return 0, err + } + + i, err := vr.readi64() + if err != nil { + return 0, err + } + + vr.pop() + return i, nil +} + +func (vr *valueReader) ReadDecimal128() (primitive.Decimal128, error) { + if err := vr.ensureElementValue(bsontype.Decimal128, 0, "ReadDecimal128"); err != nil { + return primitive.Decimal128{}, err + } + + b, err := vr.readBytes(16) + if err != nil { + return primitive.Decimal128{}, err + } + + l := binary.LittleEndian.Uint64(b[0:8]) + h := binary.LittleEndian.Uint64(b[8:16]) + + vr.pop() + return primitive.NewDecimal128(h, l), nil +} + +func (vr *valueReader) ReadDouble() (float64, error) { + if err := vr.ensureElementValue(bsontype.Double, 0, "ReadDouble"); err != nil { + return 0, err + } + + u, err := vr.readu64() + if err != nil { + return 0, err + } + + vr.pop() + return math.Float64frombits(u), nil +} + +func (vr *valueReader) ReadInt32() (int32, error) { + if err := vr.ensureElementValue(bsontype.Int32, 0, "ReadInt32"); err != nil { + return 0, err + } + + vr.pop() + return vr.readi32() +} + +func (vr *valueReader) ReadInt64() (int64, error) { + if err := vr.ensureElementValue(bsontype.Int64, 0, "ReadInt64"); err != nil { + return 0, err + } + + vr.pop() + return vr.readi64() +} + +func (vr *valueReader) ReadJavascript() (code string, err error) { + if err := vr.ensureElementValue(bsontype.JavaScript, 0, "ReadJavascript"); err != nil { + return "", err + } + + vr.pop() + return vr.readString() +} + +func (vr *valueReader) ReadMaxKey() error { + if err := vr.ensureElementValue(bsontype.MaxKey, 0, "ReadMaxKey"); err != nil { + return err + } + + vr.pop() + return nil +} + +func (vr *valueReader) ReadMinKey() error { + if err := vr.ensureElementValue(bsontype.MinKey, 0, "ReadMinKey"); err != nil { + return err + } + + vr.pop() + return nil +} + +func (vr *valueReader) ReadNull() error { + if err := vr.ensureElementValue(bsontype.Null, 0, "ReadNull"); err != nil { + return err + } + + vr.pop() + return nil +} + +func (vr *valueReader) ReadObjectID() (primitive.ObjectID, error) { + if err := vr.ensureElementValue(bsontype.ObjectID, 0, "ReadObjectID"); err != nil { + return primitive.ObjectID{}, err + } + + oidbytes, err := vr.readBytes(12) + if err != nil { + return primitive.ObjectID{}, err + } + + var oid primitive.ObjectID + copy(oid[:], oidbytes) + + vr.pop() + return oid, nil +} + +func (vr *valueReader) ReadRegex() (string, string, error) { + if err := vr.ensureElementValue(bsontype.Regex, 0, "ReadRegex"); err != nil { + return "", "", err + } + + pattern, err := vr.readCString() + if err != nil { + return "", "", err + } + + options, err := vr.readCString() + if err != nil { + return "", "", err + } + + vr.pop() + return pattern, options, nil +} + +func (vr *valueReader) ReadString() (string, error) { + if err := vr.ensureElementValue(bsontype.String, 0, "ReadString"); err != nil { + return "", err + } + + vr.pop() + return vr.readString() +} + +func (vr *valueReader) ReadSymbol() (symbol string, err error) { + if err := vr.ensureElementValue(bsontype.Symbol, 0, "ReadSymbol"); err != nil { + return "", err + } + + vr.pop() + return vr.readString() +} + +func (vr *valueReader) ReadTimestamp() (t uint32, i uint32, err error) { + if err := vr.ensureElementValue(bsontype.Timestamp, 0, "ReadTimestamp"); err != nil { + return 0, 0, err + } + + i, err = vr.readu32() + if err != nil { + return 0, 0, err + } + + t, err = vr.readu32() + if err != nil { + return 0, 0, err + } + + vr.pop() + return t, i, nil +} + +func (vr *valueReader) ReadUndefined() error { + if err := vr.ensureElementValue(bsontype.Undefined, 0, "ReadUndefined"); err != nil { + return err + } + + vr.pop() + return nil +} + +func (vr *valueReader) ReadElement() (string, ValueReader, error) { + switch vr.stack[vr.frame].mode { + case mTopLevel, mDocument, mCodeWithScope: + default: + return "", nil, vr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope}) + } + + t, err := vr.readByte() + if err != nil { + return "", nil, err + } + + if t == 0 { + if vr.offset != vr.stack[vr.frame].end { + return "", nil, vr.invalidDocumentLengthError() + } + + vr.pop() + return "", nil, ErrEOD + } + + name, err := vr.readCString() + if err != nil { + return "", nil, err + } + + vr.pushElement(bsontype.Type(t)) + return name, vr, nil +} + +func (vr *valueReader) ReadValue() (ValueReader, error) { + switch vr.stack[vr.frame].mode { + case mArray: + default: + return nil, vr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray}) + } + + t, err := vr.readByte() + if err != nil { + return nil, err + } + + if t == 0 { + if vr.offset != vr.stack[vr.frame].end { + return nil, vr.invalidDocumentLengthError() + } + + vr.pop() + return nil, ErrEOA + } + + _, err = vr.readCString() + if err != nil { + return nil, err + } + + vr.pushValue(bsontype.Type(t)) + return vr, nil +} + +// readBytes reads length bytes from the valueReader starting at the current offset. Note that the +// returned byte slice is a subslice from the valueReader buffer and must be converted or copied +// before returning in an unmarshaled value. +func (vr *valueReader) readBytes(length int32) ([]byte, error) { + if length < 0 { + return nil, fmt.Errorf("invalid length: %d", length) + } + + if vr.offset+int64(length) > int64(len(vr.d)) { + return nil, io.EOF + } + + start := vr.offset + vr.offset += int64(length) + + return vr.d[start : start+int64(length)], nil +} + +func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) { + if vr.offset+int64(length) > int64(len(vr.d)) { + return nil, io.EOF + } + + start := vr.offset + vr.offset += int64(length) + return append(dst, vr.d[start:start+int64(length)]...), nil +} + +func (vr *valueReader) skipBytes(length int32) error { + if vr.offset+int64(length) > int64(len(vr.d)) { + return io.EOF + } + + vr.offset += int64(length) + return nil +} + +func (vr *valueReader) readByte() (byte, error) { + if vr.offset+1 > int64(len(vr.d)) { + return 0x0, io.EOF + } + + vr.offset++ + return vr.d[vr.offset-1], nil +} + +func (vr *valueReader) readCString() (string, error) { + idx := bytes.IndexByte(vr.d[vr.offset:], 0x00) + if idx < 0 { + return "", io.EOF + } + start := vr.offset + // idx does not include the null byte + vr.offset += int64(idx) + 1 + return string(vr.d[start : start+int64(idx)]), nil +} + +func (vr *valueReader) readString() (string, error) { + length, err := vr.readLength() + if err != nil { + return "", err + } + + if int64(length)+vr.offset > int64(len(vr.d)) { + return "", io.EOF + } + + if length <= 0 { + return "", fmt.Errorf("invalid string length: %d", length) + } + + if vr.d[vr.offset+int64(length)-1] != 0x00 { + return "", fmt.Errorf("string does not end with null byte, but with %v", vr.d[vr.offset+int64(length)-1]) + } + + start := vr.offset + vr.offset += int64(length) + return string(vr.d[start : start+int64(length)-1]), nil +} + +func (vr *valueReader) peekLength() (int32, error) { + if vr.offset+4 > int64(len(vr.d)) { + return 0, io.EOF + } + + idx := vr.offset + return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil +} + +func (vr *valueReader) readLength() (int32, error) { return vr.readi32() } + +func (vr *valueReader) readi32() (int32, error) { + if vr.offset+4 > int64(len(vr.d)) { + return 0, io.EOF + } + + idx := vr.offset + vr.offset += 4 + return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil +} + +func (vr *valueReader) readu32() (uint32, error) { + if vr.offset+4 > int64(len(vr.d)) { + return 0, io.EOF + } + + idx := vr.offset + vr.offset += 4 + return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil +} + +func (vr *valueReader) readi64() (int64, error) { + if vr.offset+8 > int64(len(vr.d)) { + return 0, io.EOF + } + + idx := vr.offset + vr.offset += 8 + return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 | + int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil +} + +func (vr *valueReader) readu64() (uint64, error) { + if vr.offset+8 > int64(len(vr.d)) { + return 0, io.EOF + } + + idx := vr.offset + vr.offset += 8 + return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 | + uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_writer.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_writer.go new file mode 100644 index 000000000..f95a08afd --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_writer.go @@ -0,0 +1,606 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "errors" + "fmt" + "io" + "math" + "strconv" + "strings" + "sync" + + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +var _ ValueWriter = (*valueWriter)(nil) + +var vwPool = sync.Pool{ + New: func() interface{} { + return new(valueWriter) + }, +} + +// BSONValueWriterPool is a pool for BSON ValueWriters. +type BSONValueWriterPool struct { + pool sync.Pool +} + +// NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON. +func NewBSONValueWriterPool() *BSONValueWriterPool { + return &BSONValueWriterPool{ + pool: sync.Pool{ + New: func() interface{} { + return new(valueWriter) + }, + }, + } +} + +// Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination. +func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter { + vw := bvwp.pool.Get().(*valueWriter) + + // TODO: Having to call reset here with the same buffer doesn't really make sense. + vw.reset(vw.buf) + vw.buf = vw.buf[:0] + vw.w = w + return vw +} + +// GetAtModeElement retrieves a ValueWriterFlusher from the pool and resets it to use w as the destination. +func (bvwp *BSONValueWriterPool) GetAtModeElement(w io.Writer) ValueWriterFlusher { + vw := bvwp.Get(w).(*valueWriter) + vw.push(mElement) + return vw +} + +// Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing +// happens and ok will be false. +func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) { + bvw, ok := vw.(*valueWriter) + if !ok { + return false + } + + bvwp.pool.Put(bvw) + return true +} + +// This is here so that during testing we can change it and not require +// allocating a 4GB slice. +var maxSize = math.MaxInt32 + +var errNilWriter = errors.New("cannot create a ValueWriter from a nil io.Writer") + +type errMaxDocumentSizeExceeded struct { + size int64 +} + +func (mdse errMaxDocumentSizeExceeded) Error() string { + return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size) +} + +type vwMode int + +const ( + _ vwMode = iota + vwTopLevel + vwDocument + vwArray + vwValue + vwElement + vwCodeWithScope +) + +func (vm vwMode) String() string { + var str string + + switch vm { + case vwTopLevel: + str = "TopLevel" + case vwDocument: + str = "DocumentMode" + case vwArray: + str = "ArrayMode" + case vwValue: + str = "ValueMode" + case vwElement: + str = "ElementMode" + case vwCodeWithScope: + str = "CodeWithScopeMode" + default: + str = "UnknownMode" + } + + return str +} + +type vwState struct { + mode mode + key string + arrkey int + start int32 +} + +type valueWriter struct { + w io.Writer + buf []byte + + stack []vwState + frame int64 +} + +func (vw *valueWriter) advanceFrame() { + if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack + length := len(vw.stack) + if length+1 >= cap(vw.stack) { + // double it + buf := make([]vwState, 2*cap(vw.stack)+1) + copy(buf, vw.stack) + vw.stack = buf + } + vw.stack = vw.stack[:length+1] + } + vw.frame++ +} + +func (vw *valueWriter) push(m mode) { + vw.advanceFrame() + + // Clean the stack + vw.stack[vw.frame].mode = m + vw.stack[vw.frame].key = "" + vw.stack[vw.frame].arrkey = 0 + vw.stack[vw.frame].start = 0 + + vw.stack[vw.frame].mode = m + switch m { + case mDocument, mArray, mCodeWithScope: + vw.reserveLength() + } +} + +func (vw *valueWriter) reserveLength() { + vw.stack[vw.frame].start = int32(len(vw.buf)) + vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00) +} + +func (vw *valueWriter) pop() { + switch vw.stack[vw.frame].mode { + case mElement, mValue: + vw.frame-- + case mDocument, mArray, mCodeWithScope: + vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc... + } +} + +// NewBSONValueWriter creates a ValueWriter that writes BSON to w. +// +// This ValueWriter will only write entire documents to the io.Writer and it +// will buffer the document as it is built. +func NewBSONValueWriter(w io.Writer) (ValueWriter, error) { + if w == nil { + return nil, errNilWriter + } + return newValueWriter(w), nil +} + +func newValueWriter(w io.Writer) *valueWriter { + vw := new(valueWriter) + stack := make([]vwState, 1, 5) + stack[0] = vwState{mode: mTopLevel} + vw.w = w + vw.stack = stack + + return vw +} + +func newValueWriterFromSlice(buf []byte) *valueWriter { + vw := new(valueWriter) + stack := make([]vwState, 1, 5) + stack[0] = vwState{mode: mTopLevel} + vw.stack = stack + vw.buf = buf + + return vw +} + +func (vw *valueWriter) reset(buf []byte) { + if vw.stack == nil { + vw.stack = make([]vwState, 1, 5) + } + vw.stack = vw.stack[:1] + vw.stack[0] = vwState{mode: mTopLevel} + vw.buf = buf + vw.frame = 0 + vw.w = nil +} + +func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error { + te := TransitionError{ + name: name, + current: vw.stack[vw.frame].mode, + destination: destination, + modes: modes, + action: "write", + } + if vw.frame != 0 { + te.parent = vw.stack[vw.frame-1].mode + } + return te +} + +func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error { + switch vw.stack[vw.frame].mode { + case mElement: + key := vw.stack[vw.frame].key + if !isValidCString(key) { + return errors.New("BSON element key cannot contain null bytes") + } + + vw.buf = bsoncore.AppendHeader(vw.buf, t, key) + case mValue: + // TODO: Do this with a cache of the first 1000 or so array keys. + vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey)) + default: + modes := []mode{mElement, mValue} + if addmodes != nil { + modes = append(modes, addmodes...) + } + return vw.invalidTransitionError(destination, callerName, modes) + } + + return nil +} + +func (vw *valueWriter) WriteValueBytes(t bsontype.Type, b []byte) error { + if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil { + return err + } + vw.buf = append(vw.buf, b...) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteArray() (ArrayWriter, error) { + if err := vw.writeElementHeader(bsontype.Array, mArray, "WriteArray"); err != nil { + return nil, err + } + + vw.push(mArray) + + return vw, nil +} + +func (vw *valueWriter) WriteBinary(b []byte) error { + return vw.WriteBinaryWithSubtype(b, 0x00) +} + +func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error { + if err := vw.writeElementHeader(bsontype.Binary, mode(0), "WriteBinaryWithSubtype"); err != nil { + return err + } + + vw.buf = bsoncore.AppendBinary(vw.buf, btype, b) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteBoolean(b bool) error { + if err := vw.writeElementHeader(bsontype.Boolean, mode(0), "WriteBoolean"); err != nil { + return err + } + + vw.buf = bsoncore.AppendBoolean(vw.buf, b) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) { + if err := vw.writeElementHeader(bsontype.CodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil { + return nil, err + } + + // CodeWithScope is a different than other types because we need an extra + // frame on the stack. In the EndDocument code, we write the document + // length, pop, write the code with scope length, and pop. To simplify the + // pop code, we push a spacer frame that we'll always jump over. + vw.push(mCodeWithScope) + vw.buf = bsoncore.AppendString(vw.buf, code) + vw.push(mSpacer) + vw.push(mDocument) + + return vw, nil +} + +func (vw *valueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error { + if err := vw.writeElementHeader(bsontype.DBPointer, mode(0), "WriteDBPointer"); err != nil { + return err + } + + vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteDateTime(dt int64) error { + if err := vw.writeElementHeader(bsontype.DateTime, mode(0), "WriteDateTime"); err != nil { + return err + } + + vw.buf = bsoncore.AppendDateTime(vw.buf, dt) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteDecimal128(d128 primitive.Decimal128) error { + if err := vw.writeElementHeader(bsontype.Decimal128, mode(0), "WriteDecimal128"); err != nil { + return err + } + + vw.buf = bsoncore.AppendDecimal128(vw.buf, d128) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteDouble(f float64) error { + if err := vw.writeElementHeader(bsontype.Double, mode(0), "WriteDouble"); err != nil { + return err + } + + vw.buf = bsoncore.AppendDouble(vw.buf, f) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteInt32(i32 int32) error { + if err := vw.writeElementHeader(bsontype.Int32, mode(0), "WriteInt32"); err != nil { + return err + } + + vw.buf = bsoncore.AppendInt32(vw.buf, i32) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteInt64(i64 int64) error { + if err := vw.writeElementHeader(bsontype.Int64, mode(0), "WriteInt64"); err != nil { + return err + } + + vw.buf = bsoncore.AppendInt64(vw.buf, i64) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteJavascript(code string) error { + if err := vw.writeElementHeader(bsontype.JavaScript, mode(0), "WriteJavascript"); err != nil { + return err + } + + vw.buf = bsoncore.AppendJavaScript(vw.buf, code) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteMaxKey() error { + if err := vw.writeElementHeader(bsontype.MaxKey, mode(0), "WriteMaxKey"); err != nil { + return err + } + + vw.pop() + return nil +} + +func (vw *valueWriter) WriteMinKey() error { + if err := vw.writeElementHeader(bsontype.MinKey, mode(0), "WriteMinKey"); err != nil { + return err + } + + vw.pop() + return nil +} + +func (vw *valueWriter) WriteNull() error { + if err := vw.writeElementHeader(bsontype.Null, mode(0), "WriteNull"); err != nil { + return err + } + + vw.pop() + return nil +} + +func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error { + if err := vw.writeElementHeader(bsontype.ObjectID, mode(0), "WriteObjectID"); err != nil { + return err + } + + vw.buf = bsoncore.AppendObjectID(vw.buf, oid) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteRegex(pattern string, options string) error { + if !isValidCString(pattern) || !isValidCString(options) { + return errors.New("BSON regex values cannot contain null bytes") + } + if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil { + return err + } + + vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options)) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteString(s string) error { + if err := vw.writeElementHeader(bsontype.String, mode(0), "WriteString"); err != nil { + return err + } + + vw.buf = bsoncore.AppendString(vw.buf, s) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteDocument() (DocumentWriter, error) { + if vw.stack[vw.frame].mode == mTopLevel { + vw.reserveLength() + return vw, nil + } + if err := vw.writeElementHeader(bsontype.EmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil { + return nil, err + } + + vw.push(mDocument) + return vw, nil +} + +func (vw *valueWriter) WriteSymbol(symbol string) error { + if err := vw.writeElementHeader(bsontype.Symbol, mode(0), "WriteSymbol"); err != nil { + return err + } + + vw.buf = bsoncore.AppendSymbol(vw.buf, symbol) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error { + if err := vw.writeElementHeader(bsontype.Timestamp, mode(0), "WriteTimestamp"); err != nil { + return err + } + + vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i) + vw.pop() + return nil +} + +func (vw *valueWriter) WriteUndefined() error { + if err := vw.writeElementHeader(bsontype.Undefined, mode(0), "WriteUndefined"); err != nil { + return err + } + + vw.pop() + return nil +} + +func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) { + switch vw.stack[vw.frame].mode { + case mTopLevel, mDocument: + default: + return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument}) + } + + vw.push(mElement) + vw.stack[vw.frame].key = key + + return vw, nil +} + +func (vw *valueWriter) WriteDocumentEnd() error { + switch vw.stack[vw.frame].mode { + case mTopLevel, mDocument: + default: + return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode) + } + + vw.buf = append(vw.buf, 0x00) + + err := vw.writeLength() + if err != nil { + return err + } + + if vw.stack[vw.frame].mode == mTopLevel { + if err = vw.Flush(); err != nil { + return err + } + } + + vw.pop() + + if vw.stack[vw.frame].mode == mCodeWithScope { + // We ignore the error here because of the guarantee of writeLength. + // See the docs for writeLength for more info. + _ = vw.writeLength() + vw.pop() + } + return nil +} + +func (vw *valueWriter) Flush() error { + if vw.w == nil { + return nil + } + + if _, err := vw.w.Write(vw.buf); err != nil { + return err + } + // reset buffer + vw.buf = vw.buf[:0] + return nil +} + +func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) { + if vw.stack[vw.frame].mode != mArray { + return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray}) + } + + arrkey := vw.stack[vw.frame].arrkey + vw.stack[vw.frame].arrkey++ + + vw.push(mValue) + vw.stack[vw.frame].arrkey = arrkey + + return vw, nil +} + +func (vw *valueWriter) WriteArrayEnd() error { + if vw.stack[vw.frame].mode != mArray { + return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode) + } + + vw.buf = append(vw.buf, 0x00) + + err := vw.writeLength() + if err != nil { + return err + } + + vw.pop() + return nil +} + +// NOTE: We assume that if we call writeLength more than once the same function +// within the same function without altering the vw.buf that this method will +// not return an error. If this changes ensure that the following methods are +// updated: +// +// - WriteDocumentEnd +func (vw *valueWriter) writeLength() error { + length := len(vw.buf) + if length > maxSize { + return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))} + } + length = length - int(vw.stack[vw.frame].start) + start := vw.stack[vw.frame].start + + vw.buf[start+0] = byte(length) + vw.buf[start+1] = byte(length >> 8) + vw.buf[start+2] = byte(length >> 16) + vw.buf[start+3] = byte(length >> 24) + return nil +} + +func isValidCString(cs string) bool { + return !strings.ContainsRune(cs, '\x00') +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/writer.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/writer.go new file mode 100644 index 000000000..dff65f87f --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/writer.go @@ -0,0 +1,78 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonrw + +import ( + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +// ArrayWriter is the interface used to create a BSON or BSON adjacent array. +// Callers must ensure they call WriteArrayEnd when they have finished creating +// the array. +type ArrayWriter interface { + WriteArrayElement() (ValueWriter, error) + WriteArrayEnd() error +} + +// DocumentWriter is the interface used to create a BSON or BSON adjacent +// document. Callers must ensure they call WriteDocumentEnd when they have +// finished creating the document. +type DocumentWriter interface { + WriteDocumentElement(string) (ValueWriter, error) + WriteDocumentEnd() error +} + +// ValueWriter is the interface used to write BSON values. Implementations of +// this interface handle creating BSON or BSON adjacent representations of the +// values. +type ValueWriter interface { + WriteArray() (ArrayWriter, error) + WriteBinary(b []byte) error + WriteBinaryWithSubtype(b []byte, btype byte) error + WriteBoolean(bool) error + WriteCodeWithScope(code string) (DocumentWriter, error) + WriteDBPointer(ns string, oid primitive.ObjectID) error + WriteDateTime(dt int64) error + WriteDecimal128(primitive.Decimal128) error + WriteDouble(float64) error + WriteInt32(int32) error + WriteInt64(int64) error + WriteJavascript(code string) error + WriteMaxKey() error + WriteMinKey() error + WriteNull() error + WriteObjectID(primitive.ObjectID) error + WriteRegex(pattern, options string) error + WriteString(string) error + WriteDocument() (DocumentWriter, error) + WriteSymbol(symbol string) error + WriteTimestamp(t, i uint32) error + WriteUndefined() error +} + +// ValueWriterFlusher is a superset of ValueWriter that exposes functionality to flush to the underlying buffer. +type ValueWriterFlusher interface { + ValueWriter + Flush() error +} + +// BytesWriter is the interface used to write BSON bytes to a ValueWriter. +// This interface is meant to be a superset of ValueWriter, so that types that +// implement ValueWriter may also implement this interface. +type BytesWriter interface { + WriteValueBytes(t bsontype.Type, b []byte) error +} + +// SliceWriter allows a pointer to a slice of bytes to be used as an io.Writer. +type SliceWriter []byte + +func (sw *SliceWriter) Write(p []byte) (int, error) { + written := len(p) + *sw = append(*sw, p...) + return written, nil +} |