summaryrefslogtreecommitdiff
path: root/vendor/github.com/goccy/go-json/internal/encoder/compiler.go
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2021-09-10 14:42:14 +0200
committerLibravatar GitHub <noreply@github.com>2021-09-10 14:42:14 +0200
commitf2e5bedea6fb93fbbf68ed8f7153c353cc57a9f0 (patch)
tree475ae9e7470d0df670ab2a59dce351cd1d07498a /vendor/github.com/goccy/go-json/internal/encoder/compiler.go
parentfixes + db changes (#204) (diff)
downloadgotosocial-f2e5bedea6fb93fbbf68ed8f7153c353cc57a9f0.tar.xz
migrate go version to 1.17 (#203)
* migrate go version to 1.17 * update contributing
Diffstat (limited to 'vendor/github.com/goccy/go-json/internal/encoder/compiler.go')
-rw-r--r--vendor/github.com/goccy/go-json/internal/encoder/compiler.go276
1 files changed, 168 insertions, 108 deletions
diff --git a/vendor/github.com/goccy/go-json/internal/encoder/compiler.go b/vendor/github.com/goccy/go-json/internal/encoder/compiler.go
index 486c80f18..c627ed307 100644
--- a/vendor/github.com/goccy/go-json/internal/encoder/compiler.go
+++ b/vendor/github.com/goccy/go-json/internal/encoder/compiler.go
@@ -1,6 +1,7 @@
package encoder
import (
+ "context"
"encoding"
"encoding/json"
"fmt"
@@ -13,13 +14,18 @@ import (
"github.com/goccy/go-json/internal/runtime"
)
+type marshalerContext interface {
+ MarshalJSON(context.Context) ([]byte, error)
+}
+
var (
- marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
- marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
- jsonNumberType = reflect.TypeOf(json.Number(""))
- cachedOpcodeSets []*OpcodeSet
- cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet
- typeAddr *runtime.TypeAddr
+ marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
+ marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem()
+ marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
+ jsonNumberType = reflect.TypeOf(json.Number(""))
+ cachedOpcodeSets []*OpcodeSet
+ cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet
+ typeAddr *runtime.TypeAddr
)
func init() {
@@ -55,19 +61,36 @@ func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) {
// noescape trick for header.typ ( reflect.*rtype )
copiedType := *(**runtime.Type)(unsafe.Pointer(&typeptr))
- code, err := compileHead(&compileContext{
+ noescapeKeyCode, err := compileHead(&compileContext{
+ typ: copiedType,
+ structTypeToCompiledCode: map[uintptr]*CompiledCode{},
+ })
+ if err != nil {
+ return nil, err
+ }
+ escapeKeyCode, err := compileHead(&compileContext{
typ: copiedType,
structTypeToCompiledCode: map[uintptr]*CompiledCode{},
+ escapeKey: true,
})
if err != nil {
return nil, err
}
- code = copyOpcode(code)
- codeLength := code.TotalLength()
+ noescapeKeyCode = copyOpcode(noescapeKeyCode)
+ escapeKeyCode = copyOpcode(escapeKeyCode)
+ setTotalLengthToInterfaceOp(noescapeKeyCode)
+ setTotalLengthToInterfaceOp(escapeKeyCode)
+ interfaceNoescapeKeyCode := copyToInterfaceOpcode(noescapeKeyCode)
+ interfaceEscapeKeyCode := copyToInterfaceOpcode(escapeKeyCode)
+ codeLength := noescapeKeyCode.TotalLength()
codeSet := &OpcodeSet{
- Type: copiedType,
- Code: code,
- CodeLength: codeLength,
+ Type: copiedType,
+ NoescapeKeyCode: noescapeKeyCode,
+ EscapeKeyCode: escapeKeyCode,
+ InterfaceNoescapeKeyCode: interfaceNoescapeKeyCode,
+ InterfaceEscapeKeyCode: interfaceEscapeKeyCode,
+ CodeLength: codeLength,
+ EndCode: ToEndCode(interfaceNoescapeKeyCode),
}
storeOpcodeSet(typeptr, codeSet, opcodeMap)
return codeSet, nil
@@ -100,7 +123,7 @@ func compileHead(ctx *compileContext) (*Opcode, error) {
elem := typ.Elem()
if elem.Kind() == reflect.Uint8 {
p := runtime.PtrTo(elem)
- if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) {
+ if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
if isPtr {
return compileBytesPtr(ctx)
}
@@ -246,6 +269,7 @@ func linkRecursiveCode(c *Opcode) {
continue
}
code.Jmp.Code = copyOpcode(code.Jmp.Code)
+
c := code.Jmp.Code
c.End.Next = newEndOp(&compileContext{})
c.Op = c.Op.PtrHeadToHead()
@@ -258,8 +282,8 @@ func linkRecursiveCode(c *Opcode) {
lastCode.Length = lastCode.Idx + 2*uintptrSize
// extend length to alloc slot for elemIdx + length
- totalLength := uintptr(code.TotalLength() + 2)
- nextTotalLength := uintptr(c.TotalLength() + 2)
+ totalLength := uintptr(code.TotalLength() + 3)
+ nextTotalLength := uintptr(c.TotalLength() + 3)
c.End.Next.Op = OpRecursiveEnd
@@ -268,6 +292,7 @@ func linkRecursiveCode(c *Opcode) {
code.Jmp.Linked = true
linkRecursiveCode(code.Jmp.Code)
+
code = code.Next
continue
}
@@ -328,14 +353,14 @@ func optimizeStructEnd(c *Opcode) {
}
func implementsMarshalJSON(typ *runtime.Type) bool {
- if !typ.Implements(marshalJSONType) {
+ if !implementsMarshalJSONType(typ) {
return false
}
if typ.Kind() != reflect.Ptr {
return true
}
// type kind is reflect.Ptr
- if !typ.Elem().Implements(marshalJSONType) {
+ if !implementsMarshalJSONType(typ.Elem()) {
return true
}
// needs to dereference
@@ -372,7 +397,7 @@ func compile(ctx *compileContext, isPtr bool) (*Opcode, error) {
elem := typ.Elem()
if elem.Kind() == reflect.Uint8 {
p := runtime.PtrTo(elem)
- if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) {
+ if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
return compileBytes(ctx)
}
}
@@ -474,8 +499,6 @@ func compileKey(ctx *compileContext) (*Opcode, error) {
switch typ.Kind() {
case reflect.Ptr:
return compilePtr(ctx)
- case reflect.Interface:
- return compileInterface(ctx)
case reflect.String:
return compileString(ctx)
case reflect.Int:
@@ -517,10 +540,17 @@ func compilePtr(ctx *compileContext) (*Opcode, error) {
func compileMarshalJSON(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpMarshalJSON)
typ := ctx.typ
- if !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType) {
- code.AddrForMarshaler = true
+ if isPtrMarshalJSONType(typ) {
+ code.Flags |= AddrForMarshalerFlags
+ }
+ if typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType) {
+ code.Flags |= MarshalerContextFlags
+ }
+ if isNilableType(typ) {
+ code.Flags |= IsNilableTypeFlags
+ } else {
+ code.Flags &= ^IsNilableTypeFlags
}
- code.IsNilableType = isNilableType(typ)
ctx.incIndex()
return code, nil
}
@@ -529,9 +559,13 @@ func compileMarshalText(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpMarshalText)
typ := ctx.typ
if !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType) {
- code.AddrForMarshaler = true
+ code.Flags |= AddrForMarshalerFlags
+ }
+ if isNilableType(typ) {
+ code.Flags |= IsNilableTypeFlags
+ } else {
+ code.Flags &= ^IsNilableTypeFlags
}
- code.IsNilableType = isNilableType(typ)
ctx.incIndex()
return code, nil
}
@@ -540,7 +574,7 @@ const intSize = 32 << (^uint(0) >> 63)
func compileInt(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpInt)
- code.setMaskAndRshiftNum(intSize)
+ code.NumBitSize = intSize
ctx.incIndex()
return code, nil
}
@@ -556,7 +590,7 @@ func compileIntPtr(ctx *compileContext) (*Opcode, error) {
func compileInt8(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpInt)
- code.setMaskAndRshiftNum(8)
+ code.NumBitSize = 8
ctx.incIndex()
return code, nil
}
@@ -572,7 +606,7 @@ func compileInt8Ptr(ctx *compileContext) (*Opcode, error) {
func compileInt16(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpInt)
- code.setMaskAndRshiftNum(16)
+ code.NumBitSize = 16
ctx.incIndex()
return code, nil
}
@@ -588,7 +622,7 @@ func compileInt16Ptr(ctx *compileContext) (*Opcode, error) {
func compileInt32(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpInt)
- code.setMaskAndRshiftNum(32)
+ code.NumBitSize = 32
ctx.incIndex()
return code, nil
}
@@ -604,7 +638,7 @@ func compileInt32Ptr(ctx *compileContext) (*Opcode, error) {
func compileInt64(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpInt)
- code.setMaskAndRshiftNum(64)
+ code.NumBitSize = 64
ctx.incIndex()
return code, nil
}
@@ -620,7 +654,7 @@ func compileInt64Ptr(ctx *compileContext) (*Opcode, error) {
func compileUint(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpUint)
- code.setMaskAndRshiftNum(intSize)
+ code.NumBitSize = intSize
ctx.incIndex()
return code, nil
}
@@ -636,7 +670,7 @@ func compileUintPtr(ctx *compileContext) (*Opcode, error) {
func compileUint8(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpUint)
- code.setMaskAndRshiftNum(8)
+ code.NumBitSize = 8
ctx.incIndex()
return code, nil
}
@@ -652,7 +686,7 @@ func compileUint8Ptr(ctx *compileContext) (*Opcode, error) {
func compileUint16(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpUint)
- code.setMaskAndRshiftNum(16)
+ code.NumBitSize = 16
ctx.incIndex()
return code, nil
}
@@ -668,7 +702,7 @@ func compileUint16Ptr(ctx *compileContext) (*Opcode, error) {
func compileUint32(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpUint)
- code.setMaskAndRshiftNum(32)
+ code.NumBitSize = 32
ctx.incIndex()
return code, nil
}
@@ -684,7 +718,7 @@ func compileUint32Ptr(ctx *compileContext) (*Opcode, error) {
func compileUint64(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpUint)
- code.setMaskAndRshiftNum(64)
+ code.NumBitSize = 64
ctx.incIndex()
return code, nil
}
@@ -700,70 +734,70 @@ func compileUint64Ptr(ctx *compileContext) (*Opcode, error) {
func compileIntString(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpIntString)
- code.setMaskAndRshiftNum(intSize)
+ code.NumBitSize = intSize
ctx.incIndex()
return code, nil
}
func compileInt8String(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpIntString)
- code.setMaskAndRshiftNum(8)
+ code.NumBitSize = 8
ctx.incIndex()
return code, nil
}
func compileInt16String(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpIntString)
- code.setMaskAndRshiftNum(16)
+ code.NumBitSize = 16
ctx.incIndex()
return code, nil
}
func compileInt32String(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpIntString)
- code.setMaskAndRshiftNum(32)
+ code.NumBitSize = 32
ctx.incIndex()
return code, nil
}
func compileInt64String(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpIntString)
- code.setMaskAndRshiftNum(64)
+ code.NumBitSize = 64
ctx.incIndex()
return code, nil
}
func compileUintString(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpUintString)
- code.setMaskAndRshiftNum(intSize)
+ code.NumBitSize = intSize
ctx.incIndex()
return code, nil
}
func compileUint8String(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpUintString)
- code.setMaskAndRshiftNum(8)
+ code.NumBitSize = 8
ctx.incIndex()
return code, nil
}
func compileUint16String(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpUintString)
- code.setMaskAndRshiftNum(16)
+ code.NumBitSize = 16
ctx.incIndex()
return code, nil
}
func compileUint32String(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpUintString)
- code.setMaskAndRshiftNum(32)
+ code.NumBitSize = 32
ctx.incIndex()
return code, nil
}
func compileUint64String(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpUintString)
- code.setMaskAndRshiftNum(64)
+ code.NumBitSize = 64
ctx.incIndex()
return code, nil
}
@@ -879,7 +913,7 @@ func compileSlice(ctx *compileContext) (*Opcode, error) {
if err != nil {
return nil, err
}
- code.Indirect = true
+ code.Flags |= IndirectFlags
// header => opcode => elem => end
// ^ |
@@ -891,7 +925,6 @@ func compileSlice(ctx *compileContext) (*Opcode, error) {
end := newOpCode(ctx, OpSliceEnd)
ctx.incIndex()
- header.Elem = elemCode
header.End = end
header.Next = code
code.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(elemCode))
@@ -903,7 +936,7 @@ func compileSlice(ctx *compileContext) (*Opcode, error) {
func compileListElem(ctx *compileContext) (*Opcode, error) {
typ := ctx.typ
switch {
- case !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType):
+ case isPtrMarshalJSONType(typ):
return compileMarshalJSON(ctx)
case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType):
return compileMarshalText(ctx)
@@ -934,7 +967,7 @@ func compileArray(ctx *compileContext) (*Opcode, error) {
if err != nil {
return nil, err
}
- code.Indirect = true
+ code.Flags |= IndirectFlags
// header => opcode => elem => end
// ^ |
// |________|
@@ -945,7 +978,6 @@ func compileArray(ctx *compileContext) (*Opcode, error) {
end := newOpCode(ctx, OpArrayEnd)
ctx.incIndex()
- header.Elem = elemCode
header.End = end
header.Next = code
code.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(elemCode))
@@ -976,16 +1008,13 @@ func compileMap(ctx *compileContext) (*Opcode, error) {
if err != nil {
return nil, err
}
- valueCode.Indirect = true
+ valueCode.Flags |= IndirectFlags
key := newMapKeyCode(ctx, header)
ctx.incIndex()
ctx = ctx.decIndent()
- header.MapKey = key
- header.MapValue = value
-
end := newMapEndCode(ctx, header)
ctx.incIndex()
@@ -1052,8 +1081,7 @@ func compiledCode(ctx *compileContext) *Opcode {
func structHeader(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag *runtime.StructTag) *Opcode {
op := optimizeStructHeader(valueCode, tag)
fieldCode.Op = op
- fieldCode.Mask = valueCode.Mask
- fieldCode.RshiftNum = valueCode.RshiftNum
+ fieldCode.NumBitSize = valueCode.NumBitSize
fieldCode.PtrNum = valueCode.PtrNum
if op.IsMultipleOpHead() {
return valueCode.BeforeLastCode()
@@ -1065,9 +1093,8 @@ func structHeader(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag
func structField(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag *runtime.StructTag) *Opcode {
op := optimizeStructField(valueCode, tag)
fieldCode.Op = op
+ fieldCode.NumBitSize = valueCode.NumBitSize
fieldCode.PtrNum = valueCode.PtrNum
- fieldCode.Mask = valueCode.Mask
- fieldCode.RshiftNum = valueCode.RshiftNum
if op.IsMultipleOpField() {
return valueCode.BeforeLastCode()
}
@@ -1082,7 +1109,7 @@ func isNotExistsField(head *Opcode) bool {
if head.Op != OpStructHead {
return false
}
- if !head.AnonymousHead {
+ if (head.Flags & AnonymousHeadFlags) == 0 {
return false
}
if head.Next == nil {
@@ -1117,7 +1144,7 @@ func optimizeAnonymousFields(head *Opcode) {
if isNotExistsField(code.Next) {
code.Next = code.NextField
diff := code.Next.DisplayIdx - code.DisplayIdx
- for i := 0; i < diff; i++ {
+ for i := uint32(0); i < diff; i++ {
code.Next.decOpcodeIndex()
}
linkPrevToNextField(code, removedFields)
@@ -1147,20 +1174,20 @@ func anonymousStructFieldPairMap(tags runtime.StructTags, named string, valueCod
isHeadOp := strings.Contains(f.Op.String(), "Head")
if existsKey && f.Next != nil && strings.Contains(f.Next.Op.String(), "Recursive") {
// through
- } else if isHeadOp && !f.AnonymousHead {
+ } else if isHeadOp && (f.Flags&AnonymousHeadFlags) == 0 {
if existsKey {
// TODO: need to remove this head
f.Op = OpStructHead
- f.AnonymousKey = true
- f.AnonymousHead = true
+ f.Flags |= AnonymousKeyFlags
+ f.Flags |= AnonymousHeadFlags
} else if named == "" {
- f.AnonymousHead = true
+ f.Flags |= AnonymousHeadFlags
}
} else if named == "" && f.Op == OpStructEnd {
f.Op = OpStructAnonymousEnd
} else if existsKey {
diff := f.NextField.DisplayIdx - f.DisplayIdx
- for i := 0; i < diff; i++ {
+ for i := uint32(0); i < diff; i++ {
f.NextField.decOpcodeIndex()
}
linkPrevToNextField(f, removedFields)
@@ -1179,7 +1206,7 @@ func anonymousStructFieldPairMap(tags runtime.StructTags, named string, valueCod
anonymousFields[key] = append(anonymousFields[key], structFieldPair{
prevField: prevAnonymousField,
curField: f,
- isTaggedKey: f.IsTaggedKey,
+ isTaggedKey: (f.Flags & IsTaggedKeyFlags) != 0,
})
if f.Next != nil && f.NextField != f.Next && f.Next.Op.CodeType() == CodeStructField {
for k, v := range anonymousFieldPairRecursively(named, f.Next) {
@@ -1200,12 +1227,12 @@ func anonymousFieldPairRecursively(named string, valueCode *Opcode) map[string][
f := valueCode
var prevAnonymousField *Opcode
for {
- if f.DisplayKey != "" && f.AnonymousHead {
+ if f.DisplayKey != "" && (f.Flags&AnonymousHeadFlags) != 0 {
key := fmt.Sprintf("%s.%s", named, f.DisplayKey)
anonymousFields[key] = append(anonymousFields[key], structFieldPair{
prevField: prevAnonymousField,
curField: f,
- isTaggedKey: f.IsTaggedKey,
+ isTaggedKey: (f.Flags & IsTaggedKeyFlags) != 0,
})
if f.Next != nil && f.NextField != f.Next && f.Next.Op.CodeType() == CodeStructField {
for k, v := range anonymousFieldPairRecursively(named, f.Next) {
@@ -1238,11 +1265,11 @@ func optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPai
if fieldPair.prevField == nil {
// head operation
fieldPair.curField.Op = OpStructHead
- fieldPair.curField.AnonymousHead = true
- fieldPair.curField.AnonymousKey = true
+ fieldPair.curField.Flags |= AnonymousHeadFlags
+ fieldPair.curField.Flags |= AnonymousKeyFlags
} else {
diff := fieldPair.curField.NextField.DisplayIdx - fieldPair.curField.DisplayIdx
- for i := 0; i < diff; i++ {
+ for i := uint32(0); i < diff; i++ {
fieldPair.curField.NextField.decOpcodeIndex()
}
removedFields[fieldPair.curField] = struct{}{}
@@ -1258,12 +1285,12 @@ func optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPai
if fieldPair.prevField == nil {
// head operation
fieldPair.curField.Op = OpStructHead
- fieldPair.curField.AnonymousHead = true
- fieldPair.curField.AnonymousKey = true
+ fieldPair.curField.Flags |= AnonymousHeadFlags
+ fieldPair.curField.Flags |= AnonymousKeyFlags
} else {
diff := fieldPair.curField.NextField.DisplayIdx - fieldPair.curField.DisplayIdx
removedFields[fieldPair.curField] = struct{}{}
- for i := 0; i < diff; i++ {
+ for i := uint32(0); i < diff; i++ {
fieldPair.curField.NextField.decOpcodeIndex()
}
linkPrevToNextField(fieldPair.curField, removedFields)
@@ -1273,7 +1300,7 @@ func optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPai
}
} else {
for _, fieldPair := range taggedPairs {
- fieldPair.curField.IsTaggedKey = false
+ fieldPair.curField.Flags &= ^IsTaggedKeyFlags
}
}
}
@@ -1390,7 +1417,7 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
valueCode = code
}
- if field.Anonymous {
+ if field.Anonymous && !tag.IsTaggedKey {
tagKey := ""
if tag.IsTaggedKey {
tagKey = tag.Key
@@ -1398,50 +1425,76 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
for k, v := range anonymousStructFieldPairMap(tags, tagKey, valueCode) {
anonymousFields[k] = append(anonymousFields[k], v...)
}
+
valueCode.decIndent()
// fix issue144
if !(isPtr && strings.Contains(valueCode.Op.String(), "Marshal")) {
- valueCode.Indirect = indirect
+ if indirect {
+ valueCode.Flags |= IndirectFlags
+ } else {
+ valueCode.Flags &= ^IndirectFlags
+ }
}
} else {
if indirect {
// if parent is indirect type, set child indirect property to true
- valueCode.Indirect = indirect
+ valueCode.Flags |= IndirectFlags
} else {
- // if parent is not indirect type and child have only one field, set child indirect property to false
- if i == 0 && valueCode.NextField != nil && valueCode.NextField.Op == OpStructEnd {
- valueCode.Indirect = indirect
+ // if parent is not indirect type, set child indirect property to false.
+ // but if parent's indirect is false and isPtr is true, then indirect must be true.
+ // Do this only if indirectConversion is enabled at the end of compileStruct.
+ if i == 0 {
+ valueCode.Flags &= ^IndirectFlags
}
}
}
- key := fmt.Sprintf(`"%s":`, tag.Key)
- escapedKey := fmt.Sprintf(`%s:`, string(AppendEscapedString([]byte{}, tag.Key)))
+ var flags OpFlags
+ if indirect {
+ flags |= IndirectFlags
+ }
+ if field.Anonymous {
+ flags |= AnonymousKeyFlags
+ }
+ if tag.IsTaggedKey {
+ flags |= IsTaggedKeyFlags
+ }
+ if nilcheck {
+ flags |= NilCheckFlags
+ }
+ if addrForMarshaler {
+ flags |= AddrForMarshalerFlags
+ }
+ if strings.Contains(valueCode.Op.String(), "Ptr") || valueCode.Op == OpInterface {
+ flags |= IsNextOpPtrTypeFlags
+ }
+ if isNilableType {
+ flags |= IsNilableTypeFlags
+ }
+ var key string
+ if ctx.escapeKey {
+ rctx := &RuntimeContext{Option: &Option{Flag: HTMLEscapeOption}}
+ key = fmt.Sprintf(`%s:`, string(AppendString(rctx, []byte{}, tag.Key)))
+ } else {
+ key = fmt.Sprintf(`"%s":`, tag.Key)
+ }
fieldCode := &Opcode{
- Type: valueCode.Type,
- DisplayIdx: fieldOpcodeIndex,
- Idx: opcodeOffset(fieldPtrIndex),
- Next: valueCode,
- Indent: ctx.indent,
- AnonymousKey: field.Anonymous,
- Key: []byte(key),
- EscapedKey: []byte(escapedKey),
- IsTaggedKey: tag.IsTaggedKey,
- DisplayKey: tag.Key,
- Offset: field.Offset,
- Indirect: indirect,
- Nilcheck: nilcheck,
- AddrForMarshaler: addrForMarshaler,
- IsNextOpPtrType: strings.Contains(valueCode.Op.String(), "Ptr") || valueCode.Op == OpInterface,
- IsNilableType: isNilableType,
+ Idx: opcodeOffset(fieldPtrIndex),
+ Next: valueCode,
+ Flags: flags,
+ Key: key,
+ Offset: uint32(field.Offset),
+ Type: valueCode.Type,
+ DisplayIdx: fieldOpcodeIndex,
+ Indent: ctx.indent,
+ DisplayKey: tag.Key,
}
if fieldIdx == 0 {
- fieldCode.HeadIdx = fieldCode.Idx
code = structHeader(ctx, fieldCode, valueCode, tag)
head = fieldCode
prevField = fieldCode
} else {
- fieldCode.HeadIdx = head.HeadIdx
+ fieldCode.Idx = head.Idx
code.Next = fieldCode
code = structField(ctx, fieldCode, valueCode, tag)
prevField.NextField = fieldCode
@@ -1455,7 +1508,6 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
Op: OpStructEnd,
Type: nil,
Indent: ctx.indent,
- Next: newEndOp(ctx),
}
ctx = ctx.decIndent()
@@ -1464,12 +1516,11 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
if head == nil {
head = &Opcode{
Op: OpStructHead,
+ Idx: opcodeOffset(ctx.ptrIndex),
+ NextField: structEndCode,
Type: typ,
DisplayIdx: ctx.opcodeIndex,
- Idx: opcodeOffset(ctx.ptrIndex),
- HeadIdx: opcodeOffset(ctx.ptrIndex),
Indent: ctx.indent,
- NextField: structEndCode,
}
structEndCode.PrevField = head
ctx.incIndex()
@@ -1479,6 +1530,7 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
structEndCode.DisplayIdx = ctx.opcodeIndex
structEndCode.Idx = opcodeOffset(ctx.ptrIndex)
ctx.incIndex()
+ structEndCode.Next = newEndOp(ctx)
if prevField != nil && prevField.NextField == nil {
prevField.NextField = structEndCode
@@ -1494,15 +1546,23 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
delete(ctx.structTypeToCompiledCode, typeptr)
- if !disableIndirectConversion && !head.Indirect && isPtr {
- head.Indirect = true
+ if !disableIndirectConversion && (head.Flags&IndirectFlags == 0) && isPtr {
+ headCode := head
+ for strings.Contains(headCode.Op.String(), "Head") {
+ headCode.Flags |= IndirectFlags
+ headCode = headCode.Next
+ }
}
return ret, nil
}
+func implementsMarshalJSONType(typ *runtime.Type) bool {
+ return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType)
+}
+
func isPtrMarshalJSONType(typ *runtime.Type) bool {
- return !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType)
+ return !implementsMarshalJSONType(typ) && implementsMarshalJSONType(runtime.PtrTo(typ))
}
func isPtrMarshalTextType(typ *runtime.Type) bool {