From 9e0e7b8bd1dcde6473b263ab7e3e5e31836808aa Mon Sep 17 00:00:00 2001 From: Pavel Merzlyakov Date: Fri, 16 Jun 2023 13:52:56 +0300 Subject: [PATCH] fix reading assoc containers --- msgpack_reader.go | 648 +++++++++++++++++++++++++----- msgpack_stateful_reader.go | 672 -------------------------------- msgpack_stateful_reader_test.go | 124 ------ msgpack_writer.go | 17 +- structs_test.go | 280 +------------ 5 files changed, 568 insertions(+), 1173 deletions(-) delete mode 100644 msgpack_stateful_reader.go delete mode 100644 msgpack_stateful_reader_test.go diff --git a/msgpack_reader.go b/msgpack_reader.go index 13ce160..0461c87 100644 --- a/msgpack_reader.go +++ b/msgpack_reader.go @@ -1,6 +1,7 @@ package meta import ( + "bytes" "io" "github.com/pkg/errors" @@ -9,148 +10,459 @@ import ( ) type msgpackReader struct { - dec *msgpack.Decoder - containers []msgpackContainer + dec *msgpack.Decoder + stack []readContainer + curr readContainer } -type msgpackContainer struct { - length int - assoc bool +type readContainer struct { + started bool + length int + assoc bool + values map[string]msgpack.RawMessage + reader io.Reader + readCnt int } func NewMsgpackReader(r io.Reader) Reader { return &msgpackReader{ - dec: msgpack.NewDecoder(r), - containers: make([]msgpackContainer, 0, 2), + dec: msgpack.NewDecoder(r), + stack: make([]readContainer, 0, 2), + curr: readContainer{ + reader: r, + }, } } -func (rd *msgpackReader) currentContainer() msgpackContainer { - if rd == nil || len(rd.containers) == 0 { - return msgpackContainer{} - } - - return rd.containers[len(rd.containers)-1] -} - -func (rd *msgpackReader) ReadInt8(v *int8, field string) error { - tmp, err := rd.dec.DecodeInt8() +func (rd *msgpackReader) readField() (string, error) { + field, err := rd.dec.DecodeString() if err != nil { - return errors.WithStack(err) + return "", errors.WithStack(err) } - *v = tmp - return nil + return field, nil } -func (rd *msgpackReader) ReadInt16(v *int16, field string) error { - tmp, err := rd.dec.DecodeInt16() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadInt8(v *int8, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt8(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt8(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt8(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadInt32(v *int32, field string) error { - tmp, err := rd.dec.DecodeInt32() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadInt16(v *int16, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt16(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt16(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt16(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadInt64(v *int64, field string) error { - tmp, err := rd.dec.DecodeInt64() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadInt32(v *int32, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt32(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt32(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt32(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadUint8(v *uint8, field string) error { - tmp, err := rd.dec.DecodeUint8() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadInt64(v *int64, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeInt64(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeInt64(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeInt64(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadUint16(v *uint16, field string) error { - tmp, err := rd.dec.DecodeUint16() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadUint8(v *uint8, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint8(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint8(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint8(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadUint32(v *uint32, field string) error { - tmp, err := rd.dec.DecodeUint32() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadUint16(v *uint16, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint16(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint16(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint16(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadUint64(v *uint64, field string) error { - tmp, err := rd.dec.DecodeUint64() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadUint32(v *uint32, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint32(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint32(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint32(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadBool(v *bool, field string) error { - tmp, err := rd.dec.DecodeBool() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadUint64(v *uint64, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeUint64(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeUint64(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeUint64(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadFloat32(v *float32, field string) error { - tmp, err := rd.dec.DecodeFloat32() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadBool(v *bool, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeBool(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeBool(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeBool(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadFloat64(v *float64, field string) error { - tmp, err := rd.dec.DecodeFloat64() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadFloat32(v *float32, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeFloat32(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeFloat32(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeFloat32(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadString(v *string, field string) error { - tmp, err := rd.dec.DecodeString() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadFloat64(v *float64, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeFloat64(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeFloat64(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeFloat64(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) ReadBytes(v *[]byte, field string) error { - tmp, err := rd.dec.DecodeBytes() - if err != nil { - return errors.WithStack(err) +func (rd *msgpackReader) ReadString(v *string, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeString(rd.dec, v) } - *v = tmp - return nil + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeString(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeString(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) } -func (rd *msgpackReader) BeginContainer(field string) error { +func (rd *msgpackReader) ReadBytes(v *[]byte, targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return decodeBytes(rd.dec, v) + } + + if b, ok := rd.curr.values[targetField]; ok { + dec := msgpack.NewDecoder(bytes.NewReader(b)) + return decodeBytes(dec, v) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return decodeBytes(rd.dec, v) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackReader) BeginContainer(targetField string) error { + if !rd.curr.started || !rd.curr.assoc { + return rd.beginContainer(targetField) + } + + if b, ok := rd.curr.values[targetField]; ok { + rd.dec.Reset(bytes.NewReader(b)) + return rd.beginContainer(targetField) + } + + for i := rd.curr.readCnt; i < rd.curr.length; i++ { + field, err := rd.readField() + if err != nil { + return err + } + if field == targetField { + rd.curr.readCnt = i + 1 + return rd.beginContainer(targetField) + } + + raw, err := rd.dec.DecodeRaw() + if err != nil { + return errors.WithStack(err) + } + rd.curr.values[field] = raw + } + + return errors.Errorf("field `%s` not found", targetField) +} + +func (rd *msgpackReader) beginContainer(field string) error { code, err := rd.dec.PeekCode() if err != nil { return errors.WithStack(err) @@ -162,20 +474,27 @@ func (rd *msgpackReader) BeginContainer(field string) error { if err != nil { return errors.WithStack(err) } - rd.containers = append(rd.containers, msgpackContainer{ - length: l, - assoc: true, - }) + rd.stack = append(rd.stack, rd.curr) + rd.curr = readContainer{ + started: true, + length: l, + assoc: true, + values: make(map[string]msgpack.RawMessage, l), + reader: rd.dec.Buffered(), + } case msgpcode.IsFixedArray(code), code == msgpcode.Array16, code == msgpcode.Array32: l, err := rd.dec.DecodeArrayLen() if err != nil { return errors.WithStack(err) } - rd.containers = append(rd.containers, msgpackContainer{ - length: l, - assoc: false, - }) + rd.stack = append(rd.stack, rd.curr) + rd.curr = readContainer{ + started: true, + length: l, + assoc: false, + reader: rd.dec.Buffered(), + } default: return errors.Errorf("there is no container for field `%s`", field) @@ -184,19 +503,21 @@ func (rd *msgpackReader) BeginContainer(field string) error { } func (rd *msgpackReader) EndContainer() error { - if len(rd.containers) == 0 { + if len(rd.stack) == 0 { return errors.New("there is no open containers") } - rd.containers = rd.containers[:len(rd.containers)-1] + rd.curr = rd.stack[len(rd.stack)-1] + rd.stack = rd.stack[:len(rd.stack)-1] + rd.dec.Reset(rd.curr.reader) return nil } func (rd *msgpackReader) ContainerSize() (int, error) { - return rd.currentContainer().length, nil + return rd.curr.length, nil } func (rd *msgpackReader) IsContainerAssoc() (bool, error) { - return rd.currentContainer().assoc, nil + return rd.curr.assoc, nil } func (rd *msgpackReader) Skip() error { @@ -233,3 +554,120 @@ func (rd *msgpackReader) TryReadMask() (bool, FieldsMask, error) { return true, mask, nil } + +func decodeUint8(dec *msgpack.Decoder, v *uint8) error { + tmp, err := dec.DecodeUint8() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeUint16(dec *msgpack.Decoder, v *uint16) error { + tmp, err := dec.DecodeUint16() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeUint32(dec *msgpack.Decoder, v *uint32) error { + tmp, err := dec.DecodeUint32() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeUint64(dec *msgpack.Decoder, v *uint64) error { + tmp, err := dec.DecodeUint64() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeInt8(dec *msgpack.Decoder, v *int8) error { + tmp, err := dec.DecodeInt8() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeInt16(dec *msgpack.Decoder, v *int16) error { + tmp, err := dec.DecodeInt16() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeInt32(dec *msgpack.Decoder, v *int32) error { + tmp, err := dec.DecodeInt32() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeInt64(dec *msgpack.Decoder, v *int64) error { + tmp, err := dec.DecodeInt64() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeBool(dec *msgpack.Decoder, v *bool) error { + tmp, err := dec.DecodeBool() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeFloat32(dec *msgpack.Decoder, v *float32) error { + tmp, err := dec.DecodeFloat32() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeFloat64(dec *msgpack.Decoder, v *float64) error { + tmp, err := dec.DecodeFloat64() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeString(dec *msgpack.Decoder, v *string) error { + tmp, err := dec.DecodeString() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} + +func decodeBytes(dec *msgpack.Decoder, v *[]byte) error { + tmp, err := dec.DecodeBytes() + if err != nil { + return errors.WithStack(err) + } + *v = tmp + return nil +} diff --git a/msgpack_stateful_reader.go b/msgpack_stateful_reader.go deleted file mode 100644 index 9dd5961..0000000 --- a/msgpack_stateful_reader.go +++ /dev/null @@ -1,672 +0,0 @@ -package meta - -import ( - "bytes" - "io" - - "github.com/pkg/errors" - "github.com/vmihailenco/msgpack/v5" - "github.com/vmihailenco/msgpack/v5/msgpcode" -) - -type msgpackStatefulReader struct { - dec *msgpack.Decoder - stack []container - curr container -} - -type container struct { - started bool - length int - assoc bool - values map[string]msgpack.RawMessage - reader io.Reader - readCnt int -} - -func NewMsgpackStatefulReader(r io.Reader) Reader { - return &msgpackStatefulReader{ - dec: msgpack.NewDecoder(r), - stack: make([]container, 0, 2), - curr: container{ - reader: r, - }, - } -} - -func (rd *msgpackStatefulReader) readField() (string, error) { - field, err := rd.dec.DecodeString() - if err != nil { - return "", errors.WithStack(err) - } - return field, nil -} - -func (rd *msgpackStatefulReader) ReadInt8(v *int8, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeInt8(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeInt8(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeInt8(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadInt16(v *int16, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeInt16(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeInt16(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeInt16(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadInt32(v *int32, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeInt32(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeInt32(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeInt32(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadInt64(v *int64, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeInt64(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeInt64(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeInt64(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadUint8(v *uint8, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeUint8(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeUint8(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeUint8(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadUint16(v *uint16, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeUint16(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeUint16(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeUint16(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadUint32(v *uint32, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeUint32(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeUint32(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeUint32(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadUint64(v *uint64, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeUint64(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeUint64(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeUint64(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadBool(v *bool, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeBool(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeBool(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeBool(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadFloat32(v *float32, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeFloat32(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeFloat32(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeFloat32(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadFloat64(v *float64, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeFloat64(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeFloat64(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeFloat64(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadString(v *string, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeString(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeString(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeString(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) ReadBytes(v *[]byte, targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return decodeBytes(rd.dec, v) - } - - if b, ok := rd.curr.values[targetField]; ok { - dec := msgpack.NewDecoder(bytes.NewReader(b)) - return decodeBytes(dec, v) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return decodeBytes(rd.dec, v) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) BeginContainer(targetField string) error { - if !rd.curr.started || !rd.curr.assoc { - return rd.beginContainer(targetField) - } - - if b, ok := rd.curr.values[targetField]; ok { - rd.dec.Reset(bytes.NewReader(b)) - return rd.beginContainer(targetField) - } - - for i := rd.curr.readCnt; i < rd.curr.length; i++ { - field, err := rd.readField() - if err != nil { - return err - } - if field == targetField { - rd.curr.readCnt = i + 1 - return rd.beginContainer(targetField) - } - - raw, err := rd.dec.DecodeRaw() - if err != nil { - return errors.WithStack(err) - } - rd.curr.values[field] = raw - } - - return errors.Errorf("field `%s` not found", targetField) -} - -func (rd *msgpackStatefulReader) beginContainer(field string) error { - code, err := rd.dec.PeekCode() - if err != nil { - return errors.WithStack(err) - } - - switch { - case msgpcode.IsFixedMap(code), code == msgpcode.Map16, code == msgpcode.Map32: - l, err := rd.dec.DecodeMapLen() - if err != nil { - return errors.WithStack(err) - } - rd.stack = append(rd.stack, rd.curr) - rd.curr = container{ - started: true, - length: l, - assoc: true, - values: make(map[string]msgpack.RawMessage, l), - reader: rd.dec.Buffered(), - } - - case msgpcode.IsFixedArray(code), code == msgpcode.Array16, code == msgpcode.Array32: - l, err := rd.dec.DecodeArrayLen() - if err != nil { - return errors.WithStack(err) - } - rd.stack = append(rd.stack, rd.curr) - rd.curr = container{ - started: true, - length: l, - assoc: false, - reader: rd.dec.Buffered(), - } - - default: - return errors.Errorf("there is no container for field `%s`", field) - } - return nil -} - -func (rd *msgpackStatefulReader) EndContainer() error { - if len(rd.stack) == 0 { - return errors.New("there is no open containers") - } - rd.curr = rd.stack[len(rd.stack)-1] - rd.stack = rd.stack[:len(rd.stack)-1] - return nil -} - -func (rd *msgpackStatefulReader) ContainerSize() (int, error) { - return rd.curr.length, nil -} - -func (rd *msgpackStatefulReader) IsContainerAssoc() (bool, error) { - return rd.curr.assoc, nil -} - -func (rd *msgpackStatefulReader) Skip() error { - return errors.WithStack(rd.dec.Skip()) -} - -func (rd *msgpackStatefulReader) TryReadMask() (bool, FieldsMask, error) { - code, err := rd.dec.PeekCode() - if err != nil { - return false, FieldsMask{}, errors.WithStack(err) - } - - if code != msgpcode.Nil { - return false, FieldsMask{}, nil - } - - if err := rd.dec.Skip(); err != nil { - return false, FieldsMask{}, errors.WithStack(err) - } - - var mask FieldsMask - maskLen, err := rd.dec.DecodeArrayLen() - if err != nil { - return false, FieldsMask{}, errors.WithStack(err) - } - - for i := 0; i < maskLen; i++ { - maskPart, err := rd.dec.DecodeUint64() - if err != nil { - return false, FieldsMask{}, errors.WithStack(err) - } - mask.SetPartFromUint64(i, maskPart) - } - - return true, mask, nil -} - -func decodeUint8(dec *msgpack.Decoder, v *uint8) error { - tmp, err := dec.DecodeUint8() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeUint16(dec *msgpack.Decoder, v *uint16) error { - tmp, err := dec.DecodeUint16() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeUint32(dec *msgpack.Decoder, v *uint32) error { - tmp, err := dec.DecodeUint32() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeUint64(dec *msgpack.Decoder, v *uint64) error { - tmp, err := dec.DecodeUint64() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeInt8(dec *msgpack.Decoder, v *int8) error { - tmp, err := dec.DecodeInt8() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeInt16(dec *msgpack.Decoder, v *int16) error { - tmp, err := dec.DecodeInt16() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeInt32(dec *msgpack.Decoder, v *int32) error { - tmp, err := dec.DecodeInt32() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeInt64(dec *msgpack.Decoder, v *int64) error { - tmp, err := dec.DecodeInt64() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeBool(dec *msgpack.Decoder, v *bool) error { - tmp, err := dec.DecodeBool() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeFloat32(dec *msgpack.Decoder, v *float32) error { - tmp, err := dec.DecodeFloat32() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeFloat64(dec *msgpack.Decoder, v *float64) error { - tmp, err := dec.DecodeFloat64() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeString(dec *msgpack.Decoder, v *string) error { - tmp, err := dec.DecodeString() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} - -func decodeBytes(dec *msgpack.Decoder, v *[]byte) error { - tmp, err := dec.DecodeBytes() - if err != nil { - return errors.WithStack(err) - } - *v = tmp - return nil -} diff --git a/msgpack_stateful_reader_test.go b/msgpack_stateful_reader_test.go deleted file mode 100644 index 719ed94..0000000 --- a/msgpack_stateful_reader_test.go +++ /dev/null @@ -1,124 +0,0 @@ -package meta_test - -import ( - "bytes" - "encoding/hex" - "testing" - - "git.bit5.ru/backend/meta" - "github.com/stretchr/testify/require" -) - -type foo struct { - field1 uint32 - field2 uint32 - field3 []uint32 - field4 bar -} - -func (f *foo) Read(reader meta.Reader) error { - if err := reader.BeginContainer(""); err != nil { - return err - } - if err := f.ReadFields(reader); err != nil { - return err - } - if err := reader.EndContainer(); err != nil { - return err - } - return nil -} - -func (f *foo) ReadFields(reader meta.Reader) error { - if err := reader.ReadUint32(&f.field1, "f1"); err != nil { - return err - } - - if err := reader.ReadUint32(&f.field2, "f2"); err != nil { - return err - } - - if err := reader.BeginContainer("f3"); err != nil { - return err - } - size, err := reader.ContainerSize() - if err != nil { - return err - } - for ; size > 0; size-- { - var tmp uint32 - if err := reader.ReadUint32(&tmp, ""); err != nil { - return err - } - f.field3 = append(f.field3, tmp) - } - if err := reader.EndContainer(); err != nil { - return err - } - - if err := reader.BeginContainer("f4"); err != nil { - return err - } - if err := f.field4.ReadFields(reader); err != nil { - return err - } - if err := reader.EndContainer(); err != nil { - return err - } - - return nil -} - -type bar struct { - a uint32 - b uint32 -} - -func (b *bar) Read(reader meta.Reader) error { - if err := reader.BeginContainer(""); err != nil { - return err - } - if err := b.ReadFields(reader); err != nil { - return err - } - if err := reader.EndContainer(); err != nil { - return err - } - return nil -} - -func (b *bar) ReadFields(reader meta.Reader) error { - if err := reader.ReadUint32(&b.a, "a"); err != nil { - return err - } - - if err := reader.ReadUint32(&b.b, "b"); err != nil { - return err - } - - return nil -} - -func TestMsgprdr(t *testing.T) { - str := `84A266310AA2663393010203A2663482A16104A16205A266320F` - b, err := hex.DecodeString(str) - require.NoError(t, err) - - rdr := meta.NewMsgpackStatefulReader(bytes.NewReader(b)) - - var actual foo - readErr := actual.Read(rdr) - require.NoError(t, readErr) - - expected := foo{ - field1: 10, - field2: 15, - field3: []uint32{1, 2, 3}, - field4: bar{ - a: 4, - b: 5, - }, - } - - require.EqualValues(t, expected, actual) -} diff --git a/msgpack_writer.go b/msgpack_writer.go index 90236d3..7c7adc7 100644 --- a/msgpack_writer.go +++ b/msgpack_writer.go @@ -9,19 +9,24 @@ import ( type msgpackWriter struct { enc *msgpack.Encoder - containers []msgpackContainer + containers []writeContainer +} + +type writeContainer struct { + length int + assoc bool } func NewMsgpackWriter(w io.Writer) Writer { return &msgpackWriter{ enc: msgpack.NewEncoder(w), - containers: make([]msgpackContainer, 0, 1), + containers: make([]writeContainer, 0, 1), } } -func (wr *msgpackWriter) currentContainer() msgpackContainer { +func (wr *msgpackWriter) currentContainer() writeContainer { if wr == nil || len(wr.containers) == 0 { - return msgpackContainer{} + return writeContainer{} } return wr.containers[len(wr.containers)-1] } @@ -132,7 +137,7 @@ func (wr *msgpackWriter) BeginContainer(length int, field string) error { if err := wr.enc.EncodeArrayLen(length); err != nil { return errors.WithStack(err) } - wr.containers = append(wr.containers, msgpackContainer{ + wr.containers = append(wr.containers, writeContainer{ length: length, assoc: false, }) @@ -146,7 +151,7 @@ func (wr *msgpackWriter) BeginAssocContainer(length int, field string) error { if err := wr.enc.EncodeMapLen(length); err != nil { return errors.WithStack(err) } - wr.containers = append(wr.containers, msgpackContainer{ + wr.containers = append(wr.containers, writeContainer{ length: length, assoc: true, }) diff --git a/structs_test.go b/structs_test.go index 6af51f5..b02690b 100644 --- a/structs_test.go +++ b/structs_test.go @@ -2,7 +2,6 @@ package meta_test import ( "git.bit5.ru/backend/meta" - "github.com/pkg/errors" ) type TestParent struct { @@ -12,17 +11,6 @@ type TestParent struct { Field4 []TestFoo `json:"f4" msgpack:"f4"` } -var _TestParentRequiredFields map[string]struct{} - -func init() { - _TestParentRequiredFields = map[string]struct{}{ - "f1": {}, - "f2": {}, - "f3": {}, - "f4": {}, - } -} - func (s *TestParent) Reset() { s.Field1 = "" @@ -53,98 +41,6 @@ func (s *TestParent) Read(reader meta.Reader) error { func (s *TestParent) ReadFields(reader meta.Reader) error { s.Reset() - readAsMap, err := reader.IsContainerAssoc() - if err != nil { - return err - } - if readAsMap { - return s.readFieldsAssociative(reader) - } - return s.readFields(reader) -} - -func (s *TestParent) readFieldsAssociative(reader meta.Reader) error { - size, err := reader.ContainerSize() - if err != nil { - return err - } - - readFields := make(map[string]struct{}, 4) - for ; size > 0; size-- { - var field string - if err := reader.ReadString(&field, ""); err != nil { - return err - } - - switch field { - case "f1": - if err := reader.ReadString(&s.Field1, "f1"); err != nil { - return err - } - - case "f2": - if err := s.Field2.Read(reader); err != nil { - return err - } - - case "f3": - if err := reader.BeginContainer("f3"); err != nil { - return err - } - field3Size, err := reader.ContainerSize() - if err != nil { - return err - } - for ; field3Size > 0; field3Size-- { - var tmpField3 int8 - if err := reader.ReadInt8(&tmpField3, ""); err != nil { - return err - } - - s.Field3 = append(s.Field3, tmpField3) - } - if err := reader.EndContainer(); err != nil { - return err - } - - case "f4": - if err := reader.BeginContainer("f4"); err != nil { - return err - } - field4Size, err := reader.ContainerSize() - if err != nil { - return err - } - for ; field4Size > 0; field4Size-- { - var tmpField4 TestFoo - if err := tmpField4.Read(reader); err != nil { - return err - } - - s.Field4 = append(s.Field4, tmpField4) - } - if err := reader.EndContainer(); err != nil { - return err - } - - default: - return errors.Errorf("unexpected field `%s`", field) - } - - readFields[field] = struct{}{} - } - - for field := range _TestParentRequiredFields { - if _, ok := readFields[field]; !ok { - return errors.Errorf("field `%s` is not present", field) - } - } - - return nil -} - -func (s *TestParent) readFields(reader meta.Reader) error { - contSize, err := reader.ContainerSize() if err != nil { return err @@ -168,7 +64,13 @@ func (s *TestParent) readFields(reader meta.Reader) error { } contSize-- - if err := s.Field2.Read(reader); err != nil { + if err := reader.BeginContainer("f2"); err != nil { + return err + } + if err := s.Field2.ReadFields(reader); err != nil { + return err + } + if err := reader.EndContainer(); err != nil { return err } @@ -210,7 +112,13 @@ func (s *TestParent) readFields(reader meta.Reader) error { } for ; field4Size > 0; field4Size-- { var tmpField4 TestFoo - if err := tmpField4.Read(reader); err != nil { + if err := reader.BeginContainer(""); err != nil { + return err + } + if err := tmpField4.ReadFields(reader); err != nil { + return err + } + if err := reader.EndContainer(); err != nil { return err } @@ -290,14 +198,6 @@ type TestChild struct { Field string `json:"f" msgpack:"f"` } -var _TestChildRequiredFields map[string]struct{} - -func init() { - _TestChildRequiredFields = map[string]struct{}{ - "f": {}, - } -} - func (s *TestChild) Reset() { s.TestParent.Reset() @@ -318,103 +218,6 @@ func (s *TestChild) Read(reader meta.Reader) error { func (s *TestChild) ReadFields(reader meta.Reader) error { s.Reset() - readAsMap, err := reader.IsContainerAssoc() - if err != nil { - return err - } - if readAsMap { - return s.readFieldsAssociative(reader) - } - return s.readFields(reader) -} - -func (s *TestChild) readFieldsAssociative(reader meta.Reader) error { - size, err := reader.ContainerSize() - if err != nil { - return err - } - - readFields := make(map[string]struct{}, 5) - for ; size > 0; size-- { - var field string - if err := reader.ReadString(&field, ""); err != nil { - return err - } - - switch field { - case "f1": - if err := reader.ReadString(&s.Field1, "f1"); err != nil { - return err - } - - case "f2": - if err := s.Field2.Read(reader); err != nil { - return err - } - - case "f3": - if err := reader.BeginContainer("f3"); err != nil { - return err - } - field3Size, err := reader.ContainerSize() - if err != nil { - return err - } - for ; field3Size > 0; field3Size-- { - var tmpField3 int8 - if err := reader.ReadInt8(&tmpField3, ""); err != nil { - return err - } - - s.Field3 = append(s.Field3, tmpField3) - } - if err := reader.EndContainer(); err != nil { - return err - } - - case "f4": - if err := reader.BeginContainer("f4"); err != nil { - return err - } - field4Size, err := reader.ContainerSize() - if err != nil { - return err - } - for ; field4Size > 0; field4Size-- { - var tmpField4 TestFoo - if err := tmpField4.Read(reader); err != nil { - return err - } - - s.Field4 = append(s.Field4, tmpField4) - } - if err := reader.EndContainer(); err != nil { - return err - } - - case "f": - if err := reader.ReadString(&s.Field, "f"); err != nil { - return err - } - - default: - return errors.Errorf("unexpected field `%s`", field) - } - - readFields[field] = struct{}{} - } - - for field := range _TestChildRequiredFields { - if _, ok := readFields[field]; !ok { - return errors.Errorf("field `%s` is not present", field) - } - } - - return nil -} - -func (s *TestChild) readFields(reader meta.Reader) error { - contSize, err := reader.ContainerSize() if err != nil { return err @@ -466,14 +269,6 @@ type TestFoo struct { Field int64 `json:"field" msgpack:"field"` } -var _TestFooRequiredFields map[string]struct{} - -func init() { - _TestFooRequiredFields = map[string]struct{}{ - "field": {}, - } -} - func (s *TestFoo) Reset() { s.Field = 0 @@ -493,53 +288,6 @@ func (s *TestFoo) Read(reader meta.Reader) error { func (s *TestFoo) ReadFields(reader meta.Reader) error { s.Reset() - readAsMap, err := reader.IsContainerAssoc() - if err != nil { - return err - } - if readAsMap { - return s.readFieldsAssociative(reader) - } - return s.readFields(reader) -} - -func (s *TestFoo) readFieldsAssociative(reader meta.Reader) error { - size, err := reader.ContainerSize() - if err != nil { - return err - } - - readFields := make(map[string]struct{}, 1) - for ; size > 0; size-- { - var field string - if err := reader.ReadString(&field, ""); err != nil { - return err - } - - switch field { - case "field": - if err := reader.ReadInt64(&s.Field, "field"); err != nil { - return err - } - - default: - return errors.Errorf("unexpected field `%s`", field) - } - - readFields[field] = struct{}{} - } - - for field := range _TestFooRequiredFields { - if _, ok := readFields[field]; !ok { - return errors.Errorf("field `%s` is not present", field) - } - } - - return nil -} - -func (s *TestFoo) readFields(reader meta.Reader) error { - contSize, err := reader.ContainerSize() if err != nil { return err