meta/msgpack_reader.go

236 lines
4.8 KiB
Go

package meta
import (
"io"
"github.com/pkg/errors"
"github.com/vmihailenco/msgpack/v5"
"github.com/vmihailenco/msgpack/v5/msgpcode"
)
type msgpackReader struct {
dec *msgpack.Decoder
containers []msgpackContainer
}
type msgpackContainer struct {
length int
assoc bool
}
func NewMsgpackReader(r io.Reader) Reader {
return &msgpackReader{
dec: msgpack.NewDecoder(r),
containers: make([]msgpackContainer, 0, 2),
}
}
func (rd *msgpackReader) currentContainer() msgpackContainer {
if rd == nil || len(rd.containers) == 0 {
return msgpackContainer{}
}
return rd.containers[len(rd.containers)-1]
}
func (rd *msgpackReader) ReadInt8(v *int8, field string) error {
tmp, err := rd.dec.DecodeInt8()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadInt16(v *int16, field string) error {
tmp, err := rd.dec.DecodeInt16()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadInt32(v *int32, field string) error {
tmp, err := rd.dec.DecodeInt32()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadInt64(v *int64, field string) error {
tmp, err := rd.dec.DecodeInt64()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadUint8(v *uint8, field string) error {
tmp, err := rd.dec.DecodeUint8()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadUint16(v *uint16, field string) error {
tmp, err := rd.dec.DecodeUint16()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadUint32(v *uint32, field string) error {
tmp, err := rd.dec.DecodeUint32()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadUint64(v *uint64, field string) error {
tmp, err := rd.dec.DecodeUint64()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadBool(v *bool, field string) error {
tmp, err := rd.dec.DecodeBool()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadFloat32(v *float32, field string) error {
tmp, err := rd.dec.DecodeFloat32()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadFloat64(v *float64, field string) error {
tmp, err := rd.dec.DecodeFloat64()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadString(v *string, field string) error {
tmp, err := rd.dec.DecodeString()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) ReadBytes(v *[]byte, field string) error {
tmp, err := rd.dec.DecodeBytes()
if err != nil {
return errors.WithStack(err)
}
*v = tmp
return nil
}
func (rd *msgpackReader) BeginContainer(field string) error {
code, err := rd.dec.PeekCode()
if err != nil {
return errors.WithStack(err)
}
switch {
case msgpcode.IsFixedMap(code), code == msgpcode.Map16, code == msgpcode.Map32:
l, err := rd.dec.DecodeMapLen()
if err != nil {
return errors.WithStack(err)
}
rd.containers = append(rd.containers, msgpackContainer{
length: l,
assoc: true,
})
case msgpcode.IsFixedArray(code), code == msgpcode.Array16, code == msgpcode.Array32:
l, err := rd.dec.DecodeArrayLen()
if err != nil {
return errors.WithStack(err)
}
rd.containers = append(rd.containers, msgpackContainer{
length: l,
assoc: false,
})
default:
return errors.Errorf("there is no container for field `%s`", field)
}
return nil
}
func (rd *msgpackReader) EndContainer() error {
if len(rd.containers) == 0 {
return errors.New("there is no open containers")
}
rd.containers = rd.containers[:len(rd.containers)-1]
return nil
}
func (rd *msgpackReader) ContainerSize() (int, error) {
return rd.currentContainer().length, nil
}
func (rd *msgpackReader) IsContainerAssoc() (bool, error) {
return rd.currentContainer().assoc, nil
}
func (rd *msgpackReader) Skip() error {
return errors.WithStack(rd.dec.Skip())
}
func (rd *msgpackReader) TryReadMask() (bool, FieldsMask, error) {
code, err := rd.dec.PeekCode()
if err != nil {
return false, FieldsMask{}, errors.WithStack(err)
}
if code != msgpcode.Nil {
return false, FieldsMask{}, nil
}
if err := rd.dec.Skip(); err != nil {
return false, FieldsMask{}, errors.WithStack(err)
}
var mask FieldsMask
maskLen, err := rd.dec.DecodeArrayLen()
if err != nil {
return false, FieldsMask{}, errors.WithStack(err)
}
for i := 0; i < maskLen; i++ {
maskPart, err := rd.dec.DecodeUint64()
if err != nil {
return false, FieldsMask{}, errors.WithStack(err)
}
mask.SetPartFromUint64(i, maskPart)
}
return true, mask, nil
}