From 31aa90bbaaf5e729b027579bece7760ead9cd238 Mon Sep 17 00:00:00 2001 From: Pavel Merzlyakov Date: Thu, 8 Jun 2023 10:06:25 +0300 Subject: [PATCH] stateful reader --- msgpack_stateful_reader.go | 672 ++++++++++++++++++++++++++++++++ msgpack_stateful_reader_test.go | 124 ++++++ 2 files changed, 796 insertions(+) create mode 100644 msgpack_stateful_reader.go create mode 100644 msgpack_stateful_reader_test.go diff --git a/msgpack_stateful_reader.go b/msgpack_stateful_reader.go new file mode 100644 index 0000000..9dd5961 --- /dev/null +++ b/msgpack_stateful_reader.go @@ -0,0 +1,672 @@ +package meta + +import ( + "bytes" + "io" + + "github.com/pkg/errors" + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" +) + +type msgpackStatefulReader struct { + dec *msgpack.Decoder + stack []container + curr container +} + +type container struct { + started bool + length int + assoc bool + values map[string]msgpack.RawMessage + reader io.Reader + readCnt int +} + +func NewMsgpackStatefulReader(r io.Reader) Reader { + return &msgpackStatefulReader{ + dec: msgpack.NewDecoder(r), + stack: make([]container, 0, 2), + curr: container{ + reader: r, + }, + } +} + +func (rd *msgpackStatefulReader) readField() (string, error) { + field, err := rd.dec.DecodeString() + if err != nil { + return "", errors.WithStack(err) + } + return field, nil +} + +func (rd *msgpackStatefulReader) ReadInt8(v *int8, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt8(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt8(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt8(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadInt16(v *int16, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt16(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt16(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt16(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadInt32(v *int32, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt32(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt32(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt32(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadInt64(v *int64, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt64(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt64(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt64(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadUint8(v *uint8, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint8(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint8(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint8(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadUint16(v *uint16, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint16(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint16(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint16(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadUint32(v *uint32, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint32(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint32(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint32(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadUint64(v *uint64, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint64(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint64(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint64(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadBool(v *bool, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeBool(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeBool(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeBool(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadFloat32(v *float32, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeFloat32(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeFloat32(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeFloat32(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadFloat64(v *float64, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeFloat64(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeFloat64(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeFloat64(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadString(v *string, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeString(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeString(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeString(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) ReadBytes(v *[]byte, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeBytes(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeBytes(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeBytes(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) BeginContainer(targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return rd.beginContainer(targetField) + } + + if b, ok := rd.curr.values[targetField]; ok { + rd.dec.Reset(bytes.NewReader(b)) + return rd.beginContainer(targetField) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return rd.beginContainer(targetField) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackStatefulReader) beginContainer(field string) error { + code, err := rd.dec.PeekCode() + if err != nil { + return errors.WithStack(err) + } + + switch { + case msgpcode.IsFixedMap(code), code == msgpcode.Map16, code == msgpcode.Map32: + l, err := rd.dec.DecodeMapLen() + if err != nil { + return errors.WithStack(err) + } + rd.stack = append(rd.stack, rd.curr) + rd.curr = container{ + started: true, + length: l, + assoc: true, + values: make(map[string]msgpack.RawMessage, l), + reader: rd.dec.Buffered(), + } + + case msgpcode.IsFixedArray(code), code == msgpcode.Array16, code == msgpcode.Array32: + l, err := rd.dec.DecodeArrayLen() + if err != nil { + return errors.WithStack(err) + } + rd.stack = append(rd.stack, rd.curr) + rd.curr = container{ + started: true, + length: l, + assoc: false, + reader: rd.dec.Buffered(), + } + + default: + return errors.Errorf("there is no container for field `%s`", field) + } + return nil +} + +func (rd *msgpackStatefulReader) EndContainer() error { + if len(rd.stack) == 0 { + return errors.New("there is no open containers") + } + rd.curr = rd.stack[len(rd.stack)-1] + rd.stack = rd.stack[:len(rd.stack)-1] + return nil +} + +func (rd *msgpackStatefulReader) ContainerSize() (int, error) { + return rd.curr.length, nil +} + +func (rd *msgpackStatefulReader) IsContainerAssoc() (bool, error) { + return rd.curr.assoc, nil +} + +func (rd *msgpackStatefulReader) Skip() error { + return errors.WithStack(rd.dec.Skip()) +} + +func (rd *msgpackStatefulReader) TryReadMask() (bool, FieldsMask, error) { + code, err := rd.dec.PeekCode() + if err != nil { + return false, FieldsMask{}, errors.WithStack(err) + } + + if code != msgpcode.Nil { + return false, FieldsMask{}, nil + } + + if err := rd.dec.Skip(); err != nil { + return false, FieldsMask{}, errors.WithStack(err) + } + + var mask FieldsMask + maskLen, err := rd.dec.DecodeArrayLen() + if err != nil { + return false, FieldsMask{}, errors.WithStack(err) + } + + for i := 0; i < maskLen; i++ { + maskPart, err := rd.dec.DecodeUint64() + if err != nil { + return false, FieldsMask{}, errors.WithStack(err) + } + mask.SetPartFromUint64(i, maskPart) + } + + return true, mask, nil +} + +func decodeUint8(dec *msgpack.Decoder, v *uint8) error { + tmp, err := dec.DecodeUint8() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeUint16(dec *msgpack.Decoder, v *uint16) error { + tmp, err := dec.DecodeUint16() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeUint32(dec *msgpack.Decoder, v *uint32) error { + tmp, err := dec.DecodeUint32() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeUint64(dec *msgpack.Decoder, v *uint64) error { + tmp, err := dec.DecodeUint64() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeInt8(dec *msgpack.Decoder, v *int8) error { + tmp, err := dec.DecodeInt8() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeInt16(dec *msgpack.Decoder, v *int16) error { + tmp, err := dec.DecodeInt16() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeInt32(dec *msgpack.Decoder, v *int32) error { + tmp, err := dec.DecodeInt32() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeInt64(dec *msgpack.Decoder, v *int64) error { + tmp, err := dec.DecodeInt64() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeBool(dec *msgpack.Decoder, v *bool) error { + tmp, err := dec.DecodeBool() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeFloat32(dec *msgpack.Decoder, v *float32) error { + tmp, err := dec.DecodeFloat32() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeFloat64(dec *msgpack.Decoder, v *float64) error { + tmp, err := dec.DecodeFloat64() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeString(dec *msgpack.Decoder, v *string) error { + tmp, err := dec.DecodeString() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeBytes(dec *msgpack.Decoder, v *[]byte) error { + tmp, err := dec.DecodeBytes() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} diff --git a/msgpack_stateful_reader_test.go b/msgpack_stateful_reader_test.go new file mode 100644 index 0000000..719ed94 --- /dev/null +++ b/msgpack_stateful_reader_test.go @@ -0,0 +1,124 @@ +package meta_test + +import ( + "bytes" + "encoding/hex" + "testing" + + "git.bit5.ru/backend/meta" + "github.com/stretchr/testify/require" +) + +type foo struct { + field1 uint32 + field2 uint32 + field3 []uint32 + field4 bar +} + +func (f *foo) Read(reader meta.Reader) error { + if err := reader.BeginContainer(""); err != nil { + return err + } + if err := f.ReadFields(reader); err != nil { + return err + } + if err := reader.EndContainer(); err != nil { + return err + } + return nil +} + +func (f *foo) ReadFields(reader meta.Reader) error { + if err := reader.ReadUint32(&f.field1, "f1"); err != nil { + return err + } + + if err := reader.ReadUint32(&f.field2, "f2"); err != nil { + return err + } + + if err := reader.BeginContainer("f3"); err != nil { + return err + } + size, err := reader.ContainerSize() + if err != nil { + return err + } + for ; size > 0; size-- { + var tmp uint32 + if err := reader.ReadUint32(&tmp, ""); err != nil { + return err + } + f.field3 = append(f.field3, tmp) + } + if err := reader.EndContainer(); err != nil { + return err + } + + if err := reader.BeginContainer("f4"); err != nil { + return err + } + if err := f.field4.ReadFields(reader); err != nil { + return err + } + if err := reader.EndContainer(); err != nil { + return err + } + + return nil +} + +type bar struct { + a uint32 + b uint32 +} + +func (b *bar) Read(reader meta.Reader) error { + if err := reader.BeginContainer(""); err != nil { + return err + } + if err := b.ReadFields(reader); err != nil { + return err + } + if err := reader.EndContainer(); err != nil { + return err + } + return nil +} + +func (b *bar) ReadFields(reader meta.Reader) error { + if err := reader.ReadUint32(&b.a, "a"); err != nil { + return err + } + + if err := reader.ReadUint32(&b.b, "b"); err != nil { + return err + } + + return nil +} + +func TestMsgprdr(t *testing.T) { + str := `84A266310AA2663393010203A2663482A16104A16205A266320F` + b, err := hex.DecodeString(str) + require.NoError(t, err) + + rdr := meta.NewMsgpackStatefulReader(bytes.NewReader(b)) + + var actual foo + readErr := actual.Read(rdr) + require.NoError(t, readErr) + + expected := foo{ + field1: 10, + field2: 15, + field3: []uint32{1, 2, 3}, + field4: bar{ + a: 4, + b: 5, + }, + } + + require.EqualValues(t, expected, actual) +}