diff options
author | 2021-09-10 14:42:14 +0200 | |
---|---|---|
committer | 2021-09-10 14:42:14 +0200 | |
commit | f2e5bedea6fb93fbbf68ed8f7153c353cc57a9f0 (patch) | |
tree | 475ae9e7470d0df670ab2a59dce351cd1d07498a /vendor/github.com/goccy/go-json/internal/encoder/compiler.go | |
parent | fixes + db changes (#204) (diff) | |
download | gotosocial-f2e5bedea6fb93fbbf68ed8f7153c353cc57a9f0.tar.xz |
migrate go version to 1.17 (#203)
* migrate go version to 1.17
* update contributing
Diffstat (limited to 'vendor/github.com/goccy/go-json/internal/encoder/compiler.go')
-rw-r--r-- | vendor/github.com/goccy/go-json/internal/encoder/compiler.go | 276 |
1 files changed, 168 insertions, 108 deletions
diff --git a/vendor/github.com/goccy/go-json/internal/encoder/compiler.go b/vendor/github.com/goccy/go-json/internal/encoder/compiler.go index 486c80f18..c627ed307 100644 --- a/vendor/github.com/goccy/go-json/internal/encoder/compiler.go +++ b/vendor/github.com/goccy/go-json/internal/encoder/compiler.go @@ -1,6 +1,7 @@ package encoder import ( + "context" "encoding" "encoding/json" "fmt" @@ -13,13 +14,18 @@ import ( "github.com/goccy/go-json/internal/runtime" ) +type marshalerContext interface { + MarshalJSON(context.Context) ([]byte, error) +} + var ( - marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() - marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() - jsonNumberType = reflect.TypeOf(json.Number("")) - cachedOpcodeSets []*OpcodeSet - cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet - typeAddr *runtime.TypeAddr + marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem() + marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + jsonNumberType = reflect.TypeOf(json.Number("")) + cachedOpcodeSets []*OpcodeSet + cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet + typeAddr *runtime.TypeAddr ) func init() { @@ -55,19 +61,36 @@ func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) { // noescape trick for header.typ ( reflect.*rtype ) copiedType := *(**runtime.Type)(unsafe.Pointer(&typeptr)) - code, err := compileHead(&compileContext{ + noescapeKeyCode, err := compileHead(&compileContext{ + typ: copiedType, + structTypeToCompiledCode: map[uintptr]*CompiledCode{}, + }) + if err != nil { + return nil, err + } + escapeKeyCode, err := compileHead(&compileContext{ typ: copiedType, structTypeToCompiledCode: map[uintptr]*CompiledCode{}, + escapeKey: true, }) if err != nil { return nil, err } - code = copyOpcode(code) - codeLength := code.TotalLength() + noescapeKeyCode = copyOpcode(noescapeKeyCode) + escapeKeyCode = copyOpcode(escapeKeyCode) + setTotalLengthToInterfaceOp(noescapeKeyCode) + setTotalLengthToInterfaceOp(escapeKeyCode) + interfaceNoescapeKeyCode := copyToInterfaceOpcode(noescapeKeyCode) + interfaceEscapeKeyCode := copyToInterfaceOpcode(escapeKeyCode) + codeLength := noescapeKeyCode.TotalLength() codeSet := &OpcodeSet{ - Type: copiedType, - Code: code, - CodeLength: codeLength, + Type: copiedType, + NoescapeKeyCode: noescapeKeyCode, + EscapeKeyCode: escapeKeyCode, + InterfaceNoescapeKeyCode: interfaceNoescapeKeyCode, + InterfaceEscapeKeyCode: interfaceEscapeKeyCode, + CodeLength: codeLength, + EndCode: ToEndCode(interfaceNoescapeKeyCode), } storeOpcodeSet(typeptr, codeSet, opcodeMap) return codeSet, nil @@ -100,7 +123,7 @@ func compileHead(ctx *compileContext) (*Opcode, error) { elem := typ.Elem() if elem.Kind() == reflect.Uint8 { p := runtime.PtrTo(elem) - if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) { + if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { if isPtr { return compileBytesPtr(ctx) } @@ -246,6 +269,7 @@ func linkRecursiveCode(c *Opcode) { continue } code.Jmp.Code = copyOpcode(code.Jmp.Code) + c := code.Jmp.Code c.End.Next = newEndOp(&compileContext{}) c.Op = c.Op.PtrHeadToHead() @@ -258,8 +282,8 @@ func linkRecursiveCode(c *Opcode) { lastCode.Length = lastCode.Idx + 2*uintptrSize // extend length to alloc slot for elemIdx + length - totalLength := uintptr(code.TotalLength() + 2) - nextTotalLength := uintptr(c.TotalLength() + 2) + totalLength := uintptr(code.TotalLength() + 3) + nextTotalLength := uintptr(c.TotalLength() + 3) c.End.Next.Op = OpRecursiveEnd @@ -268,6 +292,7 @@ func linkRecursiveCode(c *Opcode) { code.Jmp.Linked = true linkRecursiveCode(code.Jmp.Code) + code = code.Next continue } @@ -328,14 +353,14 @@ func optimizeStructEnd(c *Opcode) { } func implementsMarshalJSON(typ *runtime.Type) bool { - if !typ.Implements(marshalJSONType) { + if !implementsMarshalJSONType(typ) { return false } if typ.Kind() != reflect.Ptr { return true } // type kind is reflect.Ptr - if !typ.Elem().Implements(marshalJSONType) { + if !implementsMarshalJSONType(typ.Elem()) { return true } // needs to dereference @@ -372,7 +397,7 @@ func compile(ctx *compileContext, isPtr bool) (*Opcode, error) { elem := typ.Elem() if elem.Kind() == reflect.Uint8 { p := runtime.PtrTo(elem) - if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) { + if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { return compileBytes(ctx) } } @@ -474,8 +499,6 @@ func compileKey(ctx *compileContext) (*Opcode, error) { switch typ.Kind() { case reflect.Ptr: return compilePtr(ctx) - case reflect.Interface: - return compileInterface(ctx) case reflect.String: return compileString(ctx) case reflect.Int: @@ -517,10 +540,17 @@ func compilePtr(ctx *compileContext) (*Opcode, error) { func compileMarshalJSON(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpMarshalJSON) typ := ctx.typ - if !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType) { - code.AddrForMarshaler = true + if isPtrMarshalJSONType(typ) { + code.Flags |= AddrForMarshalerFlags + } + if typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType) { + code.Flags |= MarshalerContextFlags + } + if isNilableType(typ) { + code.Flags |= IsNilableTypeFlags + } else { + code.Flags &= ^IsNilableTypeFlags } - code.IsNilableType = isNilableType(typ) ctx.incIndex() return code, nil } @@ -529,9 +559,13 @@ func compileMarshalText(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpMarshalText) typ := ctx.typ if !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType) { - code.AddrForMarshaler = true + code.Flags |= AddrForMarshalerFlags + } + if isNilableType(typ) { + code.Flags |= IsNilableTypeFlags + } else { + code.Flags &= ^IsNilableTypeFlags } - code.IsNilableType = isNilableType(typ) ctx.incIndex() return code, nil } @@ -540,7 +574,7 @@ const intSize = 32 << (^uint(0) >> 63) func compileInt(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpInt) - code.setMaskAndRshiftNum(intSize) + code.NumBitSize = intSize ctx.incIndex() return code, nil } @@ -556,7 +590,7 @@ func compileIntPtr(ctx *compileContext) (*Opcode, error) { func compileInt8(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpInt) - code.setMaskAndRshiftNum(8) + code.NumBitSize = 8 ctx.incIndex() return code, nil } @@ -572,7 +606,7 @@ func compileInt8Ptr(ctx *compileContext) (*Opcode, error) { func compileInt16(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpInt) - code.setMaskAndRshiftNum(16) + code.NumBitSize = 16 ctx.incIndex() return code, nil } @@ -588,7 +622,7 @@ func compileInt16Ptr(ctx *compileContext) (*Opcode, error) { func compileInt32(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpInt) - code.setMaskAndRshiftNum(32) + code.NumBitSize = 32 ctx.incIndex() return code, nil } @@ -604,7 +638,7 @@ func compileInt32Ptr(ctx *compileContext) (*Opcode, error) { func compileInt64(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpInt) - code.setMaskAndRshiftNum(64) + code.NumBitSize = 64 ctx.incIndex() return code, nil } @@ -620,7 +654,7 @@ func compileInt64Ptr(ctx *compileContext) (*Opcode, error) { func compileUint(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpUint) - code.setMaskAndRshiftNum(intSize) + code.NumBitSize = intSize ctx.incIndex() return code, nil } @@ -636,7 +670,7 @@ func compileUintPtr(ctx *compileContext) (*Opcode, error) { func compileUint8(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpUint) - code.setMaskAndRshiftNum(8) + code.NumBitSize = 8 ctx.incIndex() return code, nil } @@ -652,7 +686,7 @@ func compileUint8Ptr(ctx *compileContext) (*Opcode, error) { func compileUint16(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpUint) - code.setMaskAndRshiftNum(16) + code.NumBitSize = 16 ctx.incIndex() return code, nil } @@ -668,7 +702,7 @@ func compileUint16Ptr(ctx *compileContext) (*Opcode, error) { func compileUint32(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpUint) - code.setMaskAndRshiftNum(32) + code.NumBitSize = 32 ctx.incIndex() return code, nil } @@ -684,7 +718,7 @@ func compileUint32Ptr(ctx *compileContext) (*Opcode, error) { func compileUint64(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpUint) - code.setMaskAndRshiftNum(64) + code.NumBitSize = 64 ctx.incIndex() return code, nil } @@ -700,70 +734,70 @@ func compileUint64Ptr(ctx *compileContext) (*Opcode, error) { func compileIntString(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpIntString) - code.setMaskAndRshiftNum(intSize) + code.NumBitSize = intSize ctx.incIndex() return code, nil } func compileInt8String(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpIntString) - code.setMaskAndRshiftNum(8) + code.NumBitSize = 8 ctx.incIndex() return code, nil } func compileInt16String(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpIntString) - code.setMaskAndRshiftNum(16) + code.NumBitSize = 16 ctx.incIndex() return code, nil } func compileInt32String(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpIntString) - code.setMaskAndRshiftNum(32) + code.NumBitSize = 32 ctx.incIndex() return code, nil } func compileInt64String(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpIntString) - code.setMaskAndRshiftNum(64) + code.NumBitSize = 64 ctx.incIndex() return code, nil } func compileUintString(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpUintString) - code.setMaskAndRshiftNum(intSize) + code.NumBitSize = intSize ctx.incIndex() return code, nil } func compileUint8String(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpUintString) - code.setMaskAndRshiftNum(8) + code.NumBitSize = 8 ctx.incIndex() return code, nil } func compileUint16String(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpUintString) - code.setMaskAndRshiftNum(16) + code.NumBitSize = 16 ctx.incIndex() return code, nil } func compileUint32String(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpUintString) - code.setMaskAndRshiftNum(32) + code.NumBitSize = 32 ctx.incIndex() return code, nil } func compileUint64String(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpUintString) - code.setMaskAndRshiftNum(64) + code.NumBitSize = 64 ctx.incIndex() return code, nil } @@ -879,7 +913,7 @@ func compileSlice(ctx *compileContext) (*Opcode, error) { if err != nil { return nil, err } - code.Indirect = true + code.Flags |= IndirectFlags // header => opcode => elem => end // ^ | @@ -891,7 +925,6 @@ func compileSlice(ctx *compileContext) (*Opcode, error) { end := newOpCode(ctx, OpSliceEnd) ctx.incIndex() - header.Elem = elemCode header.End = end header.Next = code code.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(elemCode)) @@ -903,7 +936,7 @@ func compileSlice(ctx *compileContext) (*Opcode, error) { func compileListElem(ctx *compileContext) (*Opcode, error) { typ := ctx.typ switch { - case !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType): + case isPtrMarshalJSONType(typ): return compileMarshalJSON(ctx) case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType): return compileMarshalText(ctx) @@ -934,7 +967,7 @@ func compileArray(ctx *compileContext) (*Opcode, error) { if err != nil { return nil, err } - code.Indirect = true + code.Flags |= IndirectFlags // header => opcode => elem => end // ^ | // |________| @@ -945,7 +978,6 @@ func compileArray(ctx *compileContext) (*Opcode, error) { end := newOpCode(ctx, OpArrayEnd) ctx.incIndex() - header.Elem = elemCode header.End = end header.Next = code code.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(elemCode)) @@ -976,16 +1008,13 @@ func compileMap(ctx *compileContext) (*Opcode, error) { if err != nil { return nil, err } - valueCode.Indirect = true + valueCode.Flags |= IndirectFlags key := newMapKeyCode(ctx, header) ctx.incIndex() ctx = ctx.decIndent() - header.MapKey = key - header.MapValue = value - end := newMapEndCode(ctx, header) ctx.incIndex() @@ -1052,8 +1081,7 @@ func compiledCode(ctx *compileContext) *Opcode { func structHeader(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag *runtime.StructTag) *Opcode { op := optimizeStructHeader(valueCode, tag) fieldCode.Op = op - fieldCode.Mask = valueCode.Mask - fieldCode.RshiftNum = valueCode.RshiftNum + fieldCode.NumBitSize = valueCode.NumBitSize fieldCode.PtrNum = valueCode.PtrNum if op.IsMultipleOpHead() { return valueCode.BeforeLastCode() @@ -1065,9 +1093,8 @@ func structHeader(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag func structField(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag *runtime.StructTag) *Opcode { op := optimizeStructField(valueCode, tag) fieldCode.Op = op + fieldCode.NumBitSize = valueCode.NumBitSize fieldCode.PtrNum = valueCode.PtrNum - fieldCode.Mask = valueCode.Mask - fieldCode.RshiftNum = valueCode.RshiftNum if op.IsMultipleOpField() { return valueCode.BeforeLastCode() } @@ -1082,7 +1109,7 @@ func isNotExistsField(head *Opcode) bool { if head.Op != OpStructHead { return false } - if !head.AnonymousHead { + if (head.Flags & AnonymousHeadFlags) == 0 { return false } if head.Next == nil { @@ -1117,7 +1144,7 @@ func optimizeAnonymousFields(head *Opcode) { if isNotExistsField(code.Next) { code.Next = code.NextField diff := code.Next.DisplayIdx - code.DisplayIdx - for i := 0; i < diff; i++ { + for i := uint32(0); i < diff; i++ { code.Next.decOpcodeIndex() } linkPrevToNextField(code, removedFields) @@ -1147,20 +1174,20 @@ func anonymousStructFieldPairMap(tags runtime.StructTags, named string, valueCod isHeadOp := strings.Contains(f.Op.String(), "Head") if existsKey && f.Next != nil && strings.Contains(f.Next.Op.String(), "Recursive") { // through - } else if isHeadOp && !f.AnonymousHead { + } else if isHeadOp && (f.Flags&AnonymousHeadFlags) == 0 { if existsKey { // TODO: need to remove this head f.Op = OpStructHead - f.AnonymousKey = true - f.AnonymousHead = true + f.Flags |= AnonymousKeyFlags + f.Flags |= AnonymousHeadFlags } else if named == "" { - f.AnonymousHead = true + f.Flags |= AnonymousHeadFlags } } else if named == "" && f.Op == OpStructEnd { f.Op = OpStructAnonymousEnd } else if existsKey { diff := f.NextField.DisplayIdx - f.DisplayIdx - for i := 0; i < diff; i++ { + for i := uint32(0); i < diff; i++ { f.NextField.decOpcodeIndex() } linkPrevToNextField(f, removedFields) @@ -1179,7 +1206,7 @@ func anonymousStructFieldPairMap(tags runtime.StructTags, named string, valueCod anonymousFields[key] = append(anonymousFields[key], structFieldPair{ prevField: prevAnonymousField, curField: f, - isTaggedKey: f.IsTaggedKey, + isTaggedKey: (f.Flags & IsTaggedKeyFlags) != 0, }) if f.Next != nil && f.NextField != f.Next && f.Next.Op.CodeType() == CodeStructField { for k, v := range anonymousFieldPairRecursively(named, f.Next) { @@ -1200,12 +1227,12 @@ func anonymousFieldPairRecursively(named string, valueCode *Opcode) map[string][ f := valueCode var prevAnonymousField *Opcode for { - if f.DisplayKey != "" && f.AnonymousHead { + if f.DisplayKey != "" && (f.Flags&AnonymousHeadFlags) != 0 { key := fmt.Sprintf("%s.%s", named, f.DisplayKey) anonymousFields[key] = append(anonymousFields[key], structFieldPair{ prevField: prevAnonymousField, curField: f, - isTaggedKey: f.IsTaggedKey, + isTaggedKey: (f.Flags & IsTaggedKeyFlags) != 0, }) if f.Next != nil && f.NextField != f.Next && f.Next.Op.CodeType() == CodeStructField { for k, v := range anonymousFieldPairRecursively(named, f.Next) { @@ -1238,11 +1265,11 @@ func optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPai if fieldPair.prevField == nil { // head operation fieldPair.curField.Op = OpStructHead - fieldPair.curField.AnonymousHead = true - fieldPair.curField.AnonymousKey = true + fieldPair.curField.Flags |= AnonymousHeadFlags + fieldPair.curField.Flags |= AnonymousKeyFlags } else { diff := fieldPair.curField.NextField.DisplayIdx - fieldPair.curField.DisplayIdx - for i := 0; i < diff; i++ { + for i := uint32(0); i < diff; i++ { fieldPair.curField.NextField.decOpcodeIndex() } removedFields[fieldPair.curField] = struct{}{} @@ -1258,12 +1285,12 @@ func optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPai if fieldPair.prevField == nil { // head operation fieldPair.curField.Op = OpStructHead - fieldPair.curField.AnonymousHead = true - fieldPair.curField.AnonymousKey = true + fieldPair.curField.Flags |= AnonymousHeadFlags + fieldPair.curField.Flags |= AnonymousKeyFlags } else { diff := fieldPair.curField.NextField.DisplayIdx - fieldPair.curField.DisplayIdx removedFields[fieldPair.curField] = struct{}{} - for i := 0; i < diff; i++ { + for i := uint32(0); i < diff; i++ { fieldPair.curField.NextField.decOpcodeIndex() } linkPrevToNextField(fieldPair.curField, removedFields) @@ -1273,7 +1300,7 @@ func optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPai } } else { for _, fieldPair := range taggedPairs { - fieldPair.curField.IsTaggedKey = false + fieldPair.curField.Flags &= ^IsTaggedKeyFlags } } } @@ -1390,7 +1417,7 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { valueCode = code } - if field.Anonymous { + if field.Anonymous && !tag.IsTaggedKey { tagKey := "" if tag.IsTaggedKey { tagKey = tag.Key @@ -1398,50 +1425,76 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { for k, v := range anonymousStructFieldPairMap(tags, tagKey, valueCode) { anonymousFields[k] = append(anonymousFields[k], v...) } + valueCode.decIndent() // fix issue144 if !(isPtr && strings.Contains(valueCode.Op.String(), "Marshal")) { - valueCode.Indirect = indirect + if indirect { + valueCode.Flags |= IndirectFlags + } else { + valueCode.Flags &= ^IndirectFlags + } } } else { if indirect { // if parent is indirect type, set child indirect property to true - valueCode.Indirect = indirect + valueCode.Flags |= IndirectFlags } else { - // if parent is not indirect type and child have only one field, set child indirect property to false - if i == 0 && valueCode.NextField != nil && valueCode.NextField.Op == OpStructEnd { - valueCode.Indirect = indirect + // if parent is not indirect type, set child indirect property to false. + // but if parent's indirect is false and isPtr is true, then indirect must be true. + // Do this only if indirectConversion is enabled at the end of compileStruct. + if i == 0 { + valueCode.Flags &= ^IndirectFlags } } } - key := fmt.Sprintf(`"%s":`, tag.Key) - escapedKey := fmt.Sprintf(`%s:`, string(AppendEscapedString([]byte{}, tag.Key))) + var flags OpFlags + if indirect { + flags |= IndirectFlags + } + if field.Anonymous { + flags |= AnonymousKeyFlags + } + if tag.IsTaggedKey { + flags |= IsTaggedKeyFlags + } + if nilcheck { + flags |= NilCheckFlags + } + if addrForMarshaler { + flags |= AddrForMarshalerFlags + } + if strings.Contains(valueCode.Op.String(), "Ptr") || valueCode.Op == OpInterface { + flags |= IsNextOpPtrTypeFlags + } + if isNilableType { + flags |= IsNilableTypeFlags + } + var key string + if ctx.escapeKey { + rctx := &RuntimeContext{Option: &Option{Flag: HTMLEscapeOption}} + key = fmt.Sprintf(`%s:`, string(AppendString(rctx, []byte{}, tag.Key))) + } else { + key = fmt.Sprintf(`"%s":`, tag.Key) + } fieldCode := &Opcode{ - Type: valueCode.Type, - DisplayIdx: fieldOpcodeIndex, - Idx: opcodeOffset(fieldPtrIndex), - Next: valueCode, - Indent: ctx.indent, - AnonymousKey: field.Anonymous, - Key: []byte(key), - EscapedKey: []byte(escapedKey), - IsTaggedKey: tag.IsTaggedKey, - DisplayKey: tag.Key, - Offset: field.Offset, - Indirect: indirect, - Nilcheck: nilcheck, - AddrForMarshaler: addrForMarshaler, - IsNextOpPtrType: strings.Contains(valueCode.Op.String(), "Ptr") || valueCode.Op == OpInterface, - IsNilableType: isNilableType, + Idx: opcodeOffset(fieldPtrIndex), + Next: valueCode, + Flags: flags, + Key: key, + Offset: uint32(field.Offset), + Type: valueCode.Type, + DisplayIdx: fieldOpcodeIndex, + Indent: ctx.indent, + DisplayKey: tag.Key, } if fieldIdx == 0 { - fieldCode.HeadIdx = fieldCode.Idx code = structHeader(ctx, fieldCode, valueCode, tag) head = fieldCode prevField = fieldCode } else { - fieldCode.HeadIdx = head.HeadIdx + fieldCode.Idx = head.Idx code.Next = fieldCode code = structField(ctx, fieldCode, valueCode, tag) prevField.NextField = fieldCode @@ -1455,7 +1508,6 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { Op: OpStructEnd, Type: nil, Indent: ctx.indent, - Next: newEndOp(ctx), } ctx = ctx.decIndent() @@ -1464,12 +1516,11 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { if head == nil { head = &Opcode{ Op: OpStructHead, + Idx: opcodeOffset(ctx.ptrIndex), + NextField: structEndCode, Type: typ, DisplayIdx: ctx.opcodeIndex, - Idx: opcodeOffset(ctx.ptrIndex), - HeadIdx: opcodeOffset(ctx.ptrIndex), Indent: ctx.indent, - NextField: structEndCode, } structEndCode.PrevField = head ctx.incIndex() @@ -1479,6 +1530,7 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { structEndCode.DisplayIdx = ctx.opcodeIndex structEndCode.Idx = opcodeOffset(ctx.ptrIndex) ctx.incIndex() + structEndCode.Next = newEndOp(ctx) if prevField != nil && prevField.NextField == nil { prevField.NextField = structEndCode @@ -1494,15 +1546,23 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { delete(ctx.structTypeToCompiledCode, typeptr) - if !disableIndirectConversion && !head.Indirect && isPtr { - head.Indirect = true + if !disableIndirectConversion && (head.Flags&IndirectFlags == 0) && isPtr { + headCode := head + for strings.Contains(headCode.Op.String(), "Head") { + headCode.Flags |= IndirectFlags + headCode = headCode.Next + } } return ret, nil } +func implementsMarshalJSONType(typ *runtime.Type) bool { + return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType) +} + func isPtrMarshalJSONType(typ *runtime.Type) bool { - return !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType) + return !implementsMarshalJSONType(typ) && implementsMarshalJSONType(runtime.PtrTo(typ)) } func isPtrMarshalTextType(typ *runtime.Type) bool { |