summaryrefslogtreecommitdiff
path: root/vendor/github.com/goccy/go-json/internal/encoder/compiler.go
diff options
context:
space:
mode:
authorLibravatar Tobi Smethurst <31960611+tsmethurst@users.noreply.github.com>2021-08-12 21:03:24 +0200
committerLibravatar GitHub <noreply@github.com>2021-08-12 21:03:24 +0200
commit98263a7de64269898a2f81207e38943b5c8e8653 (patch)
tree743c90f109a6c5d27832d1dcef2388d939f0f77a /vendor/github.com/goccy/go-json/internal/encoder/compiler.go
parentText duplication fix (#137) (diff)
downloadgotosocial-98263a7de64269898a2f81207e38943b5c8e8653.tar.xz
Grand test fixup (#138)
* start fixing up tests * fix up tests + automate with drone * fiddle with linting * messing about with drone.yml * some more fiddling * hmmm * add cache * add vendor directory * verbose * ci updates * update some little things * update sig
Diffstat (limited to 'vendor/github.com/goccy/go-json/internal/encoder/compiler.go')
-rw-r--r--vendor/github.com/goccy/go-json/internal/encoder/compiler.go1510
1 files changed, 1510 insertions, 0 deletions
diff --git a/vendor/github.com/goccy/go-json/internal/encoder/compiler.go b/vendor/github.com/goccy/go-json/internal/encoder/compiler.go
new file mode 100644
index 000000000..486c80f18
--- /dev/null
+++ b/vendor/github.com/goccy/go-json/internal/encoder/compiler.go
@@ -0,0 +1,1510 @@
+package encoder
+
+import (
+ "encoding"
+ "encoding/json"
+ "fmt"
+ "reflect"
+ "strings"
+ "sync/atomic"
+ "unsafe"
+
+ "github.com/goccy/go-json/internal/errors"
+ "github.com/goccy/go-json/internal/runtime"
+)
+
+var (
+ marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
+ marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
+ jsonNumberType = reflect.TypeOf(json.Number(""))
+ cachedOpcodeSets []*OpcodeSet
+ cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet
+ typeAddr *runtime.TypeAddr
+)
+
+func init() {
+ typeAddr = runtime.AnalyzeTypeAddr()
+ if typeAddr == nil {
+ typeAddr = &runtime.TypeAddr{}
+ }
+ cachedOpcodeSets = make([]*OpcodeSet, typeAddr.AddrRange>>typeAddr.AddrShift)
+}
+
+func loadOpcodeMap() map[uintptr]*OpcodeSet {
+ p := atomic.LoadPointer(&cachedOpcodeMap)
+ return *(*map[uintptr]*OpcodeSet)(unsafe.Pointer(&p))
+}
+
+func storeOpcodeSet(typ uintptr, set *OpcodeSet, m map[uintptr]*OpcodeSet) {
+ newOpcodeMap := make(map[uintptr]*OpcodeSet, len(m)+1)
+ newOpcodeMap[typ] = set
+
+ for k, v := range m {
+ newOpcodeMap[k] = v
+ }
+
+ atomic.StorePointer(&cachedOpcodeMap, *(*unsafe.Pointer)(unsafe.Pointer(&newOpcodeMap)))
+}
+
+func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) {
+ opcodeMap := loadOpcodeMap()
+ if codeSet, exists := opcodeMap[typeptr]; exists {
+ return codeSet, nil
+ }
+
+ // noescape trick for header.typ ( reflect.*rtype )
+ copiedType := *(**runtime.Type)(unsafe.Pointer(&typeptr))
+
+ code, err := compileHead(&compileContext{
+ typ: copiedType,
+ structTypeToCompiledCode: map[uintptr]*CompiledCode{},
+ })
+ if err != nil {
+ return nil, err
+ }
+ code = copyOpcode(code)
+ codeLength := code.TotalLength()
+ codeSet := &OpcodeSet{
+ Type: copiedType,
+ Code: code,
+ CodeLength: codeLength,
+ }
+ storeOpcodeSet(typeptr, codeSet, opcodeMap)
+ return codeSet, nil
+}
+
+func compileHead(ctx *compileContext) (*Opcode, error) {
+ typ := ctx.typ
+ switch {
+ case implementsMarshalJSON(typ):
+ return compileMarshalJSON(ctx)
+ case implementsMarshalText(typ):
+ return compileMarshalText(ctx)
+ }
+
+ isPtr := false
+ orgType := typ
+ if typ.Kind() == reflect.Ptr {
+ typ = typ.Elem()
+ isPtr = true
+ }
+ switch {
+ case implementsMarshalJSON(typ):
+ return compileMarshalJSON(ctx)
+ case implementsMarshalText(typ):
+ return compileMarshalText(ctx)
+ }
+ switch typ.Kind() {
+ case reflect.Slice:
+ ctx := ctx.withType(typ)
+ elem := typ.Elem()
+ if elem.Kind() == reflect.Uint8 {
+ p := runtime.PtrTo(elem)
+ if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) {
+ if isPtr {
+ return compileBytesPtr(ctx)
+ }
+ return compileBytes(ctx)
+ }
+ }
+ code, err := compileSlice(ctx)
+ if err != nil {
+ return nil, err
+ }
+ optimizeStructEnd(code)
+ linkRecursiveCode(code)
+ return code, nil
+ case reflect.Map:
+ if isPtr {
+ return compilePtr(ctx.withType(runtime.PtrTo(typ)))
+ }
+ code, err := compileMap(ctx.withType(typ))
+ if err != nil {
+ return nil, err
+ }
+ optimizeStructEnd(code)
+ linkRecursiveCode(code)
+ return code, nil
+ case reflect.Struct:
+ code, err := compileStruct(ctx.withType(typ), isPtr)
+ if err != nil {
+ return nil, err
+ }
+ optimizeStructEnd(code)
+ linkRecursiveCode(code)
+ return code, nil
+ case reflect.Int:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileIntPtr(ctx)
+ }
+ return compileInt(ctx)
+ case reflect.Int8:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileInt8Ptr(ctx)
+ }
+ return compileInt8(ctx)
+ case reflect.Int16:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileInt16Ptr(ctx)
+ }
+ return compileInt16(ctx)
+ case reflect.Int32:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileInt32Ptr(ctx)
+ }
+ return compileInt32(ctx)
+ case reflect.Int64:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileInt64Ptr(ctx)
+ }
+ return compileInt64(ctx)
+ case reflect.Uint, reflect.Uintptr:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileUintPtr(ctx)
+ }
+ return compileUint(ctx)
+ case reflect.Uint8:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileUint8Ptr(ctx)
+ }
+ return compileUint8(ctx)
+ case reflect.Uint16:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileUint16Ptr(ctx)
+ }
+ return compileUint16(ctx)
+ case reflect.Uint32:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileUint32Ptr(ctx)
+ }
+ return compileUint32(ctx)
+ case reflect.Uint64:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileUint64Ptr(ctx)
+ }
+ return compileUint64(ctx)
+ case reflect.Float32:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileFloat32Ptr(ctx)
+ }
+ return compileFloat32(ctx)
+ case reflect.Float64:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileFloat64Ptr(ctx)
+ }
+ return compileFloat64(ctx)
+ case reflect.String:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileStringPtr(ctx)
+ }
+ return compileString(ctx)
+ case reflect.Bool:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileBoolPtr(ctx)
+ }
+ return compileBool(ctx)
+ case reflect.Interface:
+ ctx := ctx.withType(typ)
+ if isPtr {
+ return compileInterfacePtr(ctx)
+ }
+ return compileInterface(ctx)
+ default:
+ if isPtr && typ.Implements(marshalTextType) {
+ typ = orgType
+ }
+ code, err := compile(ctx.withType(typ), isPtr)
+ if err != nil {
+ return nil, err
+ }
+ optimizeStructEnd(code)
+ linkRecursiveCode(code)
+ return code, nil
+ }
+}
+
+func linkRecursiveCode(c *Opcode) {
+ for code := c; code.Op != OpEnd && code.Op != OpRecursiveEnd; {
+ switch code.Op {
+ case OpRecursive, OpRecursivePtr:
+ if code.Jmp.Linked {
+ code = code.Next
+ continue
+ }
+ code.Jmp.Code = copyOpcode(code.Jmp.Code)
+ c := code.Jmp.Code
+ c.End.Next = newEndOp(&compileContext{})
+ c.Op = c.Op.PtrHeadToHead()
+
+ beforeLastCode := c.End
+ lastCode := beforeLastCode.Next
+
+ lastCode.Idx = beforeLastCode.Idx + uintptrSize
+ lastCode.ElemIdx = lastCode.Idx + uintptrSize
+ lastCode.Length = lastCode.Idx + 2*uintptrSize
+
+ // extend length to alloc slot for elemIdx + length
+ totalLength := uintptr(code.TotalLength() + 2)
+ nextTotalLength := uintptr(c.TotalLength() + 2)
+
+ c.End.Next.Op = OpRecursiveEnd
+
+ code.Jmp.CurLen = totalLength
+ code.Jmp.NextLen = nextTotalLength
+ code.Jmp.Linked = true
+
+ linkRecursiveCode(code.Jmp.Code)
+ code = code.Next
+ continue
+ }
+ switch code.Op.CodeType() {
+ case CodeArrayElem, CodeSliceElem, CodeMapKey:
+ code = code.End
+ default:
+ code = code.Next
+ }
+ }
+}
+
+func optimizeStructEnd(c *Opcode) {
+ for code := c; code.Op != OpEnd; {
+ if code.Op == OpRecursive || code.Op == OpRecursivePtr {
+ // ignore if exists recursive operation
+ return
+ }
+ switch code.Op.CodeType() {
+ case CodeArrayElem, CodeSliceElem, CodeMapKey:
+ code = code.End
+ default:
+ code = code.Next
+ }
+ }
+
+ for code := c; code.Op != OpEnd; {
+ switch code.Op.CodeType() {
+ case CodeArrayElem, CodeSliceElem, CodeMapKey:
+ code = code.End
+ case CodeStructEnd:
+ switch code.Op {
+ case OpStructEnd:
+ prev := code.PrevField
+ prevOp := prev.Op.String()
+ if strings.Contains(prevOp, "Head") ||
+ strings.Contains(prevOp, "Slice") ||
+ strings.Contains(prevOp, "Array") ||
+ strings.Contains(prevOp, "Map") ||
+ strings.Contains(prevOp, "MarshalJSON") ||
+ strings.Contains(prevOp, "MarshalText") {
+ // not exists field
+ code = code.Next
+ break
+ }
+ if prev.Op != prev.Op.FieldToEnd() {
+ prev.Op = prev.Op.FieldToEnd()
+ prev.Next = code.Next
+ }
+ code = code.Next
+ default:
+ code = code.Next
+ }
+ default:
+ code = code.Next
+ }
+ }
+}
+
+func implementsMarshalJSON(typ *runtime.Type) bool {
+ if !typ.Implements(marshalJSONType) {
+ return false
+ }
+ if typ.Kind() != reflect.Ptr {
+ return true
+ }
+ // type kind is reflect.Ptr
+ if !typ.Elem().Implements(marshalJSONType) {
+ return true
+ }
+ // needs to dereference
+ return false
+}
+
+func implementsMarshalText(typ *runtime.Type) bool {
+ if !typ.Implements(marshalTextType) {
+ return false
+ }
+ if typ.Kind() != reflect.Ptr {
+ return true
+ }
+ // type kind is reflect.Ptr
+ if !typ.Elem().Implements(marshalTextType) {
+ return true
+ }
+ // needs to dereference
+ return false
+}
+
+func compile(ctx *compileContext, isPtr bool) (*Opcode, error) {
+ typ := ctx.typ
+ switch {
+ case implementsMarshalJSON(typ):
+ return compileMarshalJSON(ctx)
+ case implementsMarshalText(typ):
+ return compileMarshalText(ctx)
+ }
+ switch typ.Kind() {
+ case reflect.Ptr:
+ return compilePtr(ctx)
+ case reflect.Slice:
+ elem := typ.Elem()
+ if elem.Kind() == reflect.Uint8 {
+ p := runtime.PtrTo(elem)
+ if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) {
+ return compileBytes(ctx)
+ }
+ }
+ return compileSlice(ctx)
+ case reflect.Array:
+ return compileArray(ctx)
+ case reflect.Map:
+ return compileMap(ctx)
+ case reflect.Struct:
+ return compileStruct(ctx, isPtr)
+ case reflect.Interface:
+ return compileInterface(ctx)
+ case reflect.Int:
+ return compileInt(ctx)
+ case reflect.Int8:
+ return compileInt8(ctx)
+ case reflect.Int16:
+ return compileInt16(ctx)
+ case reflect.Int32:
+ return compileInt32(ctx)
+ case reflect.Int64:
+ return compileInt64(ctx)
+ case reflect.Uint:
+ return compileUint(ctx)
+ case reflect.Uint8:
+ return compileUint8(ctx)
+ case reflect.Uint16:
+ return compileUint16(ctx)
+ case reflect.Uint32:
+ return compileUint32(ctx)
+ case reflect.Uint64:
+ return compileUint64(ctx)
+ case reflect.Uintptr:
+ return compileUint(ctx)
+ case reflect.Float32:
+ return compileFloat32(ctx)
+ case reflect.Float64:
+ return compileFloat64(ctx)
+ case reflect.String:
+ return compileString(ctx)
+ case reflect.Bool:
+ return compileBool(ctx)
+ }
+ return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
+}
+
+func convertPtrOp(code *Opcode) OpType {
+ ptrHeadOp := code.Op.HeadToPtrHead()
+ if code.Op != ptrHeadOp {
+ if code.PtrNum > 0 {
+ // ptr field and ptr head
+ code.PtrNum--
+ }
+ return ptrHeadOp
+ }
+ switch code.Op {
+ case OpInt:
+ return OpIntPtr
+ case OpUint:
+ return OpUintPtr
+ case OpFloat32:
+ return OpFloat32Ptr
+ case OpFloat64:
+ return OpFloat64Ptr
+ case OpString:
+ return OpStringPtr
+ case OpBool:
+ return OpBoolPtr
+ case OpBytes:
+ return OpBytesPtr
+ case OpNumber:
+ return OpNumberPtr
+ case OpArray:
+ return OpArrayPtr
+ case OpSlice:
+ return OpSlicePtr
+ case OpMap:
+ return OpMapPtr
+ case OpMarshalJSON:
+ return OpMarshalJSONPtr
+ case OpMarshalText:
+ return OpMarshalTextPtr
+ case OpInterface:
+ return OpInterfacePtr
+ case OpRecursive:
+ return OpRecursivePtr
+ }
+ return code.Op
+}
+
+func compileKey(ctx *compileContext) (*Opcode, error) {
+ typ := ctx.typ
+ switch {
+ case implementsMarshalJSON(typ):
+ return compileMarshalJSON(ctx)
+ case implementsMarshalText(typ):
+ return compileMarshalText(ctx)
+ }
+ switch typ.Kind() {
+ case reflect.Ptr:
+ return compilePtr(ctx)
+ case reflect.Interface:
+ return compileInterface(ctx)
+ case reflect.String:
+ return compileString(ctx)
+ case reflect.Int:
+ return compileIntString(ctx)
+ case reflect.Int8:
+ return compileInt8String(ctx)
+ case reflect.Int16:
+ return compileInt16String(ctx)
+ case reflect.Int32:
+ return compileInt32String(ctx)
+ case reflect.Int64:
+ return compileInt64String(ctx)
+ case reflect.Uint:
+ return compileUintString(ctx)
+ case reflect.Uint8:
+ return compileUint8String(ctx)
+ case reflect.Uint16:
+ return compileUint16String(ctx)
+ case reflect.Uint32:
+ return compileUint32String(ctx)
+ case reflect.Uint64:
+ return compileUint64String(ctx)
+ case reflect.Uintptr:
+ return compileUintString(ctx)
+ }
+ return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
+}
+
+func compilePtr(ctx *compileContext) (*Opcode, error) {
+ code, err := compile(ctx.withType(ctx.typ.Elem()), true)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = convertPtrOp(code)
+ code.PtrNum++
+ return code, nil
+}
+
+func compileMarshalJSON(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpMarshalJSON)
+ typ := ctx.typ
+ if !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType) {
+ code.AddrForMarshaler = true
+ }
+ code.IsNilableType = isNilableType(typ)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileMarshalText(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpMarshalText)
+ typ := ctx.typ
+ if !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType) {
+ code.AddrForMarshaler = true
+ }
+ code.IsNilableType = isNilableType(typ)
+ ctx.incIndex()
+ return code, nil
+}
+
+const intSize = 32 << (^uint(0) >> 63)
+
+func compileInt(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpInt)
+ code.setMaskAndRshiftNum(intSize)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileIntPtr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileInt(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpIntPtr
+ return code, nil
+}
+
+func compileInt8(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpInt)
+ code.setMaskAndRshiftNum(8)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileInt8Ptr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileInt8(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpIntPtr
+ return code, nil
+}
+
+func compileInt16(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpInt)
+ code.setMaskAndRshiftNum(16)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileInt16Ptr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileInt16(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpIntPtr
+ return code, nil
+}
+
+func compileInt32(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpInt)
+ code.setMaskAndRshiftNum(32)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileInt32Ptr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileInt32(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpIntPtr
+ return code, nil
+}
+
+func compileInt64(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpInt)
+ code.setMaskAndRshiftNum(64)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileInt64Ptr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileInt64(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpIntPtr
+ return code, nil
+}
+
+func compileUint(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpUint)
+ code.setMaskAndRshiftNum(intSize)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileUintPtr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileUint(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpUintPtr
+ return code, nil
+}
+
+func compileUint8(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpUint)
+ code.setMaskAndRshiftNum(8)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileUint8Ptr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileUint8(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpUintPtr
+ return code, nil
+}
+
+func compileUint16(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpUint)
+ code.setMaskAndRshiftNum(16)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileUint16Ptr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileUint16(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpUintPtr
+ return code, nil
+}
+
+func compileUint32(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpUint)
+ code.setMaskAndRshiftNum(32)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileUint32Ptr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileUint32(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpUintPtr
+ return code, nil
+}
+
+func compileUint64(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpUint)
+ code.setMaskAndRshiftNum(64)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileUint64Ptr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileUint64(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpUintPtr
+ return code, nil
+}
+
+func compileIntString(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpIntString)
+ code.setMaskAndRshiftNum(intSize)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileInt8String(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpIntString)
+ code.setMaskAndRshiftNum(8)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileInt16String(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpIntString)
+ code.setMaskAndRshiftNum(16)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileInt32String(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpIntString)
+ code.setMaskAndRshiftNum(32)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileInt64String(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpIntString)
+ code.setMaskAndRshiftNum(64)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileUintString(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpUintString)
+ code.setMaskAndRshiftNum(intSize)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileUint8String(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpUintString)
+ code.setMaskAndRshiftNum(8)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileUint16String(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpUintString)
+ code.setMaskAndRshiftNum(16)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileUint32String(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpUintString)
+ code.setMaskAndRshiftNum(32)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileUint64String(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpUintString)
+ code.setMaskAndRshiftNum(64)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileFloat32(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpFloat32)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileFloat32Ptr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileFloat32(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpFloat32Ptr
+ return code, nil
+}
+
+func compileFloat64(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpFloat64)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileFloat64Ptr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileFloat64(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpFloat64Ptr
+ return code, nil
+}
+
+func compileString(ctx *compileContext) (*Opcode, error) {
+ var op OpType
+ if ctx.typ == runtime.Type2RType(jsonNumberType) {
+ op = OpNumber
+ } else {
+ op = OpString
+ }
+ code := newOpCode(ctx, op)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileStringPtr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileString(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if code.Op == OpNumber {
+ code.Op = OpNumberPtr
+ } else {
+ code.Op = OpStringPtr
+ }
+ return code, nil
+}
+
+func compileBool(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpBool)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileBoolPtr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileBool(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpBoolPtr
+ return code, nil
+}
+
+func compileBytes(ctx *compileContext) (*Opcode, error) {
+ code := newOpCode(ctx, OpBytes)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileBytesPtr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileBytes(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpBytesPtr
+ return code, nil
+}
+
+func compileInterface(ctx *compileContext) (*Opcode, error) {
+ code := newInterfaceCode(ctx)
+ ctx.incIndex()
+ return code, nil
+}
+
+func compileInterfacePtr(ctx *compileContext) (*Opcode, error) {
+ code, err := compileInterface(ctx)
+ if err != nil {
+ return nil, err
+ }
+ code.Op = OpInterfacePtr
+ return code, nil
+}
+
+func compileSlice(ctx *compileContext) (*Opcode, error) {
+ elem := ctx.typ.Elem()
+ size := elem.Size()
+
+ header := newSliceHeaderCode(ctx)
+ ctx.incIndex()
+
+ code, err := compileListElem(ctx.withType(elem).incIndent())
+ if err != nil {
+ return nil, err
+ }
+ code.Indirect = true
+
+ // header => opcode => elem => end
+ // ^ |
+ // |________|
+
+ elemCode := newSliceElemCode(ctx, header, size)
+ ctx.incIndex()
+
+ end := newOpCode(ctx, OpSliceEnd)
+ ctx.incIndex()
+
+ header.Elem = elemCode
+ header.End = end
+ header.Next = code
+ code.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(elemCode))
+ elemCode.Next = code
+ elemCode.End = end
+ return (*Opcode)(unsafe.Pointer(header)), nil
+}
+
+func compileListElem(ctx *compileContext) (*Opcode, error) {
+ typ := ctx.typ
+ switch {
+ case !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType):
+ return compileMarshalJSON(ctx)
+ case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType):
+ return compileMarshalText(ctx)
+ case typ.Kind() == reflect.Map:
+ return compilePtr(ctx.withType(runtime.PtrTo(typ)))
+ default:
+ code, err := compile(ctx, false)
+ if err != nil {
+ return nil, err
+ }
+ if code.Op == OpMapPtr {
+ code.PtrNum++
+ }
+ return code, nil
+ }
+}
+
+func compileArray(ctx *compileContext) (*Opcode, error) {
+ typ := ctx.typ
+ elem := typ.Elem()
+ alen := typ.Len()
+ size := elem.Size()
+
+ header := newArrayHeaderCode(ctx, alen)
+ ctx.incIndex()
+
+ code, err := compileListElem(ctx.withType(elem).incIndent())
+ if err != nil {
+ return nil, err
+ }
+ code.Indirect = true
+ // header => opcode => elem => end
+ // ^ |
+ // |________|
+
+ elemCode := newArrayElemCode(ctx, header, alen, size)
+ ctx.incIndex()
+
+ end := newOpCode(ctx, OpArrayEnd)
+ ctx.incIndex()
+
+ header.Elem = elemCode
+ header.End = end
+ header.Next = code
+ code.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(elemCode))
+ elemCode.Next = code
+ elemCode.End = end
+ return (*Opcode)(unsafe.Pointer(header)), nil
+}
+
+func compileMap(ctx *compileContext) (*Opcode, error) {
+ // header => code => value => code => key => code => value => code => end
+ // ^ |
+ // |_______________________|
+ ctx = ctx.incIndent()
+ header := newMapHeaderCode(ctx)
+ ctx.incIndex()
+
+ typ := ctx.typ
+ keyType := ctx.typ.Key()
+ keyCode, err := compileKey(ctx.withType(keyType))
+ if err != nil {
+ return nil, err
+ }
+
+ value := newMapValueCode(ctx, header)
+ ctx.incIndex()
+
+ valueCode, err := compileMapValue(ctx.withType(typ.Elem()))
+ if err != nil {
+ return nil, err
+ }
+ valueCode.Indirect = true
+
+ key := newMapKeyCode(ctx, header)
+ ctx.incIndex()
+
+ ctx = ctx.decIndent()
+
+ header.MapKey = key
+ header.MapValue = value
+
+ end := newMapEndCode(ctx, header)
+ ctx.incIndex()
+
+ header.Next = keyCode
+ keyCode.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(value))
+ value.Next = valueCode
+ valueCode.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(key))
+ key.Next = keyCode
+
+ header.End = end
+ key.End = end
+ value.End = end
+
+ return (*Opcode)(unsafe.Pointer(header)), nil
+}
+
+func compileMapValue(ctx *compileContext) (*Opcode, error) {
+ switch ctx.typ.Kind() {
+ case reflect.Map:
+ return compilePtr(ctx.withType(runtime.PtrTo(ctx.typ)))
+ default:
+ code, err := compile(ctx, false)
+ if err != nil {
+ return nil, err
+ }
+ if code.Op == OpMapPtr {
+ code.PtrNum++
+ }
+ return code, nil
+ }
+}
+
+func optimizeStructHeader(code *Opcode, tag *runtime.StructTag) OpType {
+ headType := code.ToHeaderType(tag.IsString)
+ if tag.IsOmitEmpty {
+ headType = headType.HeadToOmitEmptyHead()
+ }
+ return headType
+}
+
+func optimizeStructField(code *Opcode, tag *runtime.StructTag) OpType {
+ fieldType := code.ToFieldType(tag.IsString)
+ if tag.IsOmitEmpty {
+ fieldType = fieldType.FieldToOmitEmptyField()
+ }
+ return fieldType
+}
+
+func recursiveCode(ctx *compileContext, jmp *CompiledCode) *Opcode {
+ code := newRecursiveCode(ctx, jmp)
+ ctx.incIndex()
+ return code
+}
+
+func compiledCode(ctx *compileContext) *Opcode {
+ typ := ctx.typ
+ typeptr := uintptr(unsafe.Pointer(typ))
+ if cc, exists := ctx.structTypeToCompiledCode[typeptr]; exists {
+ return recursiveCode(ctx, cc)
+ }
+ return nil
+}
+
+func structHeader(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag *runtime.StructTag) *Opcode {
+ op := optimizeStructHeader(valueCode, tag)
+ fieldCode.Op = op
+ fieldCode.Mask = valueCode.Mask
+ fieldCode.RshiftNum = valueCode.RshiftNum
+ fieldCode.PtrNum = valueCode.PtrNum
+ if op.IsMultipleOpHead() {
+ return valueCode.BeforeLastCode()
+ }
+ ctx.decOpcodeIndex()
+ return fieldCode
+}
+
+func structField(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag *runtime.StructTag) *Opcode {
+ op := optimizeStructField(valueCode, tag)
+ fieldCode.Op = op
+ fieldCode.PtrNum = valueCode.PtrNum
+ fieldCode.Mask = valueCode.Mask
+ fieldCode.RshiftNum = valueCode.RshiftNum
+ if op.IsMultipleOpField() {
+ return valueCode.BeforeLastCode()
+ }
+ ctx.decIndex()
+ return fieldCode
+}
+
+func isNotExistsField(head *Opcode) bool {
+ if head == nil {
+ return false
+ }
+ if head.Op != OpStructHead {
+ return false
+ }
+ if !head.AnonymousHead {
+ return false
+ }
+ if head.Next == nil {
+ return false
+ }
+ if head.NextField == nil {
+ return false
+ }
+ if head.NextField.Op != OpStructAnonymousEnd {
+ return false
+ }
+ if head.Next.Op == OpStructAnonymousEnd {
+ return true
+ }
+ if head.Next.Op.CodeType() != CodeStructField {
+ return false
+ }
+ return isNotExistsField(head.Next)
+}
+
+func optimizeAnonymousFields(head *Opcode) {
+ code := head
+ var prev *Opcode
+ removedFields := map[*Opcode]struct{}{}
+ for {
+ if code.Op == OpStructEnd {
+ break
+ }
+ if code.Op == OpStructField {
+ codeType := code.Next.Op.CodeType()
+ if codeType == CodeStructField {
+ if isNotExistsField(code.Next) {
+ code.Next = code.NextField
+ diff := code.Next.DisplayIdx - code.DisplayIdx
+ for i := 0; i < diff; i++ {
+ code.Next.decOpcodeIndex()
+ }
+ linkPrevToNextField(code, removedFields)
+ code = prev
+ }
+ }
+ }
+ prev = code
+ code = code.NextField
+ }
+}
+
+type structFieldPair struct {
+ prevField *Opcode
+ curField *Opcode
+ isTaggedKey bool
+ linked bool
+}
+
+func anonymousStructFieldPairMap(tags runtime.StructTags, named string, valueCode *Opcode) map[string][]structFieldPair {
+ anonymousFields := map[string][]structFieldPair{}
+ f := valueCode
+ var prevAnonymousField *Opcode
+ removedFields := map[*Opcode]struct{}{}
+ for {
+ existsKey := tags.ExistsKey(f.DisplayKey)
+ isHeadOp := strings.Contains(f.Op.String(), "Head")
+ if existsKey && f.Next != nil && strings.Contains(f.Next.Op.String(), "Recursive") {
+ // through
+ } else if isHeadOp && !f.AnonymousHead {
+ if existsKey {
+ // TODO: need to remove this head
+ f.Op = OpStructHead
+ f.AnonymousKey = true
+ f.AnonymousHead = true
+ } else if named == "" {
+ f.AnonymousHead = true
+ }
+ } else if named == "" && f.Op == OpStructEnd {
+ f.Op = OpStructAnonymousEnd
+ } else if existsKey {
+ diff := f.NextField.DisplayIdx - f.DisplayIdx
+ for i := 0; i < diff; i++ {
+ f.NextField.decOpcodeIndex()
+ }
+ linkPrevToNextField(f, removedFields)
+ }
+
+ if f.DisplayKey == "" {
+ if f.NextField == nil {
+ break
+ }
+ prevAnonymousField = f
+ f = f.NextField
+ continue
+ }
+
+ key := fmt.Sprintf("%s.%s", named, f.DisplayKey)
+ anonymousFields[key] = append(anonymousFields[key], structFieldPair{
+ prevField: prevAnonymousField,
+ curField: f,
+ isTaggedKey: f.IsTaggedKey,
+ })
+ if f.Next != nil && f.NextField != f.Next && f.Next.Op.CodeType() == CodeStructField {
+ for k, v := range anonymousFieldPairRecursively(named, f.Next) {
+ anonymousFields[k] = append(anonymousFields[k], v...)
+ }
+ }
+ if f.NextField == nil {
+ break
+ }
+ prevAnonymousField = f
+ f = f.NextField
+ }
+ return anonymousFields
+}
+
+func anonymousFieldPairRecursively(named string, valueCode *Opcode) map[string][]structFieldPair {
+ anonymousFields := map[string][]structFieldPair{}
+ f := valueCode
+ var prevAnonymousField *Opcode
+ for {
+ if f.DisplayKey != "" && f.AnonymousHead {
+ key := fmt.Sprintf("%s.%s", named, f.DisplayKey)
+ anonymousFields[key] = append(anonymousFields[key], structFieldPair{
+ prevField: prevAnonymousField,
+ curField: f,
+ isTaggedKey: f.IsTaggedKey,
+ })
+ if f.Next != nil && f.NextField != f.Next && f.Next.Op.CodeType() == CodeStructField {
+ for k, v := range anonymousFieldPairRecursively(named, f.Next) {
+ anonymousFields[k] = append(anonymousFields[k], v...)
+ }
+ }
+ }
+ if f.NextField == nil {
+ break
+ }
+ prevAnonymousField = f
+ f = f.NextField
+ }
+ return anonymousFields
+}
+
+func optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPair) {
+ removedFields := map[*Opcode]struct{}{}
+ for _, fieldPairs := range anonymousFields {
+ if len(fieldPairs) == 1 {
+ continue
+ }
+ // conflict anonymous fields
+ taggedPairs := []structFieldPair{}
+ for _, fieldPair := range fieldPairs {
+ if fieldPair.isTaggedKey {
+ taggedPairs = append(taggedPairs, fieldPair)
+ } else {
+ if !fieldPair.linked {
+ if fieldPair.prevField == nil {
+ // head operation
+ fieldPair.curField.Op = OpStructHead
+ fieldPair.curField.AnonymousHead = true
+ fieldPair.curField.AnonymousKey = true
+ } else {
+ diff := fieldPair.curField.NextField.DisplayIdx - fieldPair.curField.DisplayIdx
+ for i := 0; i < diff; i++ {
+ fieldPair.curField.NextField.decOpcodeIndex()
+ }
+ removedFields[fieldPair.curField] = struct{}{}
+ linkPrevToNextField(fieldPair.curField, removedFields)
+ }
+ fieldPair.linked = true
+ }
+ }
+ }
+ if len(taggedPairs) > 1 {
+ for _, fieldPair := range taggedPairs {
+ if !fieldPair.linked {
+ if fieldPair.prevField == nil {
+ // head operation
+ fieldPair.curField.Op = OpStructHead
+ fieldPair.curField.AnonymousHead = true
+ fieldPair.curField.AnonymousKey = true
+ } else {
+ diff := fieldPair.curField.NextField.DisplayIdx - fieldPair.curField.DisplayIdx
+ removedFields[fieldPair.curField] = struct{}{}
+ for i := 0; i < diff; i++ {
+ fieldPair.curField.NextField.decOpcodeIndex()
+ }
+ linkPrevToNextField(fieldPair.curField, removedFields)
+ }
+ fieldPair.linked = true
+ }
+ }
+ } else {
+ for _, fieldPair := range taggedPairs {
+ fieldPair.curField.IsTaggedKey = false
+ }
+ }
+ }
+}
+
+func isNilableType(typ *runtime.Type) bool {
+ switch typ.Kind() {
+ case reflect.Ptr:
+ return true
+ case reflect.Map:
+ return true
+ case reflect.Func:
+ return true
+ default:
+ return false
+ }
+}
+
+func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
+ if code := compiledCode(ctx); code != nil {
+ return code, nil
+ }
+ typ := ctx.typ
+ typeptr := uintptr(unsafe.Pointer(typ))
+ compiled := &CompiledCode{}
+ ctx.structTypeToCompiledCode[typeptr] = compiled
+ // header => code => structField => code => end
+ // ^ |
+ // |__________|
+ fieldNum := typ.NumField()
+ indirect := runtime.IfaceIndir(typ)
+ fieldIdx := 0
+ disableIndirectConversion := false
+ var (
+ head *Opcode
+ code *Opcode
+ prevField *Opcode
+ )
+ ctx = ctx.incIndent()
+ tags := runtime.StructTags{}
+ anonymousFields := map[string][]structFieldPair{}
+ for i := 0; i < fieldNum; i++ {
+ field := typ.Field(i)
+ if runtime.IsIgnoredStructField(field) {
+ continue
+ }
+ tags = append(tags, runtime.StructTagFromField(field))
+ }
+ for i, tag := range tags {
+ field := tag.Field
+ fieldType := runtime.Type2RType(field.Type)
+ fieldOpcodeIndex := ctx.opcodeIndex
+ fieldPtrIndex := ctx.ptrIndex
+ ctx.incIndex()
+
+ nilcheck := true
+ addrForMarshaler := false
+ isIndirectSpecialCase := isPtr && i == 0 && fieldNum == 1
+ isNilableType := isNilableType(fieldType)
+
+ var valueCode *Opcode
+ switch {
+ case isIndirectSpecialCase && !isNilableType && isPtrMarshalJSONType(fieldType):
+ // *struct{ field T } => struct { field *T }
+ // func (*T) MarshalJSON() ([]byte, error)
+ // move pointer position from head to first field
+ code, err := compileMarshalJSON(ctx.withType(fieldType))
+ if err != nil {
+ return nil, err
+ }
+ addrForMarshaler = true
+ valueCode = code
+ nilcheck = false
+ indirect = false
+ disableIndirectConversion = true
+ case isIndirectSpecialCase && !isNilableType && isPtrMarshalTextType(fieldType):
+ // *struct{ field T } => struct { field *T }
+ // func (*T) MarshalText() ([]byte, error)
+ // move pointer position from head to first field
+ code, err := compileMarshalText(ctx.withType(fieldType))
+ if err != nil {
+ return nil, err
+ }
+ addrForMarshaler = true
+ valueCode = code
+ nilcheck = false
+ indirect = false
+ disableIndirectConversion = true
+ case isPtr && isPtrMarshalJSONType(fieldType):
+ // *struct{ field T }
+ // func (*T) MarshalJSON() ([]byte, error)
+ code, err := compileMarshalJSON(ctx.withType(fieldType))
+ if err != nil {
+ return nil, err
+ }
+ addrForMarshaler = true
+ nilcheck = false
+ valueCode = code
+ case isPtr && isPtrMarshalTextType(fieldType):
+ // *struct{ field T }
+ // func (*T) MarshalText() ([]byte, error)
+ code, err := compileMarshalText(ctx.withType(fieldType))
+ if err != nil {
+ return nil, err
+ }
+ addrForMarshaler = true
+ nilcheck = false
+ valueCode = code
+ default:
+ code, err := compile(ctx.withType(fieldType), isPtr)
+ if err != nil {
+ return nil, err
+ }
+ valueCode = code
+ }
+
+ if field.Anonymous {
+ tagKey := ""
+ if tag.IsTaggedKey {
+ tagKey = tag.Key
+ }
+ for k, v := range anonymousStructFieldPairMap(tags, tagKey, valueCode) {
+ anonymousFields[k] = append(anonymousFields[k], v...)
+ }
+ valueCode.decIndent()
+
+ // fix issue144
+ if !(isPtr && strings.Contains(valueCode.Op.String(), "Marshal")) {
+ valueCode.Indirect = indirect
+ }
+ } else {
+ if indirect {
+ // if parent is indirect type, set child indirect property to true
+ valueCode.Indirect = indirect
+ } else {
+ // if parent is not indirect type and child have only one field, set child indirect property to false
+ if i == 0 && valueCode.NextField != nil && valueCode.NextField.Op == OpStructEnd {
+ valueCode.Indirect = indirect
+ }
+ }
+ }
+ key := fmt.Sprintf(`"%s":`, tag.Key)
+ escapedKey := fmt.Sprintf(`%s:`, string(AppendEscapedString([]byte{}, tag.Key)))
+ fieldCode := &Opcode{
+ Type: valueCode.Type,
+ DisplayIdx: fieldOpcodeIndex,
+ Idx: opcodeOffset(fieldPtrIndex),
+ Next: valueCode,
+ Indent: ctx.indent,
+ AnonymousKey: field.Anonymous,
+ Key: []byte(key),
+ EscapedKey: []byte(escapedKey),
+ IsTaggedKey: tag.IsTaggedKey,
+ DisplayKey: tag.Key,
+ Offset: field.Offset,
+ Indirect: indirect,
+ Nilcheck: nilcheck,
+ AddrForMarshaler: addrForMarshaler,
+ IsNextOpPtrType: strings.Contains(valueCode.Op.String(), "Ptr") || valueCode.Op == OpInterface,
+ IsNilableType: isNilableType,
+ }
+ if fieldIdx == 0 {
+ fieldCode.HeadIdx = fieldCode.Idx
+ code = structHeader(ctx, fieldCode, valueCode, tag)
+ head = fieldCode
+ prevField = fieldCode
+ } else {
+ fieldCode.HeadIdx = head.HeadIdx
+ code.Next = fieldCode
+ code = structField(ctx, fieldCode, valueCode, tag)
+ prevField.NextField = fieldCode
+ fieldCode.PrevField = prevField
+ prevField = fieldCode
+ }
+ fieldIdx++
+ }
+
+ structEndCode := &Opcode{
+ Op: OpStructEnd,
+ Type: nil,
+ Indent: ctx.indent,
+ Next: newEndOp(ctx),
+ }
+
+ ctx = ctx.decIndent()
+
+ // no struct field
+ if head == nil {
+ head = &Opcode{
+ Op: OpStructHead,
+ Type: typ,
+ DisplayIdx: ctx.opcodeIndex,
+ Idx: opcodeOffset(ctx.ptrIndex),
+ HeadIdx: opcodeOffset(ctx.ptrIndex),
+ Indent: ctx.indent,
+ NextField: structEndCode,
+ }
+ structEndCode.PrevField = head
+ ctx.incIndex()
+ code = head
+ }
+
+ structEndCode.DisplayIdx = ctx.opcodeIndex
+ structEndCode.Idx = opcodeOffset(ctx.ptrIndex)
+ ctx.incIndex()
+
+ if prevField != nil && prevField.NextField == nil {
+ prevField.NextField = structEndCode
+ structEndCode.PrevField = prevField
+ }
+
+ head.End = structEndCode
+ code.Next = structEndCode
+ optimizeConflictAnonymousFields(anonymousFields)
+ optimizeAnonymousFields(head)
+ ret := (*Opcode)(unsafe.Pointer(head))
+ compiled.Code = ret
+
+ delete(ctx.structTypeToCompiledCode, typeptr)
+
+ if !disableIndirectConversion && !head.Indirect && isPtr {
+ head.Indirect = true
+ }
+
+ return ret, nil
+}
+
+func isPtrMarshalJSONType(typ *runtime.Type) bool {
+ return !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType)
+}
+
+func isPtrMarshalTextType(typ *runtime.Type) bool {
+ return !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType)
+}