diff options
Diffstat (limited to 'vendor/google.golang.org/protobuf/internal/impl/lazy.go')
-rw-r--r-- | vendor/google.golang.org/protobuf/internal/impl/lazy.go | 433 |
1 files changed, 433 insertions, 0 deletions
diff --git a/vendor/google.golang.org/protobuf/internal/impl/lazy.go b/vendor/google.golang.org/protobuf/internal/impl/lazy.go new file mode 100644 index 000000000..e8fb6c35b --- /dev/null +++ b/vendor/google.golang.org/protobuf/internal/impl/lazy.go @@ -0,0 +1,433 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl + +import ( + "fmt" + "math/bits" + "os" + "reflect" + "sort" + "sync/atomic" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/internal/errors" + "google.golang.org/protobuf/internal/protolazy" + "google.golang.org/protobuf/reflect/protoreflect" + preg "google.golang.org/protobuf/reflect/protoregistry" + piface "google.golang.org/protobuf/runtime/protoiface" +) + +var enableLazy int32 = func() int32 { + if os.Getenv("GOPROTODEBUG") == "nolazy" { + return 0 + } + return 1 +}() + +// EnableLazyUnmarshal enables lazy unmarshaling. +func EnableLazyUnmarshal(enable bool) { + if enable { + atomic.StoreInt32(&enableLazy, 1) + return + } + atomic.StoreInt32(&enableLazy, 0) +} + +// LazyEnabled reports whether lazy unmarshalling is currently enabled. +func LazyEnabled() bool { + return atomic.LoadInt32(&enableLazy) != 0 +} + +// UnmarshalField unmarshals a field in a message. +func UnmarshalField(m interface{}, num protowire.Number) { + switch m := m.(type) { + case *messageState: + m.messageInfo().lazyUnmarshal(m.pointer(), num) + case *messageReflectWrapper: + m.messageInfo().lazyUnmarshal(m.pointer(), num) + default: + panic(fmt.Sprintf("unsupported wrapper type %T", m)) + } +} + +func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) { + var f *coderFieldInfo + if int(num) < len(mi.denseCoderFields) { + f = mi.denseCoderFields[num] + } else { + f = mi.coderFields[num] + } + if f == nil { + panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num)) + } + lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr() + start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num)) + if !found && multipleEntries == nil { + panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num)) + } + // The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races. + // Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil. + fp := pointerOfValue(reflect.New(f.ft)) + if multipleEntries != nil { + for _, entry := range multipleEntries { + mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags()) + } + } else { + mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags()) + } + p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem()) +} + +func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error { + opts := lazyUnmarshalOptions + opts.flags |= flags + for len(b) > 0 { + // Parse the tag (field number and wire type). + var tag uint64 + if b[0] < 0x80 { + tag = uint64(b[0]) + b = b[1:] + } else if len(b) >= 2 && b[1] < 128 { + tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 + b = b[2:] + } else { + var n int + tag, n = protowire.ConsumeVarint(b) + if n < 0 { + return errors.New("invalid wire data") + } + b = b[n:] + } + var num protowire.Number + if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { + return errors.New("invalid wire data") + } else { + num = protowire.Number(n) + } + wtyp := protowire.Type(tag & 7) + if num == f.num { + o, err := f.funcs.unmarshal(b, p, wtyp, f, opts) + if err == nil { + b = b[o.n:] + continue + } + if err != errUnknown { + return err + } + } + n := protowire.ConsumeFieldValue(num, wtyp, b) + if n < 0 { + return errors.New("invalid wire data") + } + b = b[n:] + } + return nil +} + +func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { + fmi := f.validation.mi + if fmi == nil { + fd := mi.Desc.Fields().ByNumber(f.num) + if fd == nil || !fd.IsWeak() { + return out, ValidationUnknown + } + messageName := fd.Message().FullName() + messageType, err := preg.GlobalTypes.FindMessageByName(messageName) + if err != nil { + return out, ValidationUnknown + } + var ok bool + fmi, ok = messageType.(*MessageInfo) + if !ok { + return out, ValidationUnknown + } + } + fmi.init() + switch f.validation.typ { + case validationTypeMessage: + if wtyp != protowire.BytesType { + return out, ValidationWrongWireType + } + v, n := protowire.ConsumeBytes(b) + if n < 0 { + return out, ValidationInvalid + } + out, st := fmi.validate(v, 0, opts) + out.n = n + return out, st + case validationTypeGroup: + if wtyp != protowire.StartGroupType { + return out, ValidationWrongWireType + } + out, st := fmi.validate(b, f.num, opts) + return out, st + default: + return out, ValidationUnknown + } +} + +// unmarshalPointerLazy is similar to unmarshalPointerEager, but it +// specifically handles lazy unmarshalling. it expects lazyOffset and +// presenceOffset to both be valid. +func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { + initialized := true + var requiredMask uint64 + var lazy **protolazy.XXX_lazyUnmarshalInfo + var presence presence + var lazyIndex []protolazy.IndexEntry + var lastNum protowire.Number + outOfOrder := false + lazyDecode := false + presence = p.Apply(mi.presenceOffset).PresenceInfo() + lazy = p.Apply(mi.lazyOffset).LazyInfoPtr() + if !presence.AnyPresent(mi.presenceSize) { + if opts.CanBeLazy() { + // If the message contains existing data, we need to merge into it. + // Lazy unmarshaling doesn't merge, so only enable it when the + // message is empty (has no presence bitmap). + lazyDecode = true + if *lazy == nil { + *lazy = &protolazy.XXX_lazyUnmarshalInfo{} + } + (*lazy).SetUnmarshalFlags(opts.flags) + if !opts.AliasBuffer() { + // Make a copy of the buffer for lazy unmarshaling. + // Set the AliasBuffer flag so recursive unmarshal + // operations reuse the copy. + b = append([]byte{}, b...) + opts.flags |= piface.UnmarshalAliasBuffer + } + (*lazy).SetBuffer(b) + } + } + // Track special handling of lazy fields. + // + // In the common case, all fields are lazyValidateOnly (and lazyFields remains nil). + // In the event that validation for a field fails, this map tracks handling of the field. + type lazyAction uint8 + const ( + lazyValidateOnly lazyAction = iota // validate the field only + lazyUnmarshalNow // eagerly unmarshal the field + lazyUnmarshalLater // unmarshal the field after the message is fully processed + ) + var lazyFields map[*coderFieldInfo]lazyAction + var exts *map[int32]ExtensionField + start := len(b) + pos := 0 + for len(b) > 0 { + // Parse the tag (field number and wire type). + var tag uint64 + if b[0] < 0x80 { + tag = uint64(b[0]) + b = b[1:] + } else if len(b) >= 2 && b[1] < 128 { + tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 + b = b[2:] + } else { + var n int + tag, n = protowire.ConsumeVarint(b) + if n < 0 { + return out, errDecode + } + b = b[n:] + } + var num protowire.Number + if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { + return out, errors.New("invalid field number") + } else { + num = protowire.Number(n) + } + wtyp := protowire.Type(tag & 7) + + if wtyp == protowire.EndGroupType { + if num != groupTag { + return out, errors.New("mismatching end group marker") + } + groupTag = 0 + break + } + + var f *coderFieldInfo + if int(num) < len(mi.denseCoderFields) { + f = mi.denseCoderFields[num] + } else { + f = mi.coderFields[num] + } + var n int + err := errUnknown + discardUnknown := false + Field: + switch { + case f != nil: + if f.funcs.unmarshal == nil { + break + } + if f.isLazy && lazyDecode { + switch { + case lazyFields == nil || lazyFields[f] == lazyValidateOnly: + // Attempt to validate this field and leave it for later lazy unmarshaling. + o, valid := mi.skipField(b, f, wtyp, opts) + switch valid { + case ValidationValid: + // Skip over the valid field and continue. + err = nil + presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) + requiredMask |= f.validation.requiredBit + if !o.initialized { + initialized = false + } + n = o.n + break Field + case ValidationInvalid: + return out, errors.New("invalid proto wire format") + case ValidationWrongWireType: + break Field + case ValidationUnknown: + if lazyFields == nil { + lazyFields = make(map[*coderFieldInfo]lazyAction) + } + if presence.Present(f.presenceIndex) { + // We were unable to determine if the field is valid or not, + // and we've already skipped over at least one instance of this + // field. Clear the presence bit (so if we stop decoding early, + // we don't leave a partially-initialized field around) and flag + // the field for unmarshaling before we return. + presence.ClearPresent(f.presenceIndex) + lazyFields[f] = lazyUnmarshalLater + discardUnknown = true + break Field + } else { + // We were unable to determine if the field is valid or not, + // but this is the first time we've seen it. Flag it as needing + // eager unmarshaling and fall through to the eager unmarshal case below. + lazyFields[f] = lazyUnmarshalNow + } + } + case lazyFields[f] == lazyUnmarshalLater: + // This field will be unmarshaled in a separate pass below. + // Skip over it here. + discardUnknown = true + break Field + default: + // Eagerly unmarshal the field. + } + } + if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) { + if p.Apply(f.offset).AtomicGetPointer().IsNil() { + mi.lazyUnmarshal(p, f.num) + } + } + var o unmarshalOutput + o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) + n = o.n + if err != nil { + break + } + requiredMask |= f.validation.requiredBit + if f.funcs.isInit != nil && !o.initialized { + initialized = false + } + if f.presenceIndex != noPresence { + presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) + } + default: + // Possible extension. + if exts == nil && mi.extensionOffset.IsValid() { + exts = p.Apply(mi.extensionOffset).Extensions() + if *exts == nil { + *exts = make(map[int32]ExtensionField) + } + } + if exts == nil { + break + } + var o unmarshalOutput + o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) + if err != nil { + break + } + n = o.n + if !o.initialized { + initialized = false + } + } + if err != nil { + if err != errUnknown { + return out, err + } + n = protowire.ConsumeFieldValue(num, wtyp, b) + if n < 0 { + return out, errDecode + } + if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { + u := mi.mutableUnknownBytes(p) + *u = protowire.AppendTag(*u, num, wtyp) + *u = append(*u, b[:n]...) + } + } + b = b[n:] + end := start - len(b) + if lazyDecode && f != nil && f.isLazy { + if num != lastNum { + lazyIndex = append(lazyIndex, protolazy.IndexEntry{ + FieldNum: uint32(num), + Start: uint32(pos), + End: uint32(end), + }) + } else { + i := len(lazyIndex) - 1 + lazyIndex[i].End = uint32(end) + lazyIndex[i].MultipleContiguous = true + } + } + if num < lastNum { + outOfOrder = true + } + pos = end + lastNum = num + } + if groupTag != 0 { + return out, errors.New("missing end group marker") + } + if lazyFields != nil { + // Some fields failed validation, and now need to be unmarshaled. + for f, action := range lazyFields { + if action != lazyUnmarshalLater { + continue + } + initialized = false + if *lazy == nil { + *lazy = &protolazy.XXX_lazyUnmarshalInfo{} + } + if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil { + return out, err + } + presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) + } + } + if lazyDecode { + if outOfOrder { + sort.Slice(lazyIndex, func(i, j int) bool { + return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum || + (lazyIndex[i].FieldNum == lazyIndex[j].FieldNum && + lazyIndex[i].Start < lazyIndex[j].Start) + }) + } + if *lazy == nil { + *lazy = &protolazy.XXX_lazyUnmarshalInfo{} + } + + (*lazy).SetIndex(lazyIndex) + } + if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { + initialized = false + } + if initialized { + out.initialized = true + } + out.n = start - len(b) + return out, nil +} |