msgpack/decode.go

808 lines
16 KiB
Go

package msgpack
import (
"errors"
"fmt"
"io"
"math"
"reflect"
"time"
"github.com/vmihailenco/bufio"
)
type bufReader interface {
Read([]byte) (int, error)
ReadByte() (byte, error)
UnreadByte() error
Peek(int) ([]byte, error)
ReadN(int) ([]byte, error)
}
func Unmarshal(data []byte, v ...interface{}) error {
buf := bufio.NewBuffer(data)
return NewDecoder(buf).Decode(v...)
}
type Decoder struct {
R bufReader
DecodeMapFunc func(*Decoder) (interface{}, error)
}
func NewDecoder(rd io.Reader) *Decoder {
brd, ok := rd.(bufReader)
if !ok {
brd = bufio.NewReader(rd)
}
return &Decoder{
R: brd,
DecodeMapFunc: decodeMap,
}
}
func (d *Decoder) Decode(v ...interface{}) error {
for _, vv := range v {
if err := d.decode(vv); err != nil {
return err
}
}
return nil
}
func (d *Decoder) decode(iv interface{}) error {
var err error
switch v := iv.(type) {
case *string:
if v != nil {
*v, err = d.DecodeString()
return err
}
case *[]byte:
if v != nil {
*v, err = d.DecodeBytes()
return err
}
case *int:
if v != nil {
*v, err = d.DecodeInt()
return err
}
case *int8:
if v != nil {
*v, err = d.DecodeInt8()
return err
}
case *int16:
if v != nil {
*v, err = d.DecodeInt16()
return err
}
case *int32:
if v != nil {
*v, err = d.DecodeInt32()
return err
}
case *int64:
if v != nil {
*v, err = d.DecodeInt64()
return err
}
case *uint:
if v != nil {
*v, err = d.DecodeUint()
return err
}
case *uint8:
if v != nil {
*v, err = d.DecodeUint8()
return err
}
case *uint16:
if v != nil {
*v, err = d.DecodeUint16()
return err
}
case *uint32:
if v != nil {
*v, err = d.DecodeUint32()
return err
}
case *uint64:
if v != nil {
*v, err = d.DecodeUint64()
return err
}
case *bool:
if v != nil {
*v, err = d.DecodeBool()
return err
}
case *float32:
if v != nil {
*v, err = d.DecodeFloat32()
return err
}
case *float64:
if v != nil {
*v, err = d.DecodeFloat64()
return err
}
case *[]string:
return d.decodeIntoStrings(v)
case *map[string]string:
return d.decodeIntoMapStringString(v)
case *time.Duration:
if v != nil {
vv, err := d.DecodeInt64()
*v = time.Duration(vv)
return err
}
case *time.Time:
if v != nil {
*v, err = d.DecodeTime()
return err
}
}
v := reflect.ValueOf(iv)
if !v.IsValid() {
return errors.New("msgpack: Decode(" + v.String() + ")")
}
if v.Kind() != reflect.Ptr {
return errors.New("msgpack: pointer expected")
}
return d.DecodeValue(v)
}
func (d *Decoder) DecodeValue(v reflect.Value) error {
c, err := d.R.ReadByte()
if err != nil {
return err
}
if c == nilCode {
return nil
}
if err := d.R.UnreadByte(); err != nil {
return err
}
switch v.Kind() {
case reflect.Bool:
return d.boolValue(v)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return d.uint64Value(v)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return d.int64Value(v)
case reflect.Float32:
return d.float32Value(v)
case reflect.Float64:
return d.float64Value(v)
case reflect.String:
return d.stringValue(v)
case reflect.Array, reflect.Slice:
return d.sliceValue(v)
case reflect.Map:
return d.mapValue(v)
case reflect.Struct:
typ := v.Type()
if dec, ok := typDecMap[typ]; ok {
return dec(d, v)
}
if dec, ok := v.Interface().(decoder); ok {
return dec.DecodeMsgpack(d.R)
}
return d.structValue(v)
case reflect.Ptr:
typ := v.Type()
if v.IsNil() {
v.Set(reflect.New(typ.Elem()))
}
if dec, ok := typDecMap[typ]; ok {
return dec(d, v)
}
if dec, ok := v.Interface().(decoder); ok {
return dec.DecodeMsgpack(d.R)
}
return d.DecodeValue(v.Elem())
case reflect.Interface:
if v.IsNil() {
return d.interfaceValue(v)
} else {
return d.DecodeValue(v.Elem())
}
}
return fmt.Errorf("msgpack: unsupported type %v", v.Type().String())
}
func (d *Decoder) DecodeBool() (bool, error) {
c, err := d.R.ReadByte()
if err != nil {
return false, err
}
switch c {
case falseCode:
return false, nil
case trueCode:
return true, nil
}
return false, fmt.Errorf("msgpack: invalid code %x decoding bool", c)
}
func (d *Decoder) boolValue(value reflect.Value) error {
v, err := d.DecodeBool()
if err != nil {
return err
}
value.SetBool(v)
return nil
}
func (d *Decoder) uint16() (uint16, error) {
b, err := d.R.ReadN(2)
if err != nil {
return 0, err
}
return (uint16(b[0]) << 8) | uint16(b[1]), nil
}
func (d *Decoder) uint32() (uint32, error) {
b, err := d.R.ReadN(4)
if err != nil {
return 0, err
}
n := (uint32(b[0]) << 24) |
(uint32(b[1]) << 16) |
(uint32(b[2]) << 8) |
uint32(b[3])
return n, nil
}
func (d *Decoder) uint64() (uint64, error) {
b, err := d.R.ReadN(8)
if err != nil {
return 0, err
}
n := (uint64(b[0]) << 56) |
(uint64(b[1]) << 48) |
(uint64(b[2]) << 40) |
(uint64(b[3]) << 32) |
(uint64(b[4]) << 24) |
(uint64(b[5]) << 16) |
(uint64(b[6]) << 8) |
uint64(b[7])
return n, nil
}
func (d *Decoder) DecodeUint64() (uint64, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c <= posFixNumHighCode {
return uint64(c), nil
}
switch c {
case uint8Code:
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
return uint64(c), nil
case uint16Code:
b, err := d.R.ReadN(2)
if err != nil {
return 0, err
}
return (uint64(b[0]) << 8) | uint64(b[1]), nil
case uint32Code:
b, err := d.R.ReadN(4)
if err != nil {
return 0, err
}
v := (uint64(b[0]) << 24) |
(uint64(b[1]) << 16) |
(uint64(b[2]) << 8) |
uint64(b[3])
return v, nil
case uint64Code:
b, err := d.R.ReadN(8)
if err != nil {
return 0, err
}
v := (uint64(b[0]) << 56) |
(uint64(b[1]) << 48) |
(uint64(b[2]) << 40) |
(uint64(b[3]) << 32) |
(uint64(b[4]) << 24) |
(uint64(b[5]) << 16) |
(uint64(b[6]) << 8) |
uint64(b[7])
return v, nil
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding uint64", c)
}
func (d *Decoder) uint64Value(value reflect.Value) error {
v, err := d.DecodeUint64()
if err != nil {
return err
}
value.SetUint(v)
return nil
}
func (d *Decoder) DecodeInt64() (int64, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c <= posFixNumHighCode || c >= negFixNumLowCode {
return int64(int8(c)), nil
}
switch c {
case int8Code:
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
return int64(int8(c)), nil
case int16Code:
b, err := d.R.ReadN(2)
if err != nil {
return 0, err
}
return int64((int16(b[0]) << 8) | int16(b[1])), nil
case int32Code:
b, err := d.R.ReadN(4)
if err != nil {
return 0, err
}
v := int64((int32(b[0]) << 24) |
(int32(b[1]) << 16) |
(int32(b[2]) << 8) |
int32(b[3]))
return v, nil
case int64Code:
b, err := d.R.ReadN(8)
if err != nil {
return 0, err
}
v := (int64(b[0]) << 56) |
(int64(b[1]) << 48) |
(int64(b[2]) << 40) |
(int64(b[3]) << 32) |
(int64(b[4]) << 24) |
(int64(b[5]) << 16) |
(int64(b[6]) << 8) |
int64(b[7])
return v, nil
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding int64", c)
}
func (d *Decoder) int64Value(value reflect.Value) error {
v, err := d.DecodeInt64()
if err != nil {
return err
}
value.SetInt(v)
return nil
}
func (d *Decoder) DecodeFloat32() (float32, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c != floatCode {
return 0, fmt.Errorf("msgpack: invalid code %x decoding float32", c)
}
b, err := d.uint32()
if err != nil {
return 0, err
}
return math.Float32frombits(b), nil
}
func (d *Decoder) float32Value(value reflect.Value) error {
v, err := d.DecodeFloat32()
if err != nil {
return err
}
value.SetFloat(float64(v))
return nil
}
func (d *Decoder) DecodeFloat64() (float64, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c != doubleCode {
return 0, fmt.Errorf("msgpack: invalid code %x decoding float64", c)
}
b, err := d.uint64()
if err != nil {
return 0, err
}
return math.Float64frombits(b), nil
}
func (d *Decoder) float64Value(value reflect.Value) error {
v, err := d.DecodeFloat64()
if err != nil {
return err
}
value.SetFloat(v)
return nil
}
func (d *Decoder) structLen() (int, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c >= fixMapLowCode && c <= fixMapHighCode {
return int(c & fixMapMask), nil
}
switch c {
case map16Code:
n, err := d.uint16()
return int(n), err
case map32Code:
n, err := d.uint32()
return int(n), err
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding struct length", c)
}
func (d *Decoder) structValue(v reflect.Value) error {
n, err := d.structLen()
if err != nil {
return err
}
typ := v.Type()
for i := 0; i < n; i++ {
name, err := d.DecodeString()
if err != nil {
return err
}
f := structs.Field(typ, name)
if f != nil {
if err := f.DecodeValue(d, v); err != nil {
return err
}
} else {
_, err := d.DecodeInterface()
if err != nil {
return err
}
}
}
return nil
}
//------------------------------------------------------------------------------
func (d *Decoder) DecodeUint() (uint, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c <= posFixNumHighCode {
return uint(c), nil
}
switch c {
case uint8Code:
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
return uint(c), nil
case uint16Code:
b, err := d.R.ReadN(2)
if err != nil {
return 0, err
}
return (uint(b[0]) << 8) | uint(b[1]), nil
case uint32Code:
b, err := d.R.ReadN(4)
if err != nil {
return 0, err
}
v := (uint(b[0]) << 24) |
(uint(b[1]) << 16) |
(uint(b[2]) << 8) |
uint(b[3])
return v, nil
case uint64Code:
b, err := d.R.ReadN(8)
if err != nil {
return 0, err
}
v := (uint(b[0]) << 56) |
(uint(b[1]) << 48) |
(uint(b[2]) << 40) |
(uint(b[3]) << 32) |
(uint(b[4]) << 24) |
(uint(b[5]) << 16) |
(uint(b[6]) << 8) |
uint(b[7])
return v, nil
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding uint", c)
}
func (d *Decoder) DecodeUint8() (uint8, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c <= posFixNumHighCode {
return uint8(c), nil
}
switch c {
case uint8Code:
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
return uint8(c), nil
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding uint8", c)
}
func (d *Decoder) DecodeUint16() (uint16, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c <= posFixNumHighCode {
return uint16(c), nil
}
switch c {
case uint8Code:
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
return uint16(c), nil
case uint16Code:
b, err := d.R.ReadN(2)
if err != nil {
return 0, err
}
return (uint16(b[0]) << 8) | uint16(b[1]), nil
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding uint16", c)
}
func (d *Decoder) DecodeUint32() (uint32, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c <= posFixNumHighCode {
return uint32(c), nil
}
switch c {
case uint8Code:
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
return uint32(c), nil
case uint16Code:
b, err := d.R.ReadN(2)
if err != nil {
return 0, err
}
return (uint32(b[0]) << 8) | uint32(b[1]), nil
case uint32Code:
b, err := d.R.ReadN(4)
if err != nil {
return 0, err
}
v := (uint32(b[0]) << 24) |
(uint32(b[1]) << 16) |
(uint32(b[2]) << 8) |
uint32(b[3])
return v, nil
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding uint32", c)
}
//------------------------------------------------------------------------------
func (d *Decoder) DecodeInt() (int, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c <= posFixNumHighCode || c >= negFixNumLowCode {
return int(int8(c)), nil
}
switch c {
case int8Code:
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
return int(int8(c)), nil
case int16Code:
b, err := d.R.ReadN(2)
if err != nil {
return 0, err
}
return int((int16(b[0]) << 8) | int16(b[1])), nil
case int32Code:
b, err := d.R.ReadN(4)
if err != nil {
return 0, err
}
v := int((int32(b[0]) << 24) |
(int32(b[1]) << 16) |
(int32(b[2]) << 8) |
int32(b[3]))
return v, nil
case int64Code:
b, err := d.R.ReadN(8)
if err != nil {
return 0, err
}
v := int((int64(b[0]) << 56) |
(int64(b[1]) << 48) |
(int64(b[2]) << 40) |
(int64(b[3]) << 32) |
(int64(b[4]) << 24) |
(int64(b[5]) << 16) |
(int64(b[6]) << 8) |
int64(b[7]))
return v, nil
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding int64", c)
}
func (d *Decoder) DecodeInt8() (int8, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c <= posFixNumHighCode || c >= negFixNumLowCode {
return int8(c), nil
}
switch c {
case int8Code:
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
return int8(c), nil
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding int8", c)
}
func (d *Decoder) DecodeInt16() (int16, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c <= posFixNumHighCode || c >= negFixNumLowCode {
return int16(int8(c)), nil
}
switch c {
case int8Code:
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
return int16(int8(c)), nil
case int16Code:
b, err := d.R.ReadN(2)
if err != nil {
return 0, err
}
return (int16(b[0]) << 8) | int16(b[1]), nil
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding int16", c)
}
func (d *Decoder) DecodeInt32() (int32, error) {
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
if c <= posFixNumHighCode || c >= negFixNumLowCode {
return int32(int8(c)), nil
}
switch c {
case int8Code:
c, err := d.R.ReadByte()
if err != nil {
return 0, err
}
return int32(int8(c)), nil
case int16Code:
b, err := d.R.ReadN(2)
if err != nil {
return 0, err
}
return int32((int16(b[0]) << 8) | int16(b[1])), nil
case int32Code:
b, err := d.R.ReadN(4)
if err != nil {
return 0, err
}
v := (int32(b[0]) << 24) |
(int32(b[1]) << 16) |
(int32(b[2]) << 8) |
int32(b[3])
return v, nil
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding int32", c)
}
//------------------------------------------------------------------------------
func (d *Decoder) interfaceValue(v reflect.Value) error {
iface, err := d.DecodeInterface()
if err != nil {
return err
}
v.Set(reflect.ValueOf(iface))
return nil
}
// Decodes value into interface. Possible value types are:
// - nil,
// - int64,
// - uint64,
// - bool,
// - float32 and float64,
// - string,
// - slices of any of the above,
// - maps of any of the above.
func (d *Decoder) DecodeInterface() (interface{}, error) {
b, err := d.R.Peek(1)
if err != nil {
return nil, err
}
c := b[0]
if c <= posFixNumHighCode || c >= negFixNumLowCode {
return d.DecodeInt64()
} else if c >= fixMapLowCode && c <= fixMapHighCode {
return d.DecodeMap()
} else if c >= fixArrayLowCode && c <= fixArrayHighCode {
return d.DecodeSlice()
} else if c >= fixRawLowCode && c <= fixRawHighCode {
return d.DecodeString()
}
switch c {
case nilCode:
_, err := d.R.ReadByte()
return nil, err
case falseCode, trueCode:
return d.DecodeBool()
case floatCode:
return d.DecodeFloat32()
case doubleCode:
return d.DecodeFloat64()
case uint8Code, uint16Code, uint32Code, uint64Code:
return d.DecodeUint64()
case int8Code, int16Code, int32Code, int64Code:
return d.DecodeInt64()
case raw16Code, raw32Code:
return d.DecodeString()
case array16Code, array32Code:
return d.DecodeSlice()
case map16Code, map32Code:
return d.DecodeMap()
}
return 0, fmt.Errorf("msgpack: invalid code %x decoding interface{}", c)
}