diff --git a/changed_fields.go b/changed_fields.go new file mode 100644 index 0000000..405f966 --- /dev/null +++ b/changed_fields.go @@ -0,0 +1,37 @@ +package meta + +type ChangedFields struct { + fieldNames map[string]struct{} +} + +func NewChangedFields(fieldCount int) ChangedFields { + cf := ChangedFields{ + fieldNames: make(map[string]struct{}, fieldCount), + } + return cf +} + +func (cf *ChangedFields) Reset() { + if cf.fieldNames == nil { + cf.fieldNames = make(map[string]struct{}) + } else { + for k := range cf.fieldNames { + delete(cf.fieldNames, k) + } + } +} + +func (cf ChangedFields) Changed(field string) bool { + _, ok := cf.fieldNames[field] + return ok +} + +func (cf *ChangedFields) SetChanged(fields ...string) { + for _, field := range fields { + cf.fieldNames[field] = struct{}{} + } +} + +func (cf ChangedFields) Empty() bool { + return len(cf.fieldNames) == 0 +} diff --git a/fields_mask.go b/fields_mask.go index 02e0f4d..2179582 100644 --- a/fields_mask.go +++ b/fields_mask.go @@ -2,49 +2,49 @@ package meta const ( FieldsMaskCapacity = 4 - FieldsMaskItemBitSize = 64 + FieldsMaskPartBitSize = 64 ) func MakeFieldsMaskFromInt64(v int64) FieldsMask { var mask FieldsMask - mask.SetItemFromInt64(0, v) + mask.SetPartFromInt64(0, v) return mask } -type FieldsMaskItem uint64 +type FieldsMaskPart uint64 -func (fmi FieldsMaskItem) FieldIsDirty(index uint64) bool { - return (1< 0 { return true } @@ -52,10 +52,10 @@ func (fm FieldsMask) IsFilled() bool { return false } -func (fm FieldsMask) itemIndex(index uint64) uint64 { - return index / FieldsMaskItemBitSize +func (fm FieldsMask) partIndex(index uint64) uint64 { + return index / FieldsMaskPartBitSize } -func (fm FieldsMask) maskIndex(index uint64) uint64 { - return index % FieldsMaskItemBitSize +func (fm FieldsMask) fieldIndex(index uint64) uint64 { + return index % FieldsMaskPartBitSize } diff --git a/fields_mask_test.go b/fields_mask_test.go index f56e102..476559b 100644 --- a/fields_mask_test.go +++ b/fields_mask_test.go @@ -19,7 +19,7 @@ func TestFieldsMask(t *testing.T) { }) t.Run("filled value", func(t *testing.T) { var mask meta.FieldsMask - mask.SetItemFromInt64(0, 1) + mask.SetPartFromInt64(0, 1) actualIsFilled := mask.IsFilled() require.True(t, actualIsFilled) @@ -38,7 +38,7 @@ func TestFieldsMask(t *testing.T) { var mask meta.FieldsMask var fieldIndex uint64 = 4 - mask.SetItemFromInt64(0, 16) + mask.SetPartFromInt64(0, 16) fieldChanged := mask.FieldChanged(fieldIndex) require.True(t, fieldChanged) @@ -47,7 +47,7 @@ func TestFieldsMask(t *testing.T) { var mask meta.FieldsMask var fieldIndex uint64 = 68 - mask.SetItemFromInt64(1, 16) + mask.SetPartFromInt64(1, 16) fieldChanged := mask.FieldChanged(fieldIndex) require.True(t, fieldChanged) @@ -56,7 +56,7 @@ func TestFieldsMask(t *testing.T) { var mask meta.FieldsMask var fieldIndex uint64 = 131 - mask.SetItemFromInt64(2, 8) + mask.SetPartFromInt64(2, 8) fieldChanged := mask.FieldChanged(fieldIndex) require.True(t, fieldChanged) @@ -65,7 +65,7 @@ func TestFieldsMask(t *testing.T) { var mask meta.FieldsMask var fieldIndex uint64 = 194 - mask.SetItemFromInt64(3, 4) + mask.SetPartFromInt64(3, 4) fieldChanged := mask.FieldChanged(fieldIndex) require.True(t, fieldChanged) @@ -76,7 +76,7 @@ func TestFieldsMask(t *testing.T) { func TestFieldsMaskItem(t *testing.T) { t.Run("FieldIsDirty", func(t *testing.T) { cases := []struct { - maskItem meta.FieldsMaskItem + maskItem meta.FieldsMaskPart expectedDirtyIndexes []uint64 }{ { @@ -102,8 +102,8 @@ func TestFieldsMaskItem(t *testing.T) { } for i, c := range cases { - actualDirtyIndexes := make([]uint64, 0, meta.FieldsMaskItemBitSize) - for j := uint64(0); j < meta.FieldsMaskItemBitSize; j++ { + actualDirtyIndexes := make([]uint64, 0, meta.FieldsMaskPartBitSize) + for j := uint64(0); j < meta.FieldsMaskPartBitSize; j++ { if c.maskItem.FieldIsDirty(j) { actualDirtyIndexes = append(actualDirtyIndexes, j) } diff --git a/go.mod b/go.mod index a64fdf8..f8e9ae8 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,13 @@ require ( git.bit5.ru/backend/msgpack v1.0.0 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.7.3 + github.com/vmihailenco/msgpack/v5 v5.3.5 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/vmihailenco/bufio v0.0.0-20140618134113-fe7b595919de // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3252ab5..3458c15 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.3 h1:dAm0YRdRQlWojc3CrCRgPBzG5f941d0zvAKu7qY4e+I= github.com/stretchr/testify v1.7.3/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -18,6 +19,10 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0 github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/vmihailenco/bufio v0.0.0-20140618134113-fe7b595919de h1:U+I4zEVstMdfNES/2UO8iqkIf214SDMRhdaFTE3A5rA= github.com/vmihailenco/bufio v0.0.0-20140618134113-fe7b595919de/go.mod h1:ghSGoeEoFFkXNguSget72dMA0+OLq3AGZiqRohVojxI= +github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= +github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= gopkg.in/bufio.v1 v1.0.0-20140618132640-567b2bfa514e h1:wGA78yza6bu/mWcc4QfBuIEHEtc06xdiU0X8sY36yUU= gopkg.in/bufio.v1 v1.0.0-20140618132640-567b2bfa514e/go.mod h1:xsQCaysVCudhrYTfzYWe577fCe7Ceci+6qjO2Rdc0Z4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/interface.go b/interface.go index 5d0b03a..0b2b4ad 100644 --- a/interface.go +++ b/interface.go @@ -1,99 +1,83 @@ package meta type Reader interface { - ReadI8(v *int8, field string) error - ReadU8(v *uint8, field string) error - ReadI16(v *int16, field string) error - ReadU16(v *uint16, field string) error - ReadI32(v *int32, field string) error - ReadU32(v *uint32, field string) error - ReadI64(v *int64, field string) error - ReadU64(v *uint64, field string) error + ReadInt8(v *int8, field string) error + ReadInt16(v *int16, field string) error + ReadInt32(v *int32, field string) error + ReadInt64(v *int64, field string) error + + ReadUint8(v *uint8, field string) error + ReadUint16(v *uint16, field string) error + ReadUint32(v *uint32, field string) error + ReadUint64(v *uint64, field string) error + ReadBool(v *bool, field string) error - ReadFloat(v *float32, field string) error - ReadDouble(v *float64, field string) error + + ReadFloat32(v *float32, field string) error + ReadFloat64(v *float64, field string) error + ReadString(v *string, field string) error - ReadBlob(v *[]byte, field string) error + ReadBytes(v *[]byte, field string) error + BeginContainer(field string) error EndContainer() error - GetContainerSize() (int, error) + + ContainerSize() (int, error) + IsContainerAssoc() (bool, error) + Skip() error TryReadMask() (bool, FieldsMask, error) } type Writer interface { - WriteI8(v int8, field string) error - WriteU8(v uint8, field string) error - WriteI16(v int16, field string) error - WriteU16(v uint16, field string) error - WriteI32(v int32, field string) error - WriteU32(v uint32, field string) error - WriteU64(v uint64, field string) error - WriteI64(v int64, field string) error + WriteInt8(v int8, field string) error + WriteInt16(v int16, field string) error + WriteInt32(v int32, field string) error + WriteInt64(v int64, field string) error + + WriteUint8(v uint8, field string) error + WriteUint16(v uint16, field string) error + WriteUint32(v uint32, field string) error + WriteUint64(v uint64, field string) error + WriteBool(v bool, field string) error - WriteFloat(v float32, field string) error - WriteDouble(v float64, field string) error + + WriteFloat32(v float32, field string) error + WriteFloat64(v float64, field string) error + WriteString(v string, field string) error - WriteBlob(v []byte, field string) error - BeginContainer(field string) + WriteBytes(v []byte, field string) error + + BeginContainer(length int, field string) error + BeginAssocContainer(length int, field string) error EndContainer() error - GetData() ([]byte, error) } -type ClassFieldsProps map[string]map[string]string - -type IClassProps interface { - CLASS_ID() uint32 - CLASS_NAME() string - CLASS_PROPS() *map[string]string - CLASS_FIELDS() []string - CLASS_FIELDS_PROPS() *ClassFieldsProps +type Class interface { + ClassId() uint32 + ClassName() string } -type MetaFactory func(classId uint32) (IMetaStruct, error) - -type IMetaStruct interface { - IClassProps - Read(reader Reader) error - Write(writer Writer) error - ReadFields(reader Reader) error - WriteFields(writer Writer) error +type Readable interface { Reset() + Read(Reader) error + ReadFields(Reader) error } -type IMetaDataItem interface { - IClassProps - GetDbTableName() string - GetDbFields() []string - GetOwnerFieldName() string - GetIdFieldName() string - GetIdValue() uint64 - Import(interface{}) - Export([]interface{}) - NewInstance() IMetaDataItem +type Writable interface { + Write(Writer) error + WriteFields(Writer) error } -type IRemovedIds interface { - GetList(classId uint32) []uint64 - Add(classId uint32, id uint64) - HasList(classId uint32) bool +type Struct interface { + Class + Readable + Writable } -type IBitmasked interface { +type Bitmasked interface { + FieldChanged(index uint64) bool SetFieldChanged(index uint64) - HasValue(index uint64) bool - IsMaskFilled() bool - GetMask() FieldsMask -} - -type RPCFactory func(classId uint32) (IRPC, error) - -type IRPC interface { - GetCode() int32 - GetName() string - GetRequest() IMetaStruct - GetResponse() IMetaStruct - SetError(int32, string) - GetError() (int32, string) - Execute(interface{}) error + HasChangedFields() bool + FieldsMask() FieldsMask } diff --git a/meta.go b/meta.go deleted file mode 100644 index 2398539..0000000 --- a/meta.go +++ /dev/null @@ -1,64 +0,0 @@ -package meta - -func Read(reader Reader, createById MetaFactory) (v IMetaStruct, err error) { - return ReadStructGeneric(reader, createById, "") -} - -func Write(writer Writer, m IMetaStruct) error { - return WriteStructGeneric(writer, m, "") -} - -func ReadStructGeneric(reader Reader, createById MetaFactory, field string) (v IMetaStruct, err error) { - if err = reader.BeginContainer(field); err != nil { - return - } - var classId uint32 - if err = reader.ReadU32(&classId, ""); err != nil { - return - } - if v, err = createById(classId); err != nil { - return - } - if err = v.ReadFields(reader); err != nil { - return - } - err = reader.EndContainer() - return -} - -func WriteStructGeneric(writer Writer, m IMetaStruct, field string) error { - writer.BeginContainer(field) - if err := writer.WriteU32(m.CLASS_ID(), ""); err != nil { - return err - } - if err := m.WriteFields(writer); err != nil { - return err - } - return writer.EndContainer() -} - -func ReadStruct(reader Reader, m IMetaStruct, field string) error { - err := reader.BeginContainer(field) - if err != nil { - return err - } - err = m.ReadFields(reader) - if err != nil { - return err - } - return reader.EndContainer() -} - -func WriteStruct(writer Writer, m IMetaStruct, field string) error { - writer.BeginContainer(field) - err := m.WriteFields(writer) - if err != nil { - return err - } - return writer.EndContainer() -} - -func IsRPCFailed(rpc IRPC) bool { - code, _ := rpc.GetError() - return code != 0 -} diff --git a/msgpack_reader.go b/msgpack_reader.go index 2ebbebb..13ce160 100644 --- a/msgpack_reader.go +++ b/msgpack_reader.go @@ -1,412 +1,235 @@ package meta import ( - "github.com/pkg/errors" + "io" - "git.bit5.ru/backend/msgpack" + "github.com/pkg/errors" + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" ) -type msgpReader struct { - stack []*msgpReadState - current *msgpReadState +type msgpackReader struct { + dec *msgpack.Decoder + containers []msgpackContainer } -type msgpReadState struct { - data []interface{} - idx int +type msgpackContainer struct { + length int + assoc bool } -func NewMsgpackReader(bytes []byte) (Reader, error) { - arr := make([]interface{}, 1) - err := msgpack.Unmarshal(bytes, &arr[0]) +func NewMsgpackReader(r io.Reader) Reader { + return &msgpackReader{ + dec: msgpack.NewDecoder(r), + containers: make([]msgpackContainer, 0, 2), + } +} + +func (rd *msgpackReader) currentContainer() msgpackContainer { + if rd == nil || len(rd.containers) == 0 { + return msgpackContainer{} + } + + return rd.containers[len(rd.containers)-1] +} + +func (rd *msgpackReader) ReadInt8(v *int8, field string) error { + tmp, err := rd.dec.DecodeInt8() if err != nil { - return nil, errors.WithStack(err) + return errors.WithStack(err) } - - rd := &msgpReader{} - rd.stack = make([]*msgpReadState, 0, 2) - rd.current = &msgpReadState{arr, 0} - rd.stack = append(rd.stack, rd.current) - - return rd, nil -} - -func (state *msgpReadState) clear() { - state.data = nil -} - -func (state *msgpReadState) read(field string) (v interface{}, err error) { - if len(state.data) <= state.idx { - err = errors.Errorf("No more data for read, field:%s", field) - return - } - v = state.data[state.idx] - state.idx++ - return -} - -func (state *msgpReadState) backCursor() error { - if state.data == nil { - return errors.New("No more data for read") - } - state.idx-- + *v = tmp return nil } -func (state *msgpReadState) skip() error { - if state.data == nil || len(state.data) <= state.idx { - return errors.New("No more data for read") +func (rd *msgpackReader) ReadInt16(v *int16, field string) error { + tmp, err := rd.dec.DecodeInt16() + if err != nil { + return errors.WithStack(err) } - state.idx++ + *v = tmp return nil } -func (state *msgpReadState) size() (int, error) { - return len(state.data) - state.idx, nil // relative to current idx? -} - -func (rd *msgpReader) ReadI8(v *int8, field string) (err error) { - var c int32 - err = rd.ReadI32(&c, field) - *v = int8(c) - return -} - -func (rd *msgpReader) ReadU8(v *uint8, field string) (err error) { - var c uint32 - err = rd.ReadU32(&c, field) - *v = uint8(c) - return -} - -func (rd *msgpReader) ReadI16(v *int16, field string) (err error) { - var c int32 - err = rd.ReadI32(&c, field) - *v = int16(c) - return -} - -func (rd *msgpReader) ReadU16(v *uint16, field string) (err error) { - var c uint32 - err = rd.ReadU32(&c, field) - *v = uint16(c) - return -} - -func (rd *msgpReader) ReadI32(v *int32, field string) error { - value, err := rd.current.read(field) +func (rd *msgpackReader) ReadInt32(v *int32, field string) error { + tmp, err := rd.dec.DecodeInt32() if err != nil { - return err - } - switch t := value.(type) { - case float64: - *v = int32(t) - case int8: - *v = int32(t) - case uint8: - *v = int32(t) - case int16: - *v = int32(t) - case uint16: - *v = int32(t) - case int32: - *v = int32(t) - case uint32: - *v = int32(t) - case int64: - *v = int32(t) - case uint64: - *v = int32(t) - case bool: - if t { - *v = int32(1) - } else { - *v = int32(0) - } - default: - return errors.Errorf("Can't convert to int32 %v (%T), field:%s", t, t, field) + return errors.WithStack(err) } + *v = tmp return nil } -func (rd *msgpReader) ReadU32(v *uint32, field string) error { - value, err := rd.current.read(field) +func (rd *msgpackReader) ReadInt64(v *int64, field string) error { + tmp, err := rd.dec.DecodeInt64() if err != nil { - return err - } - switch t := value.(type) { - case float64: - *v = uint32(t) - case int8: - *v = uint32(t) - case uint8: - *v = uint32(t) - case int16: - *v = uint32(t) - case uint16: - *v = uint32(t) - case int32: - *v = uint32(t) - case uint32: - *v = uint32(t) - case uint64: - *v = uint32(t) - case int64: - *v = uint32(t) - case bool: - if t { - *v = uint32(1) - } else { - *v = uint32(0) - } - default: - return errors.Errorf("Can't convert to uint32 (%T), field:%s", t, field) + return errors.WithStack(err) } + *v = tmp return nil } -func (rd *msgpReader) ReadI64(v *int64, field string) error { - value, err := rd.current.read(field) +func (rd *msgpackReader) ReadUint8(v *uint8, field string) error { + tmp, err := rd.dec.DecodeUint8() if err != nil { - return err - } - switch t := value.(type) { - case float64: - *v = int64(t) - case int8: - *v = int64(t) - case uint8: - *v = int64(t) - case int16: - *v = int64(t) - case uint16: - *v = int64(t) - case int32: - *v = int64(t) - case uint32: - *v = int64(t) - case uint64: - *v = int64(t) - case int64: - *v = int64(t) - default: - return errors.Errorf("Can't convert to uint64 (%T), field:%s", t, field) + return errors.WithStack(err) } + *v = tmp return nil } -func (rd *msgpReader) ReadU64(v *uint64, field string) error { - value, err := rd.current.read(field) +func (rd *msgpackReader) ReadUint16(v *uint16, field string) error { + tmp, err := rd.dec.DecodeUint16() if err != nil { - return err - } - switch t := value.(type) { - case float64: - *v = uint64(t) - case int8: - *v = uint64(t) - case uint8: - *v = uint64(t) - case int16: - *v = uint64(t) - case uint16: - *v = uint64(t) - case int32: - *v = uint64(t) - case uint32: - *v = uint64(t) - case uint64: - *v = uint64(t) - case int64: - *v = uint64(t) - default: - return errors.Errorf("Can't convert to uint64 (%T), field:%s", t, field) + return errors.WithStack(err) } + *v = tmp return nil } -func (rd *msgpReader) ReadBool(v *bool, field string) (err error) { - value, err := rd.current.read(field) +func (rd *msgpackReader) ReadUint32(v *uint32, field string) error { + tmp, err := rd.dec.DecodeUint32() if err != nil { - return err - } - switch t := value.(type) { - case bool: - *v = bool(t) - default: - return errors.Errorf("Can't convert to bool (%T), field:%s", t, field) + return errors.WithStack(err) } + *v = tmp return nil } -func (rd *msgpReader) ReadFloat(v *float32, field string) error { - value, err := rd.current.read(field) +func (rd *msgpackReader) ReadUint64(v *uint64, field string) error { + tmp, err := rd.dec.DecodeUint64() if err != nil { - return err - } - switch t := value.(type) { - case float64: - //NOTE: possible type coersion error! - *v = float32(t) - case float32: - *v = t - case int64: - *v = float32(t) - case uint64: - *v = float32(t) - default: - return errors.Errorf("Can't convert to float32 (%T), field:%s", t, field) + return errors.WithStack(err) } + *v = tmp return nil } -func (rd *msgpReader) ReadDouble(v *float64, field string) error { - value, err := rd.current.read(field) +func (rd *msgpackReader) ReadBool(v *bool, field string) error { + tmp, err := rd.dec.DecodeBool() if err != nil { - return err - } - switch t := value.(type) { - case float64: - *v = t - case float32: - *v = float64(t) - case int64: - *v = float64(t) - case uint64: - *v = float64(t) - default: - return errors.Errorf("Can't convert to float64 (%T), field:%s", t, field) + return errors.WithStack(err) } + *v = tmp return nil } -func (rd *msgpReader) ReadString(v *string, field string) error { - value, err := rd.current.read(field) +func (rd *msgpackReader) ReadFloat32(v *float32, field string) error { + tmp, err := rd.dec.DecodeFloat32() if err != nil { - return err - } - switch t := value.(type) { - case string: - *v = t - - default: - return errors.Errorf("Can't convert to string %s (%T)", field, t) + return errors.WithStack(err) } + *v = tmp return nil } -func (rd *msgpReader) ReadBlob(v *[]byte, field string) error { - value, err := rd.current.read(field) +func (rd *msgpackReader) ReadFloat64(v *float64, field string) error { + tmp, err := rd.dec.DecodeFloat64() if err != nil { - return err - } - switch t := value.(type) { - case []byte: - *v = t - - default: - return errors.Errorf("Can't convert to []byte %s (%T)", field, t) + return errors.WithStack(err) } + *v = tmp return nil } -func (rd *msgpReader) BeginContainer(field string) error { - v, err := rd.current.read(field) +func (rd *msgpackReader) ReadString(v *string, field string) error { + tmp, err := rd.dec.DecodeString() if err != nil { - return err - } - - if data, ok := v.([]interface{}); ok { - cur := &msgpReadState{} - cur.data = data - rd.stack = append(rd.stack, cur) - rd.current = cur - return nil - } else { - return errors.Errorf("Next value isn't array but:%T, field:%s", v, field) - } -} - -func (rd *msgpReader) EndContainer() error { - rd.current.clear() - rd.stack[len(rd.stack)-1] = nil - rd.stack = rd.stack[0 : len(rd.stack)-1] - if len(rd.stack) > 0 { - rd.current = rd.stack[len(rd.stack)-1] - } else { - rd.current = nil + return errors.WithStack(err) } + *v = tmp return nil } -func (rd *msgpReader) GetContainerSize() (int, error) { - return rd.current.size() -} - -func (rd *msgpReader) Skip() error { - return rd.current.skip() -} - -func (rd *msgpReader) TryReadMask() (bool, FieldsMask, error) { - value, err := rd.current.read("mask") +func (rd *msgpackReader) ReadBytes(v *[]byte, field string) error { + tmp, err := rd.dec.DecodeBytes() if err != nil { - return false, FieldsMask{}, err + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func (rd *msgpackReader) BeginContainer(field string) error { + code, err := rd.dec.PeekCode() + if err != nil { + return errors.WithStack(err) } - if value == nil { //mask detected - newMask, err := rd.ReadNewMask() - if err == nil { - return true, newMask, nil - } - - if err := rd.current.backCursor(); err != nil { - return false, FieldsMask{}, err - } - - oldMask, err := rd.ReadOldMask() + switch { + case msgpcode.IsFixedMap(code), code == msgpcode.Map16, code == msgpcode.Map32: + l, err := rd.dec.DecodeMapLen() if err != nil { - return false, FieldsMask{}, err + return errors.WithStack(err) } - return true, oldMask, nil - } + rd.containers = append(rd.containers, msgpackContainer{ + length: l, + assoc: true, + }) - if err := rd.current.backCursor(); err != nil { - return false, FieldsMask{}, err - } + case msgpcode.IsFixedArray(code), code == msgpcode.Array16, code == msgpcode.Array32: + l, err := rd.dec.DecodeArrayLen() + if err != nil { + return errors.WithStack(err) + } + rd.containers = append(rd.containers, msgpackContainer{ + length: l, + assoc: false, + }) - return false, FieldsMask{}, nil + default: + return errors.Errorf("there is no container for field `%s`", field) + } + return nil } -func (rd *msgpReader) ReadNewMask() (FieldsMask, error) { - if err := rd.BeginContainer("new_mask"); err != nil { - return FieldsMask{}, err +func (rd *msgpackReader) EndContainer() error { + if len(rd.containers) == 0 { + return errors.New("there is no open containers") } - maskSize, err := rd.GetContainerSize() + rd.containers = rd.containers[:len(rd.containers)-1] + return nil +} + +func (rd *msgpackReader) ContainerSize() (int, error) { + return rd.currentContainer().length, nil +} + +func (rd *msgpackReader) IsContainerAssoc() (bool, error) { + return rd.currentContainer().assoc, nil +} + +func (rd *msgpackReader) Skip() error { + return errors.WithStack(rd.dec.Skip()) +} + +func (rd *msgpackReader) TryReadMask() (bool, FieldsMask, error) { + code, err := rd.dec.PeekCode() if err != nil { - return FieldsMask{}, err + return false, FieldsMask{}, errors.WithStack(err) + } + + if code != msgpcode.Nil { + return false, FieldsMask{}, nil + } + + if err := rd.dec.Skip(); err != nil { + return false, FieldsMask{}, errors.WithStack(err) } var mask FieldsMask - for i := 0; i < maskSize; i++ { - var maskItem uint64 - if err := rd.ReadU64(&maskItem, "mask_item"); err != nil { - return FieldsMask{}, err + maskLen, err := rd.dec.DecodeArrayLen() + if err != nil { + return false, FieldsMask{}, errors.WithStack(err) + } + + for i := 0; i < maskLen; i++ { + maskPart, err := rd.dec.DecodeUint64() + if err != nil { + return false, FieldsMask{}, errors.WithStack(err) } - mask.SetItemFromUint64(i, maskItem) + mask.SetPartFromUint64(i, maskPart) } - if err := rd.EndContainer(); err != nil { - return FieldsMask{}, err - } - - return mask, nil -} - -func (rd *msgpReader) ReadOldMask() (FieldsMask, error) { - var mask int64 - if err := rd.ReadI64(&mask, "mask"); err != nil { - return FieldsMask{}, err - } - - return MakeFieldsMaskFromInt64(mask), nil + return true, mask, nil } diff --git a/msgpack_reader_test.go b/msgpack_reader_test.go new file mode 100644 index 0000000..684af6d --- /dev/null +++ b/msgpack_reader_test.go @@ -0,0 +1,153 @@ +package meta_test + +import ( + "bytes" + "encoding/hex" + "testing" + + "git.bit5.ru/backend/meta" + "github.com/stretchr/testify/require" +) + +func TestMsgpackReader(t *testing.T) { + t.Run("reading struct as map", func(t *testing.T) { + // {"f1":"blabla","f3":[2,4,6],"f2":{"field":1},"f4":[{"field":10},{"field":1024}]} + src := "84a26631a6626c61626c61a2663393020406a2663281a56669656c6401a266349281a56669656c640a81a56669656c64cd0400" + + expected := TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + } + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackReader(bytes.NewReader(data)) + + var actual TestParent + readErr := actual.Read(rdr) + require.NoError(t, readErr) + require.EqualValues(t, expected, actual) + }) + + t.Run("reading struct as array with maps", func(t *testing.T) { + // ["blabla",{"field":1},[2,4,6],[{"field":10},{"field":1024}]] + src := "94a6626c61626c6181a56669656c6401930204069281a56669656c640a81a56669656c64cd0400" + + expected := TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + } + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackReader(bytes.NewReader(data)) + + var actual TestParent + readErr := actual.Read(rdr) + require.NoError(t, readErr) + require.EqualValues(t, expected, actual) + }) + + t.Run("reading struct only as array", func(t *testing.T) { + // ["blabla",[1],[2,4,6],[[10],[1024]]] + src := "94a6626c61626c6191019302040692910a91cd0400" + + expected := TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + } + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackReader(bytes.NewReader(data)) + + var actual TestParent + readErr := actual.Read(rdr) + require.NoError(t, readErr) + require.EqualValues(t, expected, actual) + }) + + t.Run("reading child struct as map", func(t *testing.T) { + // {"f":"qwerty","f1":"blabla","f3":[2,4,6],"f2":{"field":1},"f4":[{"field":10},{"field":1024}]} + src := "85a166a6717765727479a26631a6626c61626c61a2663393020406a2663281a56669656c6401a266349281a56669656c640a81a56669656c64cd0400" + + expected := TestChild{ + Field: "qwerty", + TestParent: TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + }, + } + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackReader(bytes.NewReader(data)) + + var actual TestChild + readErr := actual.Read(rdr) + require.NoError(t, readErr) + require.EqualValues(t, expected, actual) + }) + + t.Run("reading child struct as array", func(t *testing.T) { + // ["blabla",[1],[2,4,6],[[10],[1024]],"qwerty"] + src := "95a6626c61626c6191019302040692910a91cd0400a6717765727479" + + expected := TestChild{ + Field: "qwerty", + TestParent: TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + }, + } + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackReader(bytes.NewReader(data)) + + var actual TestChild + readErr := actual.Read(rdr) + require.NoError(t, readErr) + require.EqualValues(t, expected, actual) + }) +} diff --git a/msgpack_writer.go b/msgpack_writer.go index b8316bd..90236d3 100644 --- a/msgpack_writer.go +++ b/msgpack_writer.go @@ -4,126 +4,163 @@ import ( "io" "github.com/pkg/errors" - - "git.bit5.ru/backend/msgpack" + "github.com/vmihailenco/msgpack/v5" ) -type msgpWriter struct { - stack []*msgpWriteContainer - current *msgpWriteContainer +type msgpackWriter struct { + enc *msgpack.Encoder + containers []msgpackContainer } -type msgpWriteContainer struct { - arr []interface{} -} - -func NewMsgpackWriter() Writer { - wr := &msgpWriter{} - wr.stack = make([]*msgpWriteContainer, 0, 1) - wr.current = newMsgpWriteContainer() - wr.stack = append(wr.stack, wr.current) - return wr -} - -func newMsgpWriteContainer() *msgpWriteContainer { - return &msgpWriteContainer{make([]interface{}, 0, 1)} -} - -func (state *msgpWriteContainer) add(v interface{}) error { - state.arr = append(state.arr, v) - return nil -} - -func (wr *msgpWriter) WriteI8(v int8, field string) error { - return wr.current.add(v) -} - -func (wr *msgpWriter) WriteU8(v uint8, field string) error { - return wr.current.add(v) -} -func (wr *msgpWriter) WriteI16(v int16, field string) error { - return wr.current.add(v) -} -func (wr *msgpWriter) WriteU16(v uint16, field string) error { - return wr.current.add(v) -} - -func (wr *msgpWriter) WriteI32(v int32, field string) error { - return wr.current.add(v) -} -func (wr *msgpWriter) WriteU32(v uint32, field string) error { - return wr.current.add(v) -} - -func (wr *msgpWriter) WriteI64(v int64, field string) error { - return wr.current.add(v) -} -func (wr *msgpWriter) WriteU64(v uint64, field string) error { - return wr.current.add(v) -} - -func (wr *msgpWriter) WriteBool(v bool, field string) error { - return wr.current.add(v) -} - -func (wr *msgpWriter) WriteFloat(v float32, field string) error { - return wr.current.add(v) -} - -func (wr *msgpWriter) WriteDouble(v float64, field string) error { - return wr.current.add(v) -} - -func (wr *msgpWriter) WriteString(v string, field string) error { - return wr.current.add(v) -} - -func (wr *msgpWriter) WriteBlob(v []byte, field string) error { - return wr.current.add(v) -} - -func (wr *msgpWriter) BeginContainer(field string) { - wr.current = newMsgpWriteContainer() - wr.stack = append(wr.stack, wr.current) -} - -func (wr *msgpWriter) EndContainer() error { - if len(wr.stack) <= 1 { - return errors.New("No open container") +func NewMsgpackWriter(w io.Writer) Writer { + return &msgpackWriter{ + enc: msgpack.NewEncoder(w), + containers: make([]msgpackContainer, 0, 1), } - last := wr.current - wr.stack[len(wr.stack)-1] = nil - wr.stack = wr.stack[0 : len(wr.stack)-1] - wr.current = wr.stack[len(wr.stack)-1] +} - bytes, err := msgpack.Marshal(last.arr) - if err != nil { +func (wr *msgpackWriter) currentContainer() msgpackContainer { + if wr == nil || len(wr.containers) == 0 { + return msgpackContainer{} + } + return wr.containers[len(wr.containers)-1] +} + +func (wr *msgpackWriter) writeFieldName(field string) error { + if !wr.currentContainer().assoc { + return nil + } + + return errors.WithStack(wr.enc.EncodeString(field)) +} + +func (wr *msgpackWriter) WriteInt8(v int8, field string) error { + if err := wr.writeFieldName(field); err != nil { return err } + return errors.WithStack(wr.enc.EncodeInt(int64(v))) +} - //NOTE: using custom msgpack encoder - wr.current.add(&msgpCustomBytes{bytes}) +func (wr *msgpackWriter) WriteInt16(v int16, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeInt(int64(v))) +} +func (wr *msgpackWriter) WriteInt32(v int32, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeInt(int64(v))) +} + +func (wr *msgpackWriter) WriteInt64(v int64, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeInt(v)) +} + +func (wr *msgpackWriter) WriteUint8(v uint8, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeUint(uint64(v))) +} + +func (wr *msgpackWriter) WriteUint16(v uint16, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeUint(uint64(v))) +} + +func (wr *msgpackWriter) WriteUint32(v uint32, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeUint(uint64(v))) +} + +func (wr *msgpackWriter) WriteUint64(v uint64, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeUint(uint64(v))) +} + +func (wr *msgpackWriter) WriteBool(v bool, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeBool(v)) +} + +func (wr *msgpackWriter) WriteFloat32(v float32, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeFloat32(v)) +} + +func (wr *msgpackWriter) WriteFloat64(v float64, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeFloat64(v)) +} + +func (wr *msgpackWriter) WriteString(v string, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeString(v)) +} + +func (wr *msgpackWriter) WriteBytes(v []byte, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeBytes(v)) +} + +func (wr *msgpackWriter) BeginContainer(length int, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + if err := wr.enc.EncodeArrayLen(length); err != nil { + return errors.WithStack(err) + } + wr.containers = append(wr.containers, msgpackContainer{ + length: length, + assoc: false, + }) return nil } -func (wr *msgpWriter) GetData() ([]byte, error) { - if len(wr.stack) != 1 { - return nil, errors.New("Stack isn't empty") +func (wr *msgpackWriter) BeginAssocContainer(length int, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err } - - if len(wr.current.arr) != 1 { - return nil, errors.New("Arr size isn't valid") + if err := wr.enc.EncodeMapLen(length); err != nil { + return errors.WithStack(err) } - - return msgpack.Marshal(wr.current.arr[0]) + wr.containers = append(wr.containers, msgpackContainer{ + length: length, + assoc: true, + }) + return nil } -type msgpCustomBytes struct { - v []byte +func (wr *msgpackWriter) EndContainer() error { + if len(wr.containers) == 0 { + return errors.New("there is no open containers") + } + wr.containers = wr.containers[:len(wr.containers)-1] + return nil } -func (msgp *msgpCustomBytes) EncodeMsgpack(writer io.Writer) error { - _, err := writer.Write(msgp.v) - return errors.WithStack(err) +func (wr *msgpackWriter) GetData() ([]byte, error) { + return nil, nil } diff --git a/msgpack_writer_test.go b/msgpack_writer_test.go new file mode 100644 index 0000000..56f5187 --- /dev/null +++ b/msgpack_writer_test.go @@ -0,0 +1,63 @@ +package meta_test + +import ( + "bytes" + "encoding/hex" + "testing" + + "git.bit5.ru/backend/meta" + "github.com/stretchr/testify/require" +) + +func TestMsgpackWriter(t *testing.T) { + t.Run("write struct", func(t *testing.T) { + var buf bytes.Buffer + wr := meta.NewMsgpackWriter(&buf) + + s := TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + } + + err := s.Write(wr) + require.NoError(t, err) + + expected := "84a26631a6626c61626c61a2663281a56669656c6401a2663393020406a266349281a56669656c640a81a56669656c64cd0400" + actual := hex.EncodeToString(buf.Bytes()) + require.EqualValues(t, expected, actual) + }) + + t.Run("write child struct", func(t *testing.T) { + var buf bytes.Buffer + wr := meta.NewMsgpackWriter(&buf) + + s := TestChild{ + Field: "qwerty", + TestParent: TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + }, + } + + err := s.Write(wr) + require.NoError(t, err) + + expected := "85a26631a6626c61626c61a2663281a56669656c6401a2663393020406a266349281a56669656c640a81a56669656c64cd0400a166a6717765727479" + actual := hex.EncodeToString(buf.Bytes()) + require.EqualValues(t, expected, actual) + }) +} diff --git a/structs_test.go b/structs_test.go new file mode 100644 index 0000000..6af51f5 --- /dev/null +++ b/structs_test.go @@ -0,0 +1,581 @@ +package meta_test + +import ( + "git.bit5.ru/backend/meta" + "github.com/pkg/errors" +) + +type TestParent struct { + Field1 string `json:"f1" msgpack:"f1"` + Field2 TestFoo `json:"f2" msgpack:"f2"` + Field3 []int8 `json:"f3" msgpack:"f3"` + Field4 []TestFoo `json:"f4" msgpack:"f4"` +} + +var _TestParentRequiredFields map[string]struct{} + +func init() { + _TestParentRequiredFields = map[string]struct{}{ + "f1": {}, + "f2": {}, + "f3": {}, + "f4": {}, + } +} + +func (s *TestParent) Reset() { + + s.Field1 = "" + s.Field2.Reset() + if s.Field3 == nil { + s.Field3 = make([]int8, 0) + } else { + s.Field3 = s.Field3[:0] + } + if s.Field4 == nil { + s.Field4 = make([]TestFoo, 0) + } else { + s.Field4 = s.Field4[:0] + } + +} + +func (s *TestParent) Read(reader meta.Reader) error { + if err := reader.BeginContainer(""); err != nil { + return err + } + if err := s.ReadFields(reader); err != nil { + return err + } + return reader.EndContainer() +} + +func (s *TestParent) ReadFields(reader meta.Reader) error { + s.Reset() + + readAsMap, err := reader.IsContainerAssoc() + if err != nil { + return err + } + if readAsMap { + return s.readFieldsAssociative(reader) + } + return s.readFields(reader) +} + +func (s *TestParent) readFieldsAssociative(reader meta.Reader) error { + size, err := reader.ContainerSize() + if err != nil { + return err + } + + readFields := make(map[string]struct{}, 4) + for ; size > 0; size-- { + var field string + if err := reader.ReadString(&field, ""); err != nil { + return err + } + + switch field { + case "f1": + if err := reader.ReadString(&s.Field1, "f1"); err != nil { + return err + } + + case "f2": + if err := s.Field2.Read(reader); err != nil { + return err + } + + case "f3": + if err := reader.BeginContainer("f3"); err != nil { + return err + } + field3Size, err := reader.ContainerSize() + if err != nil { + return err + } + for ; field3Size > 0; field3Size-- { + var tmpField3 int8 + if err := reader.ReadInt8(&tmpField3, ""); err != nil { + return err + } + + s.Field3 = append(s.Field3, tmpField3) + } + if err := reader.EndContainer(); err != nil { + return err + } + + case "f4": + if err := reader.BeginContainer("f4"); err != nil { + return err + } + field4Size, err := reader.ContainerSize() + if err != nil { + return err + } + for ; field4Size > 0; field4Size-- { + var tmpField4 TestFoo + if err := tmpField4.Read(reader); err != nil { + return err + } + + s.Field4 = append(s.Field4, tmpField4) + } + if err := reader.EndContainer(); err != nil { + return err + } + + default: + return errors.Errorf("unexpected field `%s`", field) + } + + readFields[field] = struct{}{} + } + + for field := range _TestParentRequiredFields { + if _, ok := readFields[field]; !ok { + return errors.Errorf("field `%s` is not present", field) + } + } + + return nil +} + +func (s *TestParent) readFields(reader meta.Reader) error { + + contSize, err := reader.ContainerSize() + if err != nil { + return err + } + + if contSize < 4 { + contSize = 4 + } + + if contSize <= 0 { + return nil + } + + contSize-- + if err := reader.ReadString(&s.Field1, "f1"); err != nil { + return err + } + + if contSize <= 0 { + return nil + } + + contSize-- + if err := s.Field2.Read(reader); err != nil { + return err + } + + if contSize <= 0 { + return nil + } + + contSize-- + if err := reader.BeginContainer("f3"); err != nil { + return err + } + field3Size, err := reader.ContainerSize() + if err != nil { + return err + } + for ; field3Size > 0; field3Size-- { + var tmpField3 int8 + if err := reader.ReadInt8(&tmpField3, ""); err != nil { + return err + } + + s.Field3 = append(s.Field3, tmpField3) + } + if err := reader.EndContainer(); err != nil { + return err + } + + if contSize <= 0 { + return nil + } + + contSize-- + if err := reader.BeginContainer("f4"); err != nil { + return err + } + field4Size, err := reader.ContainerSize() + if err != nil { + return err + } + for ; field4Size > 0; field4Size-- { + var tmpField4 TestFoo + if err := tmpField4.Read(reader); err != nil { + return err + } + + s.Field4 = append(s.Field4, tmpField4) + } + if err := reader.EndContainer(); err != nil { + return err + } + + return nil +} + +func (s *TestParent) Write(writer meta.Writer) error { + if err := writer.BeginAssocContainer(4, ""); err != nil { + return err + } + if err := s.WriteFields(writer); err != nil { + return err + } + return writer.EndContainer() +} + +func (s *TestParent) WriteFields(writer meta.Writer) error { + + if err := writer.WriteString(s.Field1, "f1"); err != nil { + return err + } + + if err := writer.BeginAssocContainer(1, "f2"); err != nil { + return err + } + if err := s.Field2.WriteFields(writer); err != nil { + return err + } + if err := writer.EndContainer(); err != nil { + return err + } + + if err := writer.BeginContainer(len(s.Field3), "f3"); err != nil { + return err + } + for _, v := range s.Field3 { + if err := writer.WriteInt8(v, ""); err != nil { + return err + } + + } + if err := writer.EndContainer(); err != nil { + return err + } + + if err := writer.BeginContainer(len(s.Field4), "f4"); err != nil { + return err + } + for _, v := range s.Field4 { + if err := writer.BeginAssocContainer(1, ""); err != nil { + return err + } + if err := v.WriteFields(writer); err != nil { + return err + } + if err := writer.EndContainer(); err != nil { + return err + } + + } + if err := writer.EndContainer(); err != nil { + return err + } + + return nil +} + +type TestChild struct { + TestParent + + Field string `json:"f" msgpack:"f"` +} + +var _TestChildRequiredFields map[string]struct{} + +func init() { + _TestChildRequiredFields = map[string]struct{}{ + "f": {}, + } +} + +func (s *TestChild) Reset() { + s.TestParent.Reset() + + s.Field = "" + +} + +func (s *TestChild) Read(reader meta.Reader) error { + if err := reader.BeginContainer(""); err != nil { + return err + } + if err := s.ReadFields(reader); err != nil { + return err + } + return reader.EndContainer() +} + +func (s *TestChild) ReadFields(reader meta.Reader) error { + s.Reset() + + readAsMap, err := reader.IsContainerAssoc() + if err != nil { + return err + } + if readAsMap { + return s.readFieldsAssociative(reader) + } + return s.readFields(reader) +} + +func (s *TestChild) readFieldsAssociative(reader meta.Reader) error { + size, err := reader.ContainerSize() + if err != nil { + return err + } + + readFields := make(map[string]struct{}, 5) + for ; size > 0; size-- { + var field string + if err := reader.ReadString(&field, ""); err != nil { + return err + } + + switch field { + case "f1": + if err := reader.ReadString(&s.Field1, "f1"); err != nil { + return err + } + + case "f2": + if err := s.Field2.Read(reader); err != nil { + return err + } + + case "f3": + if err := reader.BeginContainer("f3"); err != nil { + return err + } + field3Size, err := reader.ContainerSize() + if err != nil { + return err + } + for ; field3Size > 0; field3Size-- { + var tmpField3 int8 + if err := reader.ReadInt8(&tmpField3, ""); err != nil { + return err + } + + s.Field3 = append(s.Field3, tmpField3) + } + if err := reader.EndContainer(); err != nil { + return err + } + + case "f4": + if err := reader.BeginContainer("f4"); err != nil { + return err + } + field4Size, err := reader.ContainerSize() + if err != nil { + return err + } + for ; field4Size > 0; field4Size-- { + var tmpField4 TestFoo + if err := tmpField4.Read(reader); err != nil { + return err + } + + s.Field4 = append(s.Field4, tmpField4) + } + if err := reader.EndContainer(); err != nil { + return err + } + + case "f": + if err := reader.ReadString(&s.Field, "f"); err != nil { + return err + } + + default: + return errors.Errorf("unexpected field `%s`", field) + } + + readFields[field] = struct{}{} + } + + for field := range _TestChildRequiredFields { + if _, ok := readFields[field]; !ok { + return errors.Errorf("field `%s` is not present", field) + } + } + + return nil +} + +func (s *TestChild) readFields(reader meta.Reader) error { + + contSize, err := reader.ContainerSize() + if err != nil { + return err + } + + if contSize < 1 { + contSize = 1 + } + + if err := s.TestParent.ReadFields(reader); err != nil { + return err + } + + if contSize <= 0 { + return nil + } + + contSize-- + if err := reader.ReadString(&s.Field, "f"); err != nil { + return err + } + + return nil +} + +func (s *TestChild) Write(writer meta.Writer) error { + if err := writer.BeginAssocContainer(5, ""); err != nil { + return err + } + if err := s.WriteFields(writer); err != nil { + return err + } + return writer.EndContainer() +} + +func (s *TestChild) WriteFields(writer meta.Writer) error { + if err := s.TestParent.WriteFields(writer); err != nil { + return err + } + + if err := writer.WriteString(s.Field, "f"); err != nil { + return err + } + + return nil +} + +type TestFoo struct { + Field int64 `json:"field" msgpack:"field"` +} + +var _TestFooRequiredFields map[string]struct{} + +func init() { + _TestFooRequiredFields = map[string]struct{}{ + "field": {}, + } +} + +func (s *TestFoo) Reset() { + + s.Field = 0 + +} + +func (s *TestFoo) Read(reader meta.Reader) error { + if err := reader.BeginContainer(""); err != nil { + return err + } + if err := s.ReadFields(reader); err != nil { + return err + } + return reader.EndContainer() +} + +func (s *TestFoo) ReadFields(reader meta.Reader) error { + s.Reset() + + readAsMap, err := reader.IsContainerAssoc() + if err != nil { + return err + } + if readAsMap { + return s.readFieldsAssociative(reader) + } + return s.readFields(reader) +} + +func (s *TestFoo) readFieldsAssociative(reader meta.Reader) error { + size, err := reader.ContainerSize() + if err != nil { + return err + } + + readFields := make(map[string]struct{}, 1) + for ; size > 0; size-- { + var field string + if err := reader.ReadString(&field, ""); err != nil { + return err + } + + switch field { + case "field": + if err := reader.ReadInt64(&s.Field, "field"); err != nil { + return err + } + + default: + return errors.Errorf("unexpected field `%s`", field) + } + + readFields[field] = struct{}{} + } + + for field := range _TestFooRequiredFields { + if _, ok := readFields[field]; !ok { + return errors.Errorf("field `%s` is not present", field) + } + } + + return nil +} + +func (s *TestFoo) readFields(reader meta.Reader) error { + + contSize, err := reader.ContainerSize() + if err != nil { + return err + } + + if contSize < 1 { + contSize = 1 + } + + if contSize <= 0 { + return nil + } + + contSize-- + if err := reader.ReadInt64(&s.Field, "field"); err != nil { + return err + } + + return nil +} + +func (s *TestFoo) Write(writer meta.Writer) error { + if err := writer.BeginAssocContainer(1, ""); err != nil { + return err + } + if err := s.WriteFields(writer); err != nil { + return err + } + return writer.EndContainer() +} + +func (s *TestFoo) WriteFields(writer meta.Writer) error { + + if err := writer.WriteInt64(s.Field, "field"); err != nil { + return err + } + + return nil +}