diff --git a/msgpack_assoc_reader.go b/msgpack_assoc_reader.go new file mode 100644 index 0000000..01dfc27 --- /dev/null +++ b/msgpack_assoc_reader.go @@ -0,0 +1,616 @@ +package meta + +import ( + "bytes" + "io" + + "github.com/pkg/errors" + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" +) + +type msgpackAssocReader struct { + dec *msgpack.Decoder + stack []assocReadContainer + curr assocReadContainer +} + +type assocReadContainer struct { + started bool + length int + assoc bool + values map[string]msgpack.RawMessage + reader io.Reader + readCnt int +} + +func NewMsgpackAssocReader(r io.Reader) Reader { + return &msgpackAssocReader{ + dec: msgpack.NewDecoder(r), + stack: make([]assocReadContainer, 0, 2), + curr: assocReadContainer{ + reader: r, + }, + } +} + +func (rd *msgpackAssocReader) readField() (string, error) { + field, err := rd.dec.DecodeString() + if err != nil { + return "", errors.WithStack(err) + } + return field, nil +} + +func (rd *msgpackAssocReader) ReadInt8(v *int8, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt8(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt8(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt8(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadInt16(v *int16, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt16(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt16(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt16(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadInt32(v *int32, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt32(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt32(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt32(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadInt64(v *int64, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt64(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt64(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt64(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadUint8(v *uint8, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint8(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint8(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint8(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadUint16(v *uint16, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint16(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint16(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint16(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadUint32(v *uint32, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint32(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint32(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint32(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadUint64(v *uint64, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint64(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint64(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint64(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadBool(v *bool, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeBool(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeBool(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeBool(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadFloat32(v *float32, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeFloat32(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeFloat32(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeFloat32(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadFloat64(v *float64, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeFloat64(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeFloat64(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeFloat64(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadString(v *string, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeString(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeString(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeString(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) ReadBytes(v *[]byte, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeBytes(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeBytes(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeBytes(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) BeginContainer(targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return rd.beginContainer(targetField) + } + + if b, ok := rd.curr.values[targetField]; ok { + rd.dec.Reset(bytes.NewReader(b)) + return rd.beginContainer(targetField) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return rd.beginContainer(targetField) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) EndContainer() error { + return rd.endContainer() +} + +func (rd *msgpackAssocReader) BeginCollection(targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return rd.beginCollection(targetField) + } + + if b, ok := rd.curr.values[targetField]; ok { + rd.dec.Reset(bytes.NewReader(b)) + return rd.beginCollection(targetField) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return rd.beginCollection(targetField) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + rd.curr.readCnt = i + 1 + } + + return FieldNotFound +} + +func (rd *msgpackAssocReader) EndCollection() error { + return rd.endContainer() +} + +func (rd *msgpackAssocReader) ContainerSize() (int, error) { + return rd.curr.length, nil +} + +func (rd *msgpackAssocReader) IsContainerAssoc() (bool, error) { + return rd.curr.assoc, nil +} + +func (rd *msgpackAssocReader) Skip() error { + return errors.WithStack(rd.dec.Skip()) +} + +func (rd *msgpackAssocReader) TryReadMask() (bool, FieldsMask, error) { + if rd.curr.assoc { + return false, FieldsMask{}, nil + } + + maskLen, err := rd.dec.DecodeArrayLen() + if err != nil { + if err == io.EOF { + return false, FieldsMask{}, nil + } + return false, FieldsMask{}, errors.WithStack(err) + } + + var mask FieldsMask + for i := 0; i < maskLen; i++ { + maskPart, err := rd.dec.DecodeUint64() + if err != nil { + return false, FieldsMask{}, errors.WithStack(err) + } + mask.SetPartFromUint64(i, maskPart) + } + + return true, mask, nil +} + +func (rd *msgpackAssocReader) beginContainer(field string) error { + code, err := rd.dec.PeekCode() + if err != nil { + return errors.WithStack(err) + } + + switch { + case msgpcode.IsFixedMap(code), code == msgpcode.Map16, code == msgpcode.Map32: + l, err := rd.dec.DecodeMapLen() + if err != nil { + return errors.WithStack(err) + } + rd.stack = append(rd.stack, rd.curr) + rd.curr = assocReadContainer{ + started: true, + length: l, + assoc: true, + values: make(map[string]msgpack.RawMessage, l), + reader: rd.dec.Buffered(), + } + + default: + return errors.Errorf("field `%s` is not a map", field) + } + return nil +} + +func (rd *msgpackAssocReader) beginCollection(field string) error { + code, err := rd.dec.PeekCode() + if err != nil { + return errors.WithStack(err) + } + + switch { + case msgpcode.IsFixedArray(code), code == msgpcode.Array16, code == msgpcode.Array32: + l, err := rd.dec.DecodeArrayLen() + if err != nil { + return errors.WithStack(err) + } + rd.stack = append(rd.stack, rd.curr) + rd.curr = assocReadContainer{ + started: true, + length: l, + assoc: false, + reader: rd.dec.Buffered(), + } + + default: + return errors.Errorf("field `%s` is not an array", field) + } + return nil +} + +func (rd *msgpackAssocReader) endContainer() error { + if len(rd.stack) == 0 { + return NoOpenContainer + } + rd.curr = rd.stack[len(rd.stack)-1] + rd.stack = rd.stack[:len(rd.stack)-1] + rd.dec.Reset(rd.curr.reader) + return nil +} diff --git a/msgpack_assoc_reader_test.go b/msgpack_assoc_reader_test.go new file mode 100644 index 0000000..acf9cac --- /dev/null +++ b/msgpack_assoc_reader_test.go @@ -0,0 +1,120 @@ +package meta_test + +import ( + "bytes" + "encoding/hex" + "testing" + + "git.bit5.ru/backend/meta/v4" + "github.com/stretchr/testify/require" +) + +func TestAssocMsgpackReader(t *testing.T) { + t.Run("reading parent", func(t *testing.T) { + // {"f1":"blabla","f3":[2,4,6],"f2":{"field":1},"f4":[{"field":10},{"field":1024}]} + src := "84a26631a6626c61626c61a2663393020406a2663281a56669656c6401a266349281a56669656c640a81a56669656c64cd0400" + + expected := TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + } + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackAssocReader(bytes.NewReader(data)) + + var actual TestParent + readErr := actual.Read(rdr) + require.NoError(t, readErr) + require.EqualValues(t, expected, actual) + }) + + t.Run("reading child", func(t *testing.T) { + // {"f":"qwerty","f1":"blabla","f3":[2,4,6],"f2":{"field":1},"f4":[{"field":10},{"field":1024}]} + src := "85a166a6717765727479a26631a6626c61626c61a2663393020406a2663281a56669656c6401a266349281a56669656c640a81a56669656c64cd0400" + + expected := TestChild{ + Field: "qwerty", + TestParent: TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + }, + } + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackAssocReader(bytes.NewReader(data)) + + var actual TestChild + readErr := actual.Read(rdr) + require.NoError(t, readErr) + require.EqualValues(t, expected, actual) + }) + + t.Run("fail reading parent as array", func(t *testing.T) { + // ["blabla",[1],[2,4,6],[[10],[1024]]] + src := "94a6626c61626c6191019302040692910a91cd0400" + + expected := TestParent{} + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackAssocReader(bytes.NewReader(data)) + + var actual TestParent + readErr := actual.Read(rdr) + require.ErrorContains(t, readErr, "field `` is not a map") + require.EqualValues(t, expected, actual) + }) + + t.Run("fail reading parent as array with maps", func(t *testing.T) { + // ["blabla",{"field":1},[2,4,6],[{"field":10},{"field":1024}]] + src := "94a6626c61626c6181a56669656c6401930204069281a56669656c640a81a56669656c64cd0400" + + expected := TestParent{} + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackAssocReader(bytes.NewReader(data)) + + var actual TestParent + readErr := actual.Read(rdr) + require.ErrorContains(t, readErr, "field `` is not a map") + require.EqualValues(t, expected, actual) + }) + + t.Run("fail reading child as array", func(t *testing.T) { + // ["blabla",[1],[2,4,6],[[10],[1024]],"qwerty"] + src := "95a6626c61626c6191019302040692910a91cd0400a6717765727479" + + expected := TestChild{} + + data, err := hex.DecodeString(src) + require.NoError(t, err) + + rdr := meta.NewMsgpackAssocReader(bytes.NewReader(data)) + + var actual TestChild + readErr := actual.Read(rdr) + require.ErrorContains(t, readErr, "field `` is not a map") + require.EqualValues(t, expected, actual) + }) +}