summaryrefslogtreecommitdiff
path: root/vendor/github.com/abema/go-mp4/marshaller.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/abema/go-mp4/marshaller.go')
-rw-r--r--vendor/github.com/abema/go-mp4/marshaller.go639
1 files changed, 639 insertions, 0 deletions
diff --git a/vendor/github.com/abema/go-mp4/marshaller.go b/vendor/github.com/abema/go-mp4/marshaller.go
new file mode 100644
index 000000000..b5c66860b
--- /dev/null
+++ b/vendor/github.com/abema/go-mp4/marshaller.go
@@ -0,0 +1,639 @@
+package mp4
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "reflect"
+
+ "github.com/abema/go-mp4/bitio"
+)
+
+const (
+ anyVersion = math.MaxUint8
+)
+
+var ErrUnsupportedBoxVersion = errors.New("unsupported box version")
+
+type marshaller struct {
+ writer bitio.Writer
+ wbits uint64
+ src IImmutableBox
+ ctx Context
+}
+
+func Marshal(w io.Writer, src IImmutableBox, ctx Context) (n uint64, err error) {
+ boxDef := src.GetType().getBoxDef(ctx)
+ if boxDef == nil {
+ return 0, ErrBoxInfoNotFound
+ }
+
+ v := reflect.ValueOf(src).Elem()
+
+ m := &marshaller{
+ writer: bitio.NewWriter(w),
+ src: src,
+ ctx: ctx,
+ }
+
+ if err := m.marshalStruct(v, boxDef.fields); err != nil {
+ return 0, err
+ }
+
+ if m.wbits%8 != 0 {
+ return 0, fmt.Errorf("box size is not multiple of 8 bits: type=%s, bits=%d", src.GetType().String(), m.wbits)
+ }
+
+ return m.wbits / 8, nil
+}
+
+func (m *marshaller) marshal(v reflect.Value, fi *fieldInstance) error {
+ switch v.Type().Kind() {
+ case reflect.Ptr:
+ return m.marshalPtr(v, fi)
+ case reflect.Struct:
+ return m.marshalStruct(v, fi.children)
+ case reflect.Array:
+ return m.marshalArray(v, fi)
+ case reflect.Slice:
+ return m.marshalSlice(v, fi)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return m.marshalInt(v, fi)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return m.marshalUint(v, fi)
+ case reflect.Bool:
+ return m.marshalBool(v, fi)
+ case reflect.String:
+ return m.marshalString(v)
+ default:
+ return fmt.Errorf("unsupported type: %s", v.Type().Kind())
+ }
+}
+
+func (m *marshaller) marshalPtr(v reflect.Value, fi *fieldInstance) error {
+ return m.marshal(v.Elem(), fi)
+}
+
+func (m *marshaller) marshalStruct(v reflect.Value, fs []*field) error {
+ for _, f := range fs {
+ fi := resolveFieldInstance(f, m.src, v, m.ctx)
+
+ if !isTargetField(m.src, fi, m.ctx) {
+ continue
+ }
+
+ wbits, override, err := fi.cfo.OnWriteField(f.name, m.writer, m.ctx)
+ if err != nil {
+ return err
+ }
+ m.wbits += wbits
+ if override {
+ continue
+ }
+
+ err = m.marshal(v.FieldByName(f.name), fi)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (m *marshaller) marshalArray(v reflect.Value, fi *fieldInstance) error {
+ size := v.Type().Size()
+ for i := 0; i < int(size)/int(v.Type().Elem().Size()); i++ {
+ var err error
+ err = m.marshal(v.Index(i), fi)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (m *marshaller) marshalSlice(v reflect.Value, fi *fieldInstance) error {
+ length := uint64(v.Len())
+ if fi.length != LengthUnlimited {
+ if length < uint64(fi.length) {
+ return fmt.Errorf("the slice has too few elements: required=%d actual=%d", fi.length, length)
+ }
+ length = uint64(fi.length)
+ }
+
+ elemType := v.Type().Elem()
+ if elemType.Kind() == reflect.Uint8 && fi.size == 8 && m.wbits%8 == 0 {
+ if _, err := io.CopyN(m.writer, bytes.NewBuffer(v.Bytes()), int64(length)); err != nil {
+ return err
+ }
+ m.wbits += length * 8
+ return nil
+ }
+
+ for i := 0; i < int(length); i++ {
+ m.marshal(v.Index(i), fi)
+ }
+ return nil
+}
+
+func (m *marshaller) marshalInt(v reflect.Value, fi *fieldInstance) error {
+ signed := v.Int()
+
+ if fi.is(fieldVarint) {
+ return errors.New("signed varint is unsupported")
+ }
+
+ signBit := signed < 0
+ val := uint64(signed)
+ for i := uint(0); i < fi.size; i += 8 {
+ v := val
+ size := uint(8)
+ if fi.size > i+8 {
+ v = v >> (fi.size - (i + 8))
+ } else if fi.size < i+8 {
+ size = fi.size - i
+ }
+
+ // set sign bit
+ if i == 0 {
+ if signBit {
+ v |= 0x1 << (size - 1)
+ } else {
+ v &= 0x1<<(size-1) - 1
+ }
+ }
+
+ if err := m.writer.WriteBits([]byte{byte(v)}, size); err != nil {
+ return err
+ }
+ m.wbits += uint64(size)
+ }
+
+ return nil
+}
+
+func (m *marshaller) marshalUint(v reflect.Value, fi *fieldInstance) error {
+ val := v.Uint()
+
+ if fi.is(fieldVarint) {
+ m.writeUvarint(val)
+ return nil
+ }
+
+ for i := uint(0); i < fi.size; i += 8 {
+ v := val
+ size := uint(8)
+ if fi.size > i+8 {
+ v = v >> (fi.size - (i + 8))
+ } else if fi.size < i+8 {
+ size = fi.size - i
+ }
+ if err := m.writer.WriteBits([]byte{byte(v)}, size); err != nil {
+ return err
+ }
+ m.wbits += uint64(size)
+ }
+
+ return nil
+}
+
+func (m *marshaller) marshalBool(v reflect.Value, fi *fieldInstance) error {
+ var val byte
+ if v.Bool() {
+ val = 0xff
+ } else {
+ val = 0x00
+ }
+ if err := m.writer.WriteBits([]byte{val}, fi.size); err != nil {
+ return err
+ }
+ m.wbits += uint64(fi.size)
+ return nil
+}
+
+func (m *marshaller) marshalString(v reflect.Value) error {
+ data := []byte(v.String())
+ for _, b := range data {
+ if err := m.writer.WriteBits([]byte{b}, 8); err != nil {
+ return err
+ }
+ m.wbits += 8
+ }
+ // null character
+ if err := m.writer.WriteBits([]byte{0x00}, 8); err != nil {
+ return err
+ }
+ m.wbits += 8
+ return nil
+}
+
+func (m *marshaller) writeUvarint(u uint64) error {
+ for i := 21; i > 0; i -= 7 {
+ if err := m.writer.WriteBits([]byte{(byte(u >> uint(i))) | 0x80}, 8); err != nil {
+ return err
+ }
+ m.wbits += 8
+ }
+
+ if err := m.writer.WriteBits([]byte{byte(u) & 0x7f}, 8); err != nil {
+ return err
+ }
+ m.wbits += 8
+
+ return nil
+}
+
+type unmarshaller struct {
+ reader bitio.ReadSeeker
+ dst IBox
+ size uint64
+ rbits uint64
+ ctx Context
+}
+
+func UnmarshalAny(r io.ReadSeeker, boxType BoxType, payloadSize uint64, ctx Context) (box IBox, n uint64, err error) {
+ dst, err := boxType.New(ctx)
+ if err != nil {
+ return nil, 0, err
+ }
+ n, err = Unmarshal(r, payloadSize, dst, ctx)
+ return dst, n, err
+}
+
+func Unmarshal(r io.ReadSeeker, payloadSize uint64, dst IBox, ctx Context) (n uint64, err error) {
+ boxDef := dst.GetType().getBoxDef(ctx)
+ if boxDef == nil {
+ return 0, ErrBoxInfoNotFound
+ }
+
+ v := reflect.ValueOf(dst).Elem()
+
+ dst.SetVersion(anyVersion)
+
+ u := &unmarshaller{
+ reader: bitio.NewReadSeeker(r),
+ dst: dst,
+ size: payloadSize,
+ ctx: ctx,
+ }
+
+ if n, override, err := dst.BeforeUnmarshal(r, payloadSize, u.ctx); err != nil {
+ return 0, err
+ } else if override {
+ return n, nil
+ } else {
+ u.rbits = n * 8
+ }
+
+ sn, err := r.Seek(0, io.SeekCurrent)
+ if err != nil {
+ return 0, err
+ }
+
+ if err := u.unmarshalStruct(v, boxDef.fields); err != nil {
+ if err == ErrUnsupportedBoxVersion {
+ r.Seek(sn, io.SeekStart)
+ }
+ return 0, err
+ }
+
+ if u.rbits%8 != 0 {
+ return 0, fmt.Errorf("box size is not multiple of 8 bits: type=%s, size=%d, bits=%d", dst.GetType().String(), u.size, u.rbits)
+ }
+
+ if u.rbits > u.size*8 {
+ return 0, fmt.Errorf("overrun error: type=%s, size=%d, bits=%d", dst.GetType().String(), u.size, u.rbits)
+ }
+
+ return u.rbits / 8, nil
+}
+
+func (u *unmarshaller) unmarshal(v reflect.Value, fi *fieldInstance) error {
+ var err error
+ switch v.Type().Kind() {
+ case reflect.Ptr:
+ err = u.unmarshalPtr(v, fi)
+ case reflect.Struct:
+ err = u.unmarshalStructInternal(v, fi)
+ case reflect.Array:
+ err = u.unmarshalArray(v, fi)
+ case reflect.Slice:
+ err = u.unmarshalSlice(v, fi)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ err = u.unmarshalInt(v, fi)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ err = u.unmarshalUint(v, fi)
+ case reflect.Bool:
+ err = u.unmarshalBool(v, fi)
+ case reflect.String:
+ err = u.unmarshalString(v, fi)
+ default:
+ return fmt.Errorf("unsupported type: %s", v.Type().Kind())
+ }
+ return err
+}
+
+func (u *unmarshaller) unmarshalPtr(v reflect.Value, fi *fieldInstance) error {
+ v.Set(reflect.New(v.Type().Elem()))
+ return u.unmarshal(v.Elem(), fi)
+}
+
+func (u *unmarshaller) unmarshalStructInternal(v reflect.Value, fi *fieldInstance) error {
+ if fi.size != 0 && fi.size%8 == 0 {
+ u2 := *u
+ u2.size = uint64(fi.size / 8)
+ u2.rbits = 0
+ if err := u2.unmarshalStruct(v, fi.children); err != nil {
+ return err
+ }
+ u.rbits += u2.rbits
+ if u2.rbits != uint64(fi.size) {
+ return errors.New("invalid alignment")
+ }
+ return nil
+ }
+
+ return u.unmarshalStruct(v, fi.children)
+}
+
+func (u *unmarshaller) unmarshalStruct(v reflect.Value, fs []*field) error {
+ for _, f := range fs {
+ fi := resolveFieldInstance(f, u.dst, v, u.ctx)
+
+ if !isTargetField(u.dst, fi, u.ctx) {
+ continue
+ }
+
+ rbits, override, err := fi.cfo.OnReadField(f.name, u.reader, u.size*8-u.rbits, u.ctx)
+ if err != nil {
+ return err
+ }
+ u.rbits += rbits
+ if override {
+ continue
+ }
+
+ err = u.unmarshal(v.FieldByName(f.name), fi)
+ if err != nil {
+ return err
+ }
+
+ if v.FieldByName(f.name).Type() == reflect.TypeOf(FullBox{}) && !u.dst.GetType().IsSupportedVersion(u.dst.GetVersion(), u.ctx) {
+ return ErrUnsupportedBoxVersion
+ }
+ }
+
+ return nil
+}
+
+func (u *unmarshaller) unmarshalArray(v reflect.Value, fi *fieldInstance) error {
+ size := v.Type().Size()
+ for i := 0; i < int(size)/int(v.Type().Elem().Size()); i++ {
+ var err error
+ err = u.unmarshal(v.Index(i), fi)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (u *unmarshaller) unmarshalSlice(v reflect.Value, fi *fieldInstance) error {
+ var slice reflect.Value
+ elemType := v.Type().Elem()
+
+ length := uint64(fi.length)
+ if fi.length == LengthUnlimited {
+ if fi.size != 0 {
+ left := (u.size)*8 - u.rbits
+ if left%uint64(fi.size) != 0 {
+ return errors.New("invalid alignment")
+ }
+ length = left / uint64(fi.size)
+ } else {
+ length = 0
+ }
+ }
+
+ if length > math.MaxInt32 {
+ return fmt.Errorf("out of memory: requestedSize=%d", length)
+ }
+
+ if fi.size != 0 && fi.size%8 == 0 && u.rbits%8 == 0 && elemType.Kind() == reflect.Uint8 && fi.size == 8 {
+ totalSize := length * uint64(fi.size) / 8
+ buf := bytes.NewBuffer(make([]byte, 0, totalSize))
+ if _, err := io.CopyN(buf, u.reader, int64(totalSize)); err != nil {
+ return err
+ }
+ slice = reflect.ValueOf(buf.Bytes())
+ u.rbits += uint64(totalSize) * 8
+
+ } else {
+ slice = reflect.MakeSlice(v.Type(), 0, int(length))
+ for i := 0; ; i++ {
+ if fi.length != LengthUnlimited && uint(i) >= fi.length {
+ break
+ }
+ if fi.length == LengthUnlimited && u.rbits >= u.size*8 {
+ break
+ }
+ slice = reflect.Append(slice, reflect.Zero(elemType))
+ if err := u.unmarshal(slice.Index(i), fi); err != nil {
+ return err
+ }
+ if u.rbits > u.size*8 {
+ return fmt.Errorf("failed to read array completely: fieldName=\"%s\"", fi.name)
+ }
+ }
+ }
+
+ v.Set(slice)
+ return nil
+}
+
+func (u *unmarshaller) unmarshalInt(v reflect.Value, fi *fieldInstance) error {
+ if fi.is(fieldVarint) {
+ return errors.New("signed varint is unsupported")
+ }
+
+ if fi.size == 0 {
+ return fmt.Errorf("size must not be zero: %s", fi.name)
+ }
+
+ data, err := u.reader.ReadBits(fi.size)
+ if err != nil {
+ return err
+ }
+ u.rbits += uint64(fi.size)
+
+ signBit := false
+ if len(data) > 0 {
+ signMask := byte(0x01) << ((fi.size - 1) % 8)
+ signBit = data[0]&signMask != 0
+ if signBit {
+ data[0] |= ^(signMask - 1)
+ }
+ }
+
+ var val uint64
+ if signBit {
+ val = ^uint64(0)
+ }
+ for i := range data {
+ val <<= 8
+ val |= uint64(data[i])
+ }
+ v.SetInt(int64(val))
+ return nil
+}
+
+func (u *unmarshaller) unmarshalUint(v reflect.Value, fi *fieldInstance) error {
+ if fi.is(fieldVarint) {
+ val, err := u.readUvarint()
+ if err != nil {
+ return err
+ }
+ v.SetUint(val)
+ return nil
+ }
+
+ if fi.size == 0 {
+ return fmt.Errorf("size must not be zero: %s", fi.name)
+ }
+
+ data, err := u.reader.ReadBits(fi.size)
+ if err != nil {
+ return err
+ }
+ u.rbits += uint64(fi.size)
+
+ val := uint64(0)
+ for i := range data {
+ val <<= 8
+ val |= uint64(data[i])
+ }
+ v.SetUint(val)
+
+ return nil
+}
+
+func (u *unmarshaller) unmarshalBool(v reflect.Value, fi *fieldInstance) error {
+ if fi.size == 0 {
+ return fmt.Errorf("size must not be zero: %s", fi.name)
+ }
+
+ data, err := u.reader.ReadBits(fi.size)
+ if err != nil {
+ return err
+ }
+ u.rbits += uint64(fi.size)
+
+ val := false
+ for _, b := range data {
+ val = val || (b != byte(0))
+ }
+ v.SetBool(val)
+
+ return nil
+}
+
+func (u *unmarshaller) unmarshalString(v reflect.Value, fi *fieldInstance) error {
+ switch fi.strType {
+ case stringType_C:
+ return u.unmarshalStringC(v)
+ case stringType_C_P:
+ return u.unmarshalStringCP(v, fi)
+ default:
+ return fmt.Errorf("unknown string type: %d", fi.strType)
+ }
+}
+
+func (u *unmarshaller) unmarshalStringC(v reflect.Value) error {
+ data := make([]byte, 0, 16)
+ for {
+ if u.rbits >= u.size*8 {
+ break
+ }
+
+ c, err := u.reader.ReadBits(8)
+ if err != nil {
+ return err
+ }
+ u.rbits += 8
+
+ if c[0] == 0 {
+ break // null character
+ }
+
+ data = append(data, c[0])
+ }
+ v.SetString(string(data))
+
+ return nil
+}
+
+func (u *unmarshaller) unmarshalStringCP(v reflect.Value, fi *fieldInstance) error {
+ if ok, err := u.tryReadPString(v, fi); err != nil {
+ return err
+ } else if ok {
+ return nil
+ }
+ return u.unmarshalStringC(v)
+}
+
+func (u *unmarshaller) tryReadPString(v reflect.Value, fi *fieldInstance) (ok bool, err error) {
+ remainingSize := (u.size*8 - u.rbits) / 8
+ if remainingSize < 2 {
+ return false, nil
+ }
+
+ offset, err := u.reader.Seek(0, io.SeekCurrent)
+ if err != nil {
+ return false, err
+ }
+ defer func() {
+ if err == nil && !ok {
+ _, err = u.reader.Seek(offset, io.SeekStart)
+ }
+ }()
+
+ buf0 := make([]byte, 1)
+ if _, err := io.ReadFull(u.reader, buf0); err != nil {
+ return false, err
+ }
+ remainingSize--
+ plen := buf0[0]
+ if uint64(plen) > remainingSize {
+ return false, nil
+ }
+ buf := make([]byte, int(plen))
+ if _, err := io.ReadFull(u.reader, buf); err != nil {
+ return false, err
+ }
+ remainingSize -= uint64(plen)
+ if fi.cfo.IsPString(fi.name, buf, remainingSize, u.ctx) {
+ u.rbits += uint64(len(buf)+1) * 8
+ v.SetString(string(buf))
+ return true, nil
+ }
+ return false, nil
+}
+
+func (u *unmarshaller) readUvarint() (uint64, error) {
+ var val uint64
+ for {
+ octet, err := u.reader.ReadBits(8)
+ if err != nil {
+ return 0, err
+ }
+ u.rbits += 8
+
+ val = (val << 7) + uint64(octet[0]&0x7f)
+
+ if octet[0]&0x80 == 0 {
+ return val, nil
+ }
+ }
+}