package meta import ( "bytes" "io" "github.com/pkg/errors" "github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5/msgpcode" ) var FieldNotFound = errors.New("field not found") var NoOpenContainer = errors.New("there is no open container") type msgpackReader struct { dec *msgpack.Decoder stack []readContainer curr readContainer } type readContainer struct { started bool length int assoc bool values map[string]msgpack.RawMessage reader io.Reader readCnt int } func NewMsgpackReader(r io.Reader) Reader { return &msgpackReader{ dec: msgpack.NewDecoder(r), stack: make([]readContainer, 0, 2), curr: readContainer{ reader: r, }, } } func (rd *msgpackReader) readField() (string, error) { field, err := rd.dec.DecodeString() if err != nil { return "", errors.WithStack(err) } return field, nil } func (rd *msgpackReader) ReadInt8(v *int8, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeInt8(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeInt8(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeInt8(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadInt16(v *int16, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeInt16(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeInt16(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeInt16(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadInt32(v *int32, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeInt32(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeInt32(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeInt32(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadInt64(v *int64, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeInt64(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeInt64(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeInt64(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadUint8(v *uint8, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeUint8(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeUint8(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeUint8(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadUint16(v *uint16, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeUint16(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeUint16(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeUint16(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadUint32(v *uint32, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeUint32(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeUint32(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeUint32(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadUint64(v *uint64, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeUint64(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeUint64(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeUint64(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadBool(v *bool, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeBool(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeBool(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeBool(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadFloat32(v *float32, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeFloat32(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeFloat32(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeFloat32(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadFloat64(v *float64, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeFloat64(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeFloat64(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeFloat64(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadString(v *string, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeString(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeString(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeString(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) ReadBytes(v *[]byte, targetField string) error { if !rd.curr.started || !rd.curr.assoc { return decodeBytes(rd.dec, v) } if b, ok := rd.curr.values[targetField]; ok { dec := msgpack.NewDecoder(bytes.NewReader(b)) return decodeBytes(dec, v) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return decodeBytes(rd.dec, v) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw } return FieldNotFound } func (rd *msgpackReader) BeginContainer(targetField string) error { if !rd.curr.started || !rd.curr.assoc { return rd.beginContainer(targetField) } if b, ok := rd.curr.values[targetField]; ok { rd.dec.Reset(bytes.NewReader(b)) return rd.beginContainer(targetField) } for i := rd.curr.readCnt; i < rd.curr.length; i++ { field, err := rd.readField() if err != nil { return err } if field == targetField { rd.curr.readCnt = i + 1 return rd.beginContainer(targetField) } raw, err := rd.dec.DecodeRaw() if err != nil { return errors.WithStack(err) } rd.curr.values[field] = raw rd.curr.readCnt = i + 1 } return FieldNotFound } func (rd *msgpackReader) beginContainer(field string) error { code, err := rd.dec.PeekCode() if err != nil { return errors.WithStack(err) } switch { case msgpcode.IsFixedMap(code), code == msgpcode.Map16, code == msgpcode.Map32: l, err := rd.dec.DecodeMapLen() if err != nil { return errors.WithStack(err) } rd.stack = append(rd.stack, rd.curr) rd.curr = readContainer{ started: true, length: l, assoc: true, values: make(map[string]msgpack.RawMessage, l), reader: rd.dec.Buffered(), } case msgpcode.IsFixedArray(code), code == msgpcode.Array16, code == msgpcode.Array32: l, err := rd.dec.DecodeArrayLen() if err != nil { return errors.WithStack(err) } rd.stack = append(rd.stack, rd.curr) rd.curr = readContainer{ started: true, length: l, assoc: false, reader: rd.dec.Buffered(), } default: return errors.Errorf("there is no container for field `%s`", field) } return nil } func (rd *msgpackReader) EndContainer() error { if len(rd.stack) == 0 { return NoOpenContainer } rd.curr = rd.stack[len(rd.stack)-1] rd.stack = rd.stack[:len(rd.stack)-1] rd.dec.Reset(rd.curr.reader) return nil } func (rd *msgpackReader) ContainerSize() (int, error) { return rd.curr.length, nil } func (rd *msgpackReader) IsContainerAssoc() (bool, error) { return rd.curr.assoc, nil } func (rd *msgpackReader) Skip() error { return errors.WithStack(rd.dec.Skip()) } func (rd *msgpackReader) TryReadMask() (bool, FieldsMask, error) { maskLen, err := rd.dec.DecodeArrayLen() if err != nil { if err == io.EOF { return false, FieldsMask{}, nil } return false, FieldsMask{}, errors.WithStack(err) } var mask FieldsMask for i := 0; i < maskLen; i++ { maskPart, err := rd.dec.DecodeUint64() if err != nil { return false, FieldsMask{}, errors.WithStack(err) } mask.SetPartFromUint64(i, maskPart) } return true, mask, nil } func decodeUint8(dec *msgpack.Decoder, v *uint8) error { tmp, err := dec.DecodeUint8() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeUint16(dec *msgpack.Decoder, v *uint16) error { tmp, err := dec.DecodeUint16() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeUint32(dec *msgpack.Decoder, v *uint32) error { tmp, err := dec.DecodeUint32() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeUint64(dec *msgpack.Decoder, v *uint64) error { tmp, err := dec.DecodeUint64() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeInt8(dec *msgpack.Decoder, v *int8) error { tmp, err := dec.DecodeInt8() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeInt16(dec *msgpack.Decoder, v *int16) error { tmp, err := dec.DecodeInt16() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeInt32(dec *msgpack.Decoder, v *int32) error { tmp, err := dec.DecodeInt32() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeInt64(dec *msgpack.Decoder, v *int64) error { tmp, err := dec.DecodeInt64() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeBool(dec *msgpack.Decoder, v *bool) error { tmp, err := dec.DecodeBool() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeFloat32(dec *msgpack.Decoder, v *float32) error { tmp, err := dec.DecodeFloat32() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeFloat64(dec *msgpack.Decoder, v *float64) error { tmp, err := dec.DecodeFloat64() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeString(dec *msgpack.Decoder, v *string) error { tmp, err := dec.DecodeString() if err != nil { return errors.WithStack(err) } *v = tmp return nil } func decodeBytes(dec *msgpack.Decoder, v *[]byte) error { tmp, err := dec.DecodeBytes() if err != nil { return errors.WithStack(err) } *v = tmp return nil }