msgpack/encode.go

344 lines
6.4 KiB
Go

package msgpack
import (
"bytes"
"fmt"
"io"
"math"
"reflect"
"time"
)
type writer interface {
io.Writer
WriteByte(byte) error
WriteString(string) (int, error)
}
type writeByte struct {
io.Writer
}
func (w *writeByte) WriteByte(b byte) error {
n, err := w.Write([]byte{b})
if err != nil {
return err
}
if n != 1 {
return io.ErrShortWrite
}
return nil
}
func (w *writeByte) WriteString(s string) (int, error) {
return w.Write([]byte(s))
}
func Marshal(v ...interface{}) ([]byte, error) {
buf := &bytes.Buffer{}
err := NewEncoder(buf).Encode(v...)
return buf.Bytes(), err
}
type Encoder struct {
W writer
}
func NewEncoder(w io.Writer) *Encoder {
ww, ok := w.(writer)
if !ok {
ww = &writeByte{Writer: w}
}
return &Encoder{
W: ww,
}
}
func (e *Encoder) Encode(v ...interface{}) error {
for _, vv := range v {
if err := e.encode(vv); err != nil {
return err
}
}
return nil
}
func (e *Encoder) encode(iv interface{}) error {
if iv == nil {
return e.EncodeNil()
}
switch v := iv.(type) {
case string:
return e.EncodeString(v)
case []byte:
return e.EncodeBytes(v)
case int:
return e.EncodeInt64(int64(v))
case int64:
return e.EncodeInt64(v)
case uint:
return e.EncodeUint64(uint64(v))
case uint64:
return e.EncodeUint64(v)
case bool:
return e.EncodeBool(v)
case float32:
return e.EncodeFloat32(v)
case float64:
return e.EncodeFloat64(v)
case []string:
return e.encodeStringSlice(v)
case map[string]string:
return e.encodeMapStringString(v)
case time.Duration:
return e.EncodeInt64(int64(v))
case time.Time:
return e.EncodeTime(v)
case encoder:
return v.EncodeMsgpack(e.W)
}
return e.EncodeValue(reflect.ValueOf(iv))
}
func (e *Encoder) EncodeValue(v reflect.Value) error {
switch v.Kind() {
case reflect.String:
return e.EncodeString(v.String())
case reflect.Bool:
return e.EncodeBool(v.Bool())
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return e.EncodeUint64(v.Uint())
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return e.EncodeInt64(v.Int())
case reflect.Float32:
return e.EncodeFloat32(float32(v.Float()))
case reflect.Float64:
return e.EncodeFloat64(v.Float())
case reflect.Array:
return e.encodeSlice(v)
case reflect.Slice:
if v.IsNil() {
return e.EncodeNil()
}
return e.encodeSlice(v)
case reflect.Map:
return e.encodeMap(v)
case reflect.Interface, reflect.Ptr:
if v.IsNil() {
return e.EncodeNil()
}
if enc, ok := typEncMap[v.Type()]; ok {
return enc(e, v)
}
if enc, ok := v.Interface().(encoder); ok {
return enc.EncodeMsgpack(e.W)
}
return e.EncodeValue(v.Elem())
case reflect.Struct:
typ := v.Type()
if enc, ok := typEncMap[typ]; ok {
return enc(e, v)
}
if enc, ok := v.Interface().(encoder); ok {
return enc.EncodeMsgpack(e.W)
}
return e.encodeStruct(v)
default:
return fmt.Errorf("msgpack: unsupported type %v", v.Type().String())
}
panic("not reached")
}
func (e *Encoder) EncodeNil() error {
return e.W.WriteByte(nilCode)
}
func (e *Encoder) EncodeUint(v uint) error {
return e.EncodeUint64(uint64(v))
}
func (e *Encoder) EncodeUint8(v uint8) error {
return e.EncodeUint64(uint64(v))
}
func (e *Encoder) EncodeUint16(v uint16) error {
return e.EncodeUint64(uint64(v))
}
func (e *Encoder) EncodeUint32(v uint32) error {
return e.EncodeUint64(uint64(v))
}
func (e *Encoder) EncodeUint64(v uint64) error {
switch {
case v < 128:
return e.W.WriteByte(byte(v))
case v < 256:
return e.write([]byte{uint8Code, byte(v)})
case v < 65536:
return e.write([]byte{uint16Code, byte(v >> 8), byte(v)})
case v < 4294967296:
return e.write([]byte{
uint32Code,
byte(v >> 24),
byte(v >> 16),
byte(v >> 8),
byte(v),
})
default:
return e.write([]byte{
uint64Code,
byte(v >> 56),
byte(v >> 48),
byte(v >> 40),
byte(v >> 32),
byte(v >> 24),
byte(v >> 16),
byte(v >> 8),
byte(v),
})
}
panic("not reached")
}
func (e *Encoder) EncodeInt(v int) error {
return e.EncodeInt64(int64(v))
}
func (e *Encoder) EncodeInt8(v int8) error {
return e.EncodeInt64(int64(v))
}
func (e *Encoder) EncodeInt16(v int16) error {
return e.EncodeInt64(int64(v))
}
func (e *Encoder) EncodeInt32(v int32) error {
return e.EncodeInt64(int64(v))
}
func (e *Encoder) EncodeInt64(v int64) error {
switch {
case v < -2147483648 || v >= 2147483648:
return e.write([]byte{
int64Code,
byte(v >> 56),
byte(v >> 48),
byte(v >> 40),
byte(v >> 32),
byte(v >> 24),
byte(v >> 16),
byte(v >> 8),
byte(v),
})
case v < -32768 || v >= 32768:
return e.write([]byte{
int32Code,
byte(v >> 24),
byte(v >> 16),
byte(v >> 8),
byte(v),
})
case v < -128 || v >= 128:
return e.write([]byte{int16Code, byte(v >> 8), byte(v)})
case v < -32:
return e.write([]byte{int8Code, byte(v)})
default:
return e.W.WriteByte(byte(v))
}
panic("not reached")
}
func (e *Encoder) EncodeBool(value bool) error {
if value {
return e.W.WriteByte(trueCode)
}
return e.W.WriteByte(falseCode)
}
func (e *Encoder) EncodeFloat32(value float32) error {
v := math.Float32bits(value)
return e.write([]byte{
floatCode,
byte(v >> 24),
byte(v >> 16),
byte(v >> 8),
byte(v),
})
}
func (e *Encoder) EncodeFloat64(value float64) error {
v := math.Float64bits(value)
return e.write([]byte{
doubleCode,
byte(v >> 56),
byte(v >> 48),
byte(v >> 40),
byte(v >> 32),
byte(v >> 24),
byte(v >> 16),
byte(v >> 8),
byte(v),
})
}
func (e *Encoder) encodeStruct(v reflect.Value) error {
fields := structs.Fields(v.Type())
switch l := len(fields); {
case l < 16:
if err := e.W.WriteByte(fixMapLowCode | byte(l)); err != nil {
return err
}
case l < 65536:
if err := e.write([]byte{
map16Code,
byte(l >> 8),
byte(l),
}); err != nil {
return err
}
default:
if err := e.write([]byte{
map32Code,
byte(l >> 24),
byte(l >> 16),
byte(l >> 8),
byte(l),
}); err != nil {
return err
}
}
for _, f := range fields {
if err := e.EncodeString(f.Name()); err != nil {
return err
}
if err := f.EncodeValue(e, v); err != nil {
return err
}
}
return nil
}
func (e *Encoder) write(data []byte) error {
n, err := e.W.Write(data)
if err != nil {
return err
}
if n < len(data) {
return io.ErrShortWrite
}
return nil
}
func (e *Encoder) writeString(s string) error {
n, err := e.W.WriteString(s)
if err != nil {
return err
}
if n < len(s) {
return io.ErrShortWrite
}
return nil
}