From e330a36104e0c23829ab0d72abfc496d44779b83 Mon Sep 17 00:00:00 2001 From: Pavel Merzlyakov Date: Sat, 1 Oct 2022 21:09:54 +0300 Subject: [PATCH] initial commit --- fields_mask.go | 61 +++++++ fields_mask_test.go | 115 +++++++++++++ go.mod | 16 ++ go.sum | 29 ++++ interface.go | 99 +++++++++++ meta.go | 64 +++++++ msgpack_reader.go | 412 ++++++++++++++++++++++++++++++++++++++++++++ msgpack_writer.go | 129 ++++++++++++++ 8 files changed, 925 insertions(+) create mode 100644 fields_mask.go create mode 100644 fields_mask_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 interface.go create mode 100644 meta.go create mode 100644 msgpack_reader.go create mode 100644 msgpack_writer.go diff --git a/fields_mask.go b/fields_mask.go new file mode 100644 index 0000000..02e0f4d --- /dev/null +++ b/fields_mask.go @@ -0,0 +1,61 @@ +package meta + +const ( + FieldsMaskCapacity = 4 + FieldsMaskItemBitSize = 64 +) + +func MakeFieldsMaskFromInt64(v int64) FieldsMask { + var mask FieldsMask + mask.SetItemFromInt64(0, v) + + return mask +} + +type FieldsMaskItem uint64 + +func (fmi FieldsMaskItem) FieldIsDirty(index uint64) bool { + return (1< 0 { + return true + } + } + return false +} + +func (fm FieldsMask) itemIndex(index uint64) uint64 { + return index / FieldsMaskItemBitSize +} + +func (fm FieldsMask) maskIndex(index uint64) uint64 { + return index % FieldsMaskItemBitSize +} diff --git a/fields_mask_test.go b/fields_mask_test.go new file mode 100644 index 0000000..f56e102 --- /dev/null +++ b/fields_mask_test.go @@ -0,0 +1,115 @@ +package meta_test + +import ( + "testing" + + "git.bit5.ru/backend/meta" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFieldsMask(t *testing.T) { + t.Run("IsFilled", func(t *testing.T) { + t.Run("default value", func(t *testing.T) { + var mask meta.FieldsMask + + actualIsFilled := mask.IsFilled() + require.False(t, actualIsFilled) + }) + t.Run("filled value", func(t *testing.T) { + var mask meta.FieldsMask + mask.SetItemFromInt64(0, 1) + + actualIsFilled := mask.IsFilled() + require.True(t, actualIsFilled) + }) + }) + + t.Run("FieldChanged", func(t *testing.T) { + t.Run("field not changed", func(t *testing.T) { + var mask meta.FieldsMask + var fieldIndex uint64 = 4 + + fieldChanged := mask.FieldChanged(fieldIndex) + require.False(t, fieldChanged) + }) + t.Run("filled changed #1", func(t *testing.T) { + var mask meta.FieldsMask + var fieldIndex uint64 = 4 + + mask.SetItemFromInt64(0, 16) + + fieldChanged := mask.FieldChanged(fieldIndex) + require.True(t, fieldChanged) + }) + t.Run("filled changed #2", func(t *testing.T) { + var mask meta.FieldsMask + var fieldIndex uint64 = 68 + + mask.SetItemFromInt64(1, 16) + + fieldChanged := mask.FieldChanged(fieldIndex) + require.True(t, fieldChanged) + }) + t.Run("filled changed #3", func(t *testing.T) { + var mask meta.FieldsMask + var fieldIndex uint64 = 131 + + mask.SetItemFromInt64(2, 8) + + fieldChanged := mask.FieldChanged(fieldIndex) + require.True(t, fieldChanged) + }) + t.Run("filled changed #4", func(t *testing.T) { + var mask meta.FieldsMask + var fieldIndex uint64 = 194 + + mask.SetItemFromInt64(3, 4) + + fieldChanged := mask.FieldChanged(fieldIndex) + require.True(t, fieldChanged) + }) + }) +} + +func TestFieldsMaskItem(t *testing.T) { + t.Run("FieldIsDirty", func(t *testing.T) { + cases := []struct { + maskItem meta.FieldsMaskItem + expectedDirtyIndexes []uint64 + }{ + { + maskItem: 0, + expectedDirtyIndexes: []uint64{}, + }, + { + maskItem: 1, // 0b0001 + expectedDirtyIndexes: []uint64{0}, + }, + { + maskItem: 2, // 0b0010 + expectedDirtyIndexes: []uint64{1}, + }, + { + maskItem: 10, // 0b1010 + expectedDirtyIndexes: []uint64{1, 3}, + }, + { + maskItem: 11, // 0b1011 + expectedDirtyIndexes: []uint64{0, 1, 3}, + }, + } + + for i, c := range cases { + actualDirtyIndexes := make([]uint64, 0, meta.FieldsMaskItemBitSize) + for j := uint64(0); j < meta.FieldsMaskItemBitSize; j++ { + if c.maskItem.FieldIsDirty(j) { + actualDirtyIndexes = append(actualDirtyIndexes, j) + } + } + + assert.Equalf(t, c.expectedDirtyIndexes, actualDirtyIndexes, "case #%d", i) + } + }) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a64fdf8 --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module git.bit5.ru/backend/meta + +go 1.18 + +require ( + git.bit5.ru/backend/msgpack v1.0.0 + github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.7.3 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/vmihailenco/bufio v0.0.0-20140618134113-fe7b595919de // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3252ab5 --- /dev/null +++ b/go.sum @@ -0,0 +1,29 @@ +git.bit5.ru/backend/msgpack v1.0.0 h1:D7sFCFjSN1ADUaESjrRVIWY9TGVATq5i08eKn0ep6ZE= +git.bit5.ru/backend/msgpack v1.0.0/go.mod h1:Syf8E+3pr9z3TropB/eN4PJUekRg5ZD/0sHydHH17r0= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.3 h1:dAm0YRdRQlWojc3CrCRgPBzG5f941d0zvAKu7qY4e+I= +github.com/stretchr/testify v1.7.3/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= +github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= +github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/vmihailenco/bufio v0.0.0-20140618134113-fe7b595919de h1:U+I4zEVstMdfNES/2UO8iqkIf214SDMRhdaFTE3A5rA= +github.com/vmihailenco/bufio v0.0.0-20140618134113-fe7b595919de/go.mod h1:ghSGoeEoFFkXNguSget72dMA0+OLq3AGZiqRohVojxI= +gopkg.in/bufio.v1 v1.0.0-20140618132640-567b2bfa514e h1:wGA78yza6bu/mWcc4QfBuIEHEtc06xdiU0X8sY36yUU= +gopkg.in/bufio.v1 v1.0.0-20140618132640-567b2bfa514e/go.mod h1:xsQCaysVCudhrYTfzYWe577fCe7Ceci+6qjO2Rdc0Z4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +launchpad.net/gocheck v0.0.0-20140225173054-000000000087 h1:Izowp2XBH6Ya6rv+hqbceQyw/gSGoXfH/UPoTGduL54= +launchpad.net/gocheck v0.0.0-20140225173054-000000000087/go.mod h1:hj7XX3B/0A+80Vse0e+BUHsHMTEhd0O4cpUHr/e/BUM= diff --git a/interface.go b/interface.go new file mode 100644 index 0000000..5d0b03a --- /dev/null +++ b/interface.go @@ -0,0 +1,99 @@ +package meta + +type Reader interface { + ReadI8(v *int8, field string) error + ReadU8(v *uint8, field string) error + ReadI16(v *int16, field string) error + ReadU16(v *uint16, field string) error + ReadI32(v *int32, field string) error + ReadU32(v *uint32, field string) error + ReadI64(v *int64, field string) error + ReadU64(v *uint64, field string) error + ReadBool(v *bool, field string) error + ReadFloat(v *float32, field string) error + ReadDouble(v *float64, field string) error + ReadString(v *string, field string) error + ReadBlob(v *[]byte, field string) error + BeginContainer(field string) error + EndContainer() error + GetContainerSize() (int, error) + Skip() error + TryReadMask() (bool, FieldsMask, error) +} + +type Writer interface { + WriteI8(v int8, field string) error + WriteU8(v uint8, field string) error + WriteI16(v int16, field string) error + WriteU16(v uint16, field string) error + WriteI32(v int32, field string) error + WriteU32(v uint32, field string) error + WriteU64(v uint64, field string) error + WriteI64(v int64, field string) error + WriteBool(v bool, field string) error + WriteFloat(v float32, field string) error + WriteDouble(v float64, field string) error + WriteString(v string, field string) error + WriteBlob(v []byte, field string) error + BeginContainer(field string) + EndContainer() error + GetData() ([]byte, error) +} + +type ClassFieldsProps map[string]map[string]string + +type IClassProps interface { + CLASS_ID() uint32 + CLASS_NAME() string + CLASS_PROPS() *map[string]string + CLASS_FIELDS() []string + CLASS_FIELDS_PROPS() *ClassFieldsProps +} + +type MetaFactory func(classId uint32) (IMetaStruct, error) + +type IMetaStruct interface { + IClassProps + Read(reader Reader) error + Write(writer Writer) error + ReadFields(reader Reader) error + WriteFields(writer Writer) error + Reset() +} + +type IMetaDataItem interface { + IClassProps + GetDbTableName() string + GetDbFields() []string + GetOwnerFieldName() string + GetIdFieldName() string + GetIdValue() uint64 + Import(interface{}) + Export([]interface{}) + NewInstance() IMetaDataItem +} + +type IRemovedIds interface { + GetList(classId uint32) []uint64 + Add(classId uint32, id uint64) + HasList(classId uint32) bool +} + +type IBitmasked interface { + SetFieldChanged(index uint64) + HasValue(index uint64) bool + IsMaskFilled() bool + GetMask() FieldsMask +} + +type RPCFactory func(classId uint32) (IRPC, error) + +type IRPC interface { + GetCode() int32 + GetName() string + GetRequest() IMetaStruct + GetResponse() IMetaStruct + SetError(int32, string) + GetError() (int32, string) + Execute(interface{}) error +} diff --git a/meta.go b/meta.go new file mode 100644 index 0000000..2398539 --- /dev/null +++ b/meta.go @@ -0,0 +1,64 @@ +package meta + +func Read(reader Reader, createById MetaFactory) (v IMetaStruct, err error) { + return ReadStructGeneric(reader, createById, "") +} + +func Write(writer Writer, m IMetaStruct) error { + return WriteStructGeneric(writer, m, "") +} + +func ReadStructGeneric(reader Reader, createById MetaFactory, field string) (v IMetaStruct, err error) { + if err = reader.BeginContainer(field); err != nil { + return + } + var classId uint32 + if err = reader.ReadU32(&classId, ""); err != nil { + return + } + if v, err = createById(classId); err != nil { + return + } + if err = v.ReadFields(reader); err != nil { + return + } + err = reader.EndContainer() + return +} + +func WriteStructGeneric(writer Writer, m IMetaStruct, field string) error { + writer.BeginContainer(field) + if err := writer.WriteU32(m.CLASS_ID(), ""); err != nil { + return err + } + if err := m.WriteFields(writer); err != nil { + return err + } + return writer.EndContainer() +} + +func ReadStruct(reader Reader, m IMetaStruct, field string) error { + err := reader.BeginContainer(field) + if err != nil { + return err + } + err = m.ReadFields(reader) + if err != nil { + return err + } + return reader.EndContainer() +} + +func WriteStruct(writer Writer, m IMetaStruct, field string) error { + writer.BeginContainer(field) + err := m.WriteFields(writer) + if err != nil { + return err + } + return writer.EndContainer() +} + +func IsRPCFailed(rpc IRPC) bool { + code, _ := rpc.GetError() + return code != 0 +} diff --git a/msgpack_reader.go b/msgpack_reader.go new file mode 100644 index 0000000..2ebbebb --- /dev/null +++ b/msgpack_reader.go @@ -0,0 +1,412 @@ +package meta + +import ( + "github.com/pkg/errors" + + "git.bit5.ru/backend/msgpack" +) + +type msgpReader struct { + stack []*msgpReadState + current *msgpReadState +} + +type msgpReadState struct { + data []interface{} + idx int +} + +func NewMsgpackReader(bytes []byte) (Reader, error) { + arr := make([]interface{}, 1) + err := msgpack.Unmarshal(bytes, &arr[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + rd := &msgpReader{} + rd.stack = make([]*msgpReadState, 0, 2) + rd.current = &msgpReadState{arr, 0} + rd.stack = append(rd.stack, rd.current) + + return rd, nil +} + +func (state *msgpReadState) clear() { + state.data = nil +} + +func (state *msgpReadState) read(field string) (v interface{}, err error) { + if len(state.data) <= state.idx { + err = errors.Errorf("No more data for read, field:%s", field) + return + } + v = state.data[state.idx] + state.idx++ + return +} + +func (state *msgpReadState) backCursor() error { + if state.data == nil { + return errors.New("No more data for read") + } + state.idx-- + return nil +} + +func (state *msgpReadState) skip() error { + if state.data == nil || len(state.data) <= state.idx { + return errors.New("No more data for read") + } + state.idx++ + return nil +} + +func (state *msgpReadState) size() (int, error) { + return len(state.data) - state.idx, nil // relative to current idx? +} + +func (rd *msgpReader) ReadI8(v *int8, field string) (err error) { + var c int32 + err = rd.ReadI32(&c, field) + *v = int8(c) + return +} + +func (rd *msgpReader) ReadU8(v *uint8, field string) (err error) { + var c uint32 + err = rd.ReadU32(&c, field) + *v = uint8(c) + return +} + +func (rd *msgpReader) ReadI16(v *int16, field string) (err error) { + var c int32 + err = rd.ReadI32(&c, field) + *v = int16(c) + return +} + +func (rd *msgpReader) ReadU16(v *uint16, field string) (err error) { + var c uint32 + err = rd.ReadU32(&c, field) + *v = uint16(c) + return +} + +func (rd *msgpReader) ReadI32(v *int32, field string) error { + value, err := rd.current.read(field) + if err != nil { + return err + } + switch t := value.(type) { + case float64: + *v = int32(t) + case int8: + *v = int32(t) + case uint8: + *v = int32(t) + case int16: + *v = int32(t) + case uint16: + *v = int32(t) + case int32: + *v = int32(t) + case uint32: + *v = int32(t) + case int64: + *v = int32(t) + case uint64: + *v = int32(t) + case bool: + if t { + *v = int32(1) + } else { + *v = int32(0) + } + default: + return errors.Errorf("Can't convert to int32 %v (%T), field:%s", t, t, field) + } + return nil +} + +func (rd *msgpReader) ReadU32(v *uint32, field string) error { + value, err := rd.current.read(field) + if err != nil { + return err + } + switch t := value.(type) { + case float64: + *v = uint32(t) + case int8: + *v = uint32(t) + case uint8: + *v = uint32(t) + case int16: + *v = uint32(t) + case uint16: + *v = uint32(t) + case int32: + *v = uint32(t) + case uint32: + *v = uint32(t) + case uint64: + *v = uint32(t) + case int64: + *v = uint32(t) + case bool: + if t { + *v = uint32(1) + } else { + *v = uint32(0) + } + default: + return errors.Errorf("Can't convert to uint32 (%T), field:%s", t, field) + } + return nil +} + +func (rd *msgpReader) ReadI64(v *int64, field string) error { + value, err := rd.current.read(field) + if err != nil { + return err + } + switch t := value.(type) { + case float64: + *v = int64(t) + case int8: + *v = int64(t) + case uint8: + *v = int64(t) + case int16: + *v = int64(t) + case uint16: + *v = int64(t) + case int32: + *v = int64(t) + case uint32: + *v = int64(t) + case uint64: + *v = int64(t) + case int64: + *v = int64(t) + default: + return errors.Errorf("Can't convert to uint64 (%T), field:%s", t, field) + } + return nil +} + +func (rd *msgpReader) ReadU64(v *uint64, field string) error { + value, err := rd.current.read(field) + if err != nil { + return err + } + switch t := value.(type) { + case float64: + *v = uint64(t) + case int8: + *v = uint64(t) + case uint8: + *v = uint64(t) + case int16: + *v = uint64(t) + case uint16: + *v = uint64(t) + case int32: + *v = uint64(t) + case uint32: + *v = uint64(t) + case uint64: + *v = uint64(t) + case int64: + *v = uint64(t) + default: + return errors.Errorf("Can't convert to uint64 (%T), field:%s", t, field) + } + return nil +} + +func (rd *msgpReader) ReadBool(v *bool, field string) (err error) { + value, err := rd.current.read(field) + if err != nil { + return err + } + switch t := value.(type) { + case bool: + *v = bool(t) + default: + return errors.Errorf("Can't convert to bool (%T), field:%s", t, field) + } + return nil +} + +func (rd *msgpReader) ReadFloat(v *float32, field string) error { + value, err := rd.current.read(field) + if err != nil { + return err + } + switch t := value.(type) { + case float64: + //NOTE: possible type coersion error! + *v = float32(t) + case float32: + *v = t + case int64: + *v = float32(t) + case uint64: + *v = float32(t) + default: + return errors.Errorf("Can't convert to float32 (%T), field:%s", t, field) + } + return nil +} + +func (rd *msgpReader) ReadDouble(v *float64, field string) error { + value, err := rd.current.read(field) + if err != nil { + return err + } + switch t := value.(type) { + case float64: + *v = t + case float32: + *v = float64(t) + case int64: + *v = float64(t) + case uint64: + *v = float64(t) + default: + return errors.Errorf("Can't convert to float64 (%T), field:%s", t, field) + } + return nil +} + +func (rd *msgpReader) ReadString(v *string, field string) error { + value, err := rd.current.read(field) + if err != nil { + return err + } + switch t := value.(type) { + case string: + *v = t + + default: + return errors.Errorf("Can't convert to string %s (%T)", field, t) + } + return nil +} + +func (rd *msgpReader) ReadBlob(v *[]byte, field string) error { + value, err := rd.current.read(field) + if err != nil { + return err + } + switch t := value.(type) { + case []byte: + *v = t + + default: + return errors.Errorf("Can't convert to []byte %s (%T)", field, t) + } + return nil +} + +func (rd *msgpReader) BeginContainer(field string) error { + v, err := rd.current.read(field) + if err != nil { + return err + } + + if data, ok := v.([]interface{}); ok { + cur := &msgpReadState{} + cur.data = data + rd.stack = append(rd.stack, cur) + rd.current = cur + return nil + } else { + return errors.Errorf("Next value isn't array but:%T, field:%s", v, field) + } +} + +func (rd *msgpReader) EndContainer() error { + rd.current.clear() + rd.stack[len(rd.stack)-1] = nil + rd.stack = rd.stack[0 : len(rd.stack)-1] + if len(rd.stack) > 0 { + rd.current = rd.stack[len(rd.stack)-1] + } else { + rd.current = nil + } + return nil +} + +func (rd *msgpReader) GetContainerSize() (int, error) { + return rd.current.size() +} + +func (rd *msgpReader) Skip() error { + return rd.current.skip() +} + +func (rd *msgpReader) TryReadMask() (bool, FieldsMask, error) { + value, err := rd.current.read("mask") + if err != nil { + return false, FieldsMask{}, err + } + + if value == nil { //mask detected + newMask, err := rd.ReadNewMask() + if err == nil { + return true, newMask, nil + } + + if err := rd.current.backCursor(); err != nil { + return false, FieldsMask{}, err + } + + oldMask, err := rd.ReadOldMask() + if err != nil { + return false, FieldsMask{}, err + } + return true, oldMask, nil + } + + if err := rd.current.backCursor(); err != nil { + return false, FieldsMask{}, err + } + + return false, FieldsMask{}, nil +} + +func (rd *msgpReader) ReadNewMask() (FieldsMask, error) { + if err := rd.BeginContainer("new_mask"); err != nil { + return FieldsMask{}, err + } + maskSize, err := rd.GetContainerSize() + if err != nil { + return FieldsMask{}, err + } + + var mask FieldsMask + for i := 0; i < maskSize; i++ { + var maskItem uint64 + if err := rd.ReadU64(&maskItem, "mask_item"); err != nil { + return FieldsMask{}, err + } + mask.SetItemFromUint64(i, maskItem) + } + + if err := rd.EndContainer(); err != nil { + return FieldsMask{}, err + } + + return mask, nil +} + +func (rd *msgpReader) ReadOldMask() (FieldsMask, error) { + var mask int64 + if err := rd.ReadI64(&mask, "mask"); err != nil { + return FieldsMask{}, err + } + + return MakeFieldsMaskFromInt64(mask), nil +} diff --git a/msgpack_writer.go b/msgpack_writer.go new file mode 100644 index 0000000..b8316bd --- /dev/null +++ b/msgpack_writer.go @@ -0,0 +1,129 @@ +package meta + +import ( + "io" + + "github.com/pkg/errors" + + "git.bit5.ru/backend/msgpack" +) + +type msgpWriter struct { + stack []*msgpWriteContainer + current *msgpWriteContainer +} + +type msgpWriteContainer struct { + arr []interface{} +} + +func NewMsgpackWriter() Writer { + wr := &msgpWriter{} + wr.stack = make([]*msgpWriteContainer, 0, 1) + wr.current = newMsgpWriteContainer() + wr.stack = append(wr.stack, wr.current) + return wr +} + +func newMsgpWriteContainer() *msgpWriteContainer { + return &msgpWriteContainer{make([]interface{}, 0, 1)} +} + +func (state *msgpWriteContainer) add(v interface{}) error { + state.arr = append(state.arr, v) + return nil +} + +func (wr *msgpWriter) WriteI8(v int8, field string) error { + return wr.current.add(v) +} + +func (wr *msgpWriter) WriteU8(v uint8, field string) error { + return wr.current.add(v) +} +func (wr *msgpWriter) WriteI16(v int16, field string) error { + return wr.current.add(v) +} +func (wr *msgpWriter) WriteU16(v uint16, field string) error { + return wr.current.add(v) +} + +func (wr *msgpWriter) WriteI32(v int32, field string) error { + return wr.current.add(v) +} +func (wr *msgpWriter) WriteU32(v uint32, field string) error { + return wr.current.add(v) +} + +func (wr *msgpWriter) WriteI64(v int64, field string) error { + return wr.current.add(v) +} +func (wr *msgpWriter) WriteU64(v uint64, field string) error { + return wr.current.add(v) +} + +func (wr *msgpWriter) WriteBool(v bool, field string) error { + return wr.current.add(v) +} + +func (wr *msgpWriter) WriteFloat(v float32, field string) error { + return wr.current.add(v) +} + +func (wr *msgpWriter) WriteDouble(v float64, field string) error { + return wr.current.add(v) +} + +func (wr *msgpWriter) WriteString(v string, field string) error { + return wr.current.add(v) +} + +func (wr *msgpWriter) WriteBlob(v []byte, field string) error { + return wr.current.add(v) +} + +func (wr *msgpWriter) BeginContainer(field string) { + wr.current = newMsgpWriteContainer() + wr.stack = append(wr.stack, wr.current) +} + +func (wr *msgpWriter) EndContainer() error { + if len(wr.stack) <= 1 { + return errors.New("No open container") + } + last := wr.current + wr.stack[len(wr.stack)-1] = nil + wr.stack = wr.stack[0 : len(wr.stack)-1] + wr.current = wr.stack[len(wr.stack)-1] + + bytes, err := msgpack.Marshal(last.arr) + if err != nil { + return err + } + + //NOTE: using custom msgpack encoder + wr.current.add(&msgpCustomBytes{bytes}) + + return nil +} + +func (wr *msgpWriter) GetData() ([]byte, error) { + if len(wr.stack) != 1 { + return nil, errors.New("Stack isn't empty") + } + + if len(wr.current.arr) != 1 { + return nil, errors.New("Arr size isn't valid") + } + + return msgpack.Marshal(wr.current.arr[0]) +} + +type msgpCustomBytes struct { + v []byte +} + +func (msgp *msgpCustomBytes) EncodeMsgpack(writer io.Writer) error { + _, err := writer.Write(msgp.v) + return errors.WithStack(err) +}