diff --git a/interface.go b/interface.go index 954b58f..2632267 100644 --- a/interface.go +++ b/interface.go @@ -22,6 +22,9 @@ type Reader interface { BeginContainer(field string) error EndContainer() error + BeginCollection(field string) error + EndCollection() error + ContainerSize() (int, error) IsContainerAssoc() (bool, error) diff --git a/msgpack_reader.go b/msgpack_reader.go index fd83985..0799666 100644 --- a/msgpack_reader.go +++ b/msgpack_reader.go @@ -1,7 +1,6 @@ package meta import ( - "bytes" "io" "github.com/pkg/errors" @@ -9,9 +8,10 @@ import ( "github.com/vmihailenco/msgpack/v5/msgpcode" ) -var FieldNotFound = errors.New("field not found") - -var NoOpenContainer = errors.New("there is no open container") +var ( + FieldNotFound = errors.New("field not found") + NoOpenContainer = errors.New("there is no open container") +) type msgpackReader struct { dec *msgpack.Decoder @@ -20,514 +20,83 @@ type msgpackReader struct { } type readContainer struct { - started bool - length int - assoc bool - values map[string]msgpack.RawMessage - reader io.Reader - readCnt int + length int } func NewMsgpackReader(r io.Reader) Reader { return &msgpackReader{ dec: msgpack.NewDecoder(r), stack: make([]readContainer, 0, 2), - curr: readContainer{ - reader: r, - }, + curr: readContainer{}, } } -func (rd *msgpackReader) readField() (string, error) { - field, err := rd.dec.DecodeString() - if err != nil { - return "", errors.WithStack(err) - } - return field, nil -} - func (rd *msgpackReader) ReadInt8(v *int8, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeInt8(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeInt8(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeInt8(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeInt8(rd.dec, v) } func (rd *msgpackReader) ReadInt16(v *int16, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeInt16(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeInt16(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeInt16(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeInt16(rd.dec, v) } func (rd *msgpackReader) ReadInt32(v *int32, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeInt32(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeInt32(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeInt32(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeInt32(rd.dec, v) } func (rd *msgpackReader) ReadInt64(v *int64, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeInt64(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeInt64(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeInt64(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeInt64(rd.dec, v) } func (rd *msgpackReader) ReadUint8(v *uint8, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeUint8(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeUint8(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeUint8(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeUint8(rd.dec, v) } func (rd *msgpackReader) ReadUint16(v *uint16, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeUint16(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeUint16(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeUint16(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeUint16(rd.dec, v) } func (rd *msgpackReader) ReadUint32(v *uint32, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeUint32(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeUint32(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeUint32(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeUint32(rd.dec, v) } func (rd *msgpackReader) ReadUint64(v *uint64, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeUint64(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeUint64(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeUint64(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeUint64(rd.dec, v) } func (rd *msgpackReader) ReadBool(v *bool, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeBool(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeBool(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeBool(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeBool(rd.dec, v) } func (rd *msgpackReader) ReadFloat32(v *float32, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeFloat32(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeFloat32(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeFloat32(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeFloat32(rd.dec, v) } func (rd *msgpackReader) ReadFloat64(v *float64, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeFloat64(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeFloat64(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeFloat64(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeFloat64(rd.dec, v) } func (rd *msgpackReader) ReadString(v *string, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeString(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeString(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeString(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeString(rd.dec, v) } func (rd *msgpackReader) ReadBytes(v *[]byte, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeBytes(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeBytes(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeBytes(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound + return decodeBytes(rd.dec, v) } func (rd *msgpackReader) BeginContainer(targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return rd.beginContainer(targetField) - } - - if b, ok := rd.curr.values[targetField]; ok { - rd.dec.Reset(bytes.NewReader(b)) - return rd.beginContainer(targetField) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return rd.beginContainer(targetField) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - rd.curr.readCnt = i + 1 - } - - return FieldNotFound -} - -func (rd *msgpackReader) beginContainer(field string) error { - code, err := rd.dec.PeekCode() - if err != nil { - return errors.WithStack(err) - } - - switch { - case msgpcode.IsFixedMap(code), code == msgpcode.Map16, code == msgpcode.Map32: - l, err := rd.dec.DecodeMapLen() - if err != nil { - return errors.WithStack(err) - } - rd.stack = append(rd.stack, rd.curr) - rd.curr = readContainer{ - started: true, - length: l, - assoc: true, - values: make(map[string]msgpack.RawMessage, l), - reader: rd.dec.Buffered(), - } - - case msgpcode.IsFixedArray(code), code == msgpcode.Array16, code == msgpcode.Array32: - l, err := rd.dec.DecodeArrayLen() - if err != nil { - return errors.WithStack(err) - } - rd.stack = append(rd.stack, rd.curr) - rd.curr = readContainer{ - started: true, - length: l, - assoc: false, - reader: rd.dec.Buffered(), - } - - default: - return errors.Errorf("there is no container for field `%s`", field) - } - return nil + return rd.beginContainer(targetField) } func (rd *msgpackReader) EndContainer() error { - if len(rd.stack) == 0 { - return NoOpenContainer - } - rd.curr = rd.stack[len(rd.stack)-1] - rd.stack = rd.stack[:len(rd.stack)-1] - rd.dec.Reset(rd.curr.reader) - return nil + return rd.endContainer() +} + +func (rd *msgpackReader) BeginCollection(targetField string) error { + return rd.beginContainer(targetField) +} + +func (rd *msgpackReader) EndCollection() error { + return rd.endContainer() } func (rd *msgpackReader) ContainerSize() (int, error) { @@ -535,7 +104,7 @@ func (rd *msgpackReader) ContainerSize() (int, error) { } func (rd *msgpackReader) IsContainerAssoc() (bool, error) { - return rd.curr.assoc, nil + return false, nil } func (rd *msgpackReader) Skip() error { @@ -543,10 +112,6 @@ func (rd *msgpackReader) Skip() error { } func (rd *msgpackReader) TryReadMask() (bool, FieldsMask, error) { - if rd.curr.assoc { - return false, FieldsMask{}, nil - } - maskLen, err := rd.dec.DecodeArrayLen() if err != nil { if err == io.EOF { @@ -567,6 +132,38 @@ func (rd *msgpackReader) TryReadMask() (bool, FieldsMask, error) { return true, mask, nil } +func (rd *msgpackReader) beginContainer(field string) error { + code, err := rd.dec.PeekCode() + if err != nil { + return errors.WithStack(err) + } + + switch { + case msgpcode.IsFixedArray(code), code == msgpcode.Array16, code == msgpcode.Array32: + l, err := rd.dec.DecodeArrayLen() + if err != nil { + return errors.WithStack(err) + } + rd.stack = append(rd.stack, rd.curr) + rd.curr = readContainer{ + length: l, + } + + default: + return errors.Errorf("field `%s` is not an array", field) + } + return nil +} + +func (rd *msgpackReader) endContainer() error { + if len(rd.stack) == 0 { + return NoOpenContainer + } + rd.curr = rd.stack[len(rd.stack)-1] + rd.stack = rd.stack[:len(rd.stack)-1] + return nil +} + func decodeUint8(dec *msgpack.Decoder, v *uint8) error { tmp, err := dec.DecodeUint8() if err != nil { diff --git a/msgpack_reader_test.go b/msgpack_reader_test.go index 12931cd..8210096 100644 --- a/msgpack_reader_test.go +++ b/msgpack_reader_test.go @@ -10,61 +10,7 @@ import ( ) func TestMsgpackReader(t *testing.T) { - t.Run("reading struct as map", func(t *testing.T) { - // {"f1":"blabla","f3":[2,4,6],"f2":{"field":1},"f4":[{"field":10},{"field":1024}]} - src := "84a26631a6626c61626c61a2663393020406a2663281a56669656c6401a266349281a56669656c640a81a56669656c64cd0400" - - expected := TestParent{ - Field1: "blabla", - Field2: TestFoo{ - Field: 1, - }, - Field3: []int8{2, 4, 6}, - Field4: []TestFoo{ - {Field: 10}, - {Field: 1024}, - }, - } - - data, err := hex.DecodeString(src) - require.NoError(t, err) - - rdr := meta.NewMsgpackReader(bytes.NewReader(data)) - - var actual TestParent - readErr := actual.Read(rdr) - require.NoError(t, readErr) - require.EqualValues(t, expected, actual) - }) - - t.Run("reading struct as array with maps", func(t *testing.T) { - // ["blabla",{"field":1},[2,4,6],[{"field":10},{"field":1024}]] - src := "94a6626c61626c6181a56669656c6401930204069281a56669656c640a81a56669656c64cd0400" - - expected := TestParent{ - Field1: "blabla", - Field2: TestFoo{ - Field: 1, - }, - Field3: []int8{2, 4, 6}, - Field4: []TestFoo{ - {Field: 10}, - {Field: 1024}, - }, - } - - data, err := hex.DecodeString(src) - require.NoError(t, err) - - rdr := meta.NewMsgpackReader(bytes.NewReader(data)) - - var actual TestParent - readErr := actual.Read(rdr) - require.NoError(t, readErr) - require.EqualValues(t, expected, actual) - }) - - t.Run("reading struct only as array", func(t *testing.T) { + t.Run("reading parent", func(t *testing.T) { // ["blabla",[1],[2,4,6],[[10],[1024]]] src := "94a6626c61626c6191019302040692910a91cd0400" @@ -91,37 +37,7 @@ func TestMsgpackReader(t *testing.T) { require.EqualValues(t, expected, actual) }) - t.Run("reading child struct as map", func(t *testing.T) { - // {"f":"qwerty","f1":"blabla","f3":[2,4,6],"f2":{"field":1},"f4":[{"field":10},{"field":1024}]} - src := "85a166a6717765727479a26631a6626c61626c61a2663393020406a2663281a56669656c6401a266349281a56669656c640a81a56669656c64cd0400" - - expected := TestChild{ - Field: "qwerty", - TestParent: TestParent{ - Field1: "blabla", - Field2: TestFoo{ - Field: 1, - }, - Field3: []int8{2, 4, 6}, - Field4: []TestFoo{ - {Field: 10}, - {Field: 1024}, - }, - }, - } - - data, err := hex.DecodeString(src) - require.NoError(t, err) - - rdr := meta.NewMsgpackReader(bytes.NewReader(data)) - - var actual TestChild - readErr := actual.Read(rdr) - require.NoError(t, readErr) - require.EqualValues(t, expected, actual) - }) - - t.Run("reading child struct as array", func(t *testing.T) { + t.Run("reading child", func(t *testing.T) { // ["blabla",[1],[2,4,6],[[10],[1024]],"qwerty"] src := "95a6626c61626c6191019302040692910a91cd0400a6717765727479" @@ -150,4 +66,59 @@ func TestMsgpackReader(t *testing.T) { require.NoError(t, readErr) require.EqualValues(t, expected, actual) }) + + t.Run("fail reading parent as map", func(t *testing.T) { + // {"f1":"blabla","f3":[2,4,6],"f2":{"field":1},"f4":[{"field":10},{"field":1024}]} + src := "84a26631a6626c61626c61a2663393020406a2663281a56669656c6401a266349281a56669656c640a81a56669656c64cd0400" + + expected := TestParent{} + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackReader(bytes.NewReader(data)) + + var actual TestParent + readErr := actual.Read(rdr) + require.ErrorContains(t, readErr, "field `` is not an array") + require.EqualValues(t, expected, actual) + }) + + t.Run("fail reading parent as array with maps", func(t *testing.T) { + // ["blabla",{"field":1},[2,4,6],[{"field":10},{"field":1024}]] + src := "94a6626c61626c6181a56669656c6401930204069281a56669656c640a81a56669656c64cd0400" + + expected := TestParent{ + Field1: "blabla", + Field3: []int8{}, + Field4: []TestFoo{}, + } + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackReader(bytes.NewReader(data)) + + var actual TestParent + readErr := actual.Read(rdr) + require.ErrorContains(t, readErr, "field `f2` is not an array") + require.EqualValues(t, expected, actual) + }) + + t.Run("fail reading child struct as map", func(t *testing.T) { + // {"f":"qwerty","f1":"blabla","f3":[2,4,6],"f2":{"field":1},"f4":[{"field":10},{"field":1024}]} + src := "85a166a6717765727479a26631a6626c61626c61a2663393020406a2663281a56669656c6401a266349281a56669656c640a81a56669656c64cd0400" + + expected := TestChild{} + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackReader(bytes.NewReader(data)) + + var actual TestChild + readErr := actual.Read(rdr) + require.ErrorContains(t, readErr, "field `` is not an array") + require.EqualValues(t, expected, actual) + }) } diff --git a/structs_test.go b/structs_test.go index d6ee4ca..d3d9ab6 100644 --- a/structs_test.go +++ b/structs_test.go @@ -79,7 +79,7 @@ func (s *TestParent) ReadFields(reader meta.Reader) error { } contSize-- - if err := reader.BeginContainer("f3"); err != nil { + if err := reader.BeginCollection("f3"); err != nil { return err } field3Size, err := reader.ContainerSize() @@ -94,7 +94,7 @@ func (s *TestParent) ReadFields(reader meta.Reader) error { s.Field3 = append(s.Field3, tmpField3) } - if err := reader.EndContainer(); err != nil { + if err := reader.EndCollection(); err != nil { return err } @@ -103,7 +103,7 @@ func (s *TestParent) ReadFields(reader meta.Reader) error { } contSize-- - if err := reader.BeginContainer("f4"); err != nil { + if err := reader.BeginCollection("f4"); err != nil { return err } field4Size, err := reader.ContainerSize() @@ -124,7 +124,7 @@ func (s *TestParent) ReadFields(reader meta.Reader) error { s.Field4 = append(s.Field4, tmpField4) } - if err := reader.EndContainer(); err != nil { + if err := reader.EndCollection(); err != nil { return err }