From fe1e485334122c1365b8c4215966735ae99caa8b Mon Sep 17 00:00:00 2001 From: Pavel Merzlyakov Date: Mon, 3 Jul 2023 22:09:39 +0300 Subject: [PATCH] improve writer interface --- fields_mask_test.go | 2 +- go.mod | 2 +- interface.go | 4 +- msgpack_assoc_writer.go | 171 +++++++++++++++++++++++++++++++++++ msgpack_assoc_writer_test.go | 63 +++++++++++++ msgpack_reader_test.go | 2 +- msgpack_writer.go | 101 +++------------------ msgpack_writer_test.go | 6 +- structs_test.go | 21 ++--- 9 files changed, 268 insertions(+), 104 deletions(-) create mode 100644 msgpack_assoc_writer.go create mode 100644 msgpack_assoc_writer_test.go diff --git a/fields_mask_test.go b/fields_mask_test.go index e81a999..7209c2c 100644 --- a/fields_mask_test.go +++ b/fields_mask_test.go @@ -3,7 +3,7 @@ package meta_test import ( "testing" - "git.bit5.ru/backend/meta/v2" + "git.bit5.ru/backend/meta/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/go.mod b/go.mod index 78c16e3..5b286be 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module git.bit5.ru/backend/meta/v2 +module git.bit5.ru/backend/meta/v3 go 1.20 diff --git a/interface.go b/interface.go index 0b2b4ad..952e712 100644 --- a/interface.go +++ b/interface.go @@ -48,8 +48,10 @@ type Writer interface { WriteString(v string, field string) error WriteBytes(v []byte, field string) error + BeginCollection(length int, field string) error + EndCollection() error + BeginContainer(length int, field string) error - BeginAssocContainer(length int, field string) error EndContainer() error } diff --git a/msgpack_assoc_writer.go b/msgpack_assoc_writer.go new file mode 100644 index 0000000..fc4d00d --- /dev/null +++ b/msgpack_assoc_writer.go @@ -0,0 +1,171 @@ +package meta + +import ( + "io" + + "github.com/pkg/errors" + "github.com/vmihailenco/msgpack/v5" +) + +type msgpackAssocWriter struct { + enc *msgpack.Encoder + containers []writeContainer +} + +type writeContainer struct { + length int + assoc bool +} + +func NewMsgpackAssocWriter(w io.Writer) Writer { + return &msgpackAssocWriter{ + enc: msgpack.NewEncoder(w), + containers: make([]writeContainer, 0, 1), + } +} + +func (wr *msgpackAssocWriter) currentContainer() writeContainer { + if wr == nil || len(wr.containers) == 0 { + return writeContainer{} + } + return wr.containers[len(wr.containers)-1] +} + +func (wr *msgpackAssocWriter) writeFieldName(field string) error { + if !wr.currentContainer().assoc { + return nil + } + + return errors.WithStack(wr.enc.EncodeString(field)) +} + +func (wr *msgpackAssocWriter) WriteInt8(v int8, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeInt(int64(v))) +} + +func (wr *msgpackAssocWriter) WriteInt16(v int16, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeInt(int64(v))) +} + +func (wr *msgpackAssocWriter) WriteInt32(v int32, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeInt(int64(v))) +} + +func (wr *msgpackAssocWriter) WriteInt64(v int64, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeInt(v)) +} + +func (wr *msgpackAssocWriter) WriteUint8(v uint8, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeUint(uint64(v))) +} + +func (wr *msgpackAssocWriter) WriteUint16(v uint16, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeUint(uint64(v))) +} + +func (wr *msgpackAssocWriter) WriteUint32(v uint32, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeUint(uint64(v))) +} + +func (wr *msgpackAssocWriter) WriteUint64(v uint64, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeUint(uint64(v))) +} + +func (wr *msgpackAssocWriter) WriteBool(v bool, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeBool(v)) +} + +func (wr *msgpackAssocWriter) WriteFloat32(v float32, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeFloat32(v)) +} + +func (wr *msgpackAssocWriter) WriteFloat64(v float64, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeFloat64(v)) +} + +func (wr *msgpackAssocWriter) WriteString(v string, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeString(v)) +} + +func (wr *msgpackAssocWriter) WriteBytes(v []byte, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + return errors.WithStack(wr.enc.EncodeBytes(v)) +} + +func (wr *msgpackAssocWriter) BeginContainer(length int, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + if err := wr.enc.EncodeMapLen(length); err != nil { + return errors.WithStack(err) + } + wr.containers = append(wr.containers, writeContainer{ + length: length, + assoc: true, + }) + return nil +} + +func (wr *msgpackAssocWriter) EndContainer() error { + if len(wr.containers) == 0 { + return errors.New("there is no open containers") + } + wr.containers = wr.containers[:len(wr.containers)-1] + return nil +} + +func (wr *msgpackAssocWriter) BeginCollection(length int, field string) error { + if err := wr.writeFieldName(field); err != nil { + return err + } + if err := wr.enc.EncodeArrayLen(length); err != nil { + return errors.WithStack(err) + } + wr.containers = append(wr.containers, writeContainer{ + length: length, + assoc: false, + }) + return nil +} + +func (wr *msgpackAssocWriter) EndCollection() error { + return wr.EndContainer() +} diff --git a/msgpack_assoc_writer_test.go b/msgpack_assoc_writer_test.go new file mode 100644 index 0000000..1855ce9 --- /dev/null +++ b/msgpack_assoc_writer_test.go @@ -0,0 +1,63 @@ +package meta_test + +import ( + "bytes" + "encoding/hex" + "testing" + + "git.bit5.ru/backend/meta/v3" + "github.com/stretchr/testify/require" +) + +func TestMsgpackAssocWriter(t *testing.T) { + t.Run("write struct", func(t *testing.T) { + var buf bytes.Buffer + wr := meta.NewMsgpackAssocWriter(&buf) + + s := TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + } + + err := s.Write(wr) + require.NoError(t, err) + + expected := "84a26631a6626c61626c61a2663281a56669656c6401a2663393020406a266349281a56669656c640a81a56669656c64cd0400" + actual := hex.EncodeToString(buf.Bytes()) + require.EqualValues(t, expected, actual) + }) + + t.Run("write child struct", func(t *testing.T) { + var buf bytes.Buffer + wr := meta.NewMsgpackAssocWriter(&buf) + + s := TestChild{ + Field: "qwerty", + TestParent: TestParent{ + Field1: "blabla", + Field2: TestFoo{ + Field: 1, + }, + Field3: []int8{2, 4, 6}, + Field4: []TestFoo{ + {Field: 10}, + {Field: 1024}, + }, + }, + } + + err := s.Write(wr) + require.NoError(t, err) + + expected := "85a26631a6626c61626c61a2663281a56669656c6401a2663393020406a266349281a56669656c640a81a56669656c64cd0400a166a6717765727479" + actual := hex.EncodeToString(buf.Bytes()) + require.EqualValues(t, expected, actual) + }) +} diff --git a/msgpack_reader_test.go b/msgpack_reader_test.go index 377c60d..732f194 100644 --- a/msgpack_reader_test.go +++ b/msgpack_reader_test.go @@ -5,7 +5,7 @@ import ( "encoding/hex" "testing" - "git.bit5.ru/backend/meta/v2" + "git.bit5.ru/backend/meta/v3" "github.com/stretchr/testify/require" ) diff --git a/msgpack_writer.go b/msgpack_writer.go index d929153..9ec9172 100644 --- a/msgpack_writer.go +++ b/msgpack_writer.go @@ -9,156 +9,85 @@ import ( type msgpackWriter struct { enc *msgpack.Encoder - containers []writeContainer -} - -type writeContainer struct { - length int - assoc bool + containers []struct{} } func NewMsgpackWriter(w io.Writer) Writer { return &msgpackWriter{ enc: msgpack.NewEncoder(w), - containers: make([]writeContainer, 0, 1), + containers: make([]struct{}, 0, 1), } } -func (wr *msgpackWriter) currentContainer() writeContainer { - if wr == nil || len(wr.containers) == 0 { - return writeContainer{} - } - return wr.containers[len(wr.containers)-1] -} - -func (wr *msgpackWriter) writeFieldName(field string) error { - if !wr.currentContainer().assoc { - return nil - } - - return errors.WithStack(wr.enc.EncodeString(field)) -} - func (wr *msgpackWriter) WriteInt8(v int8, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeInt(int64(v))) } func (wr *msgpackWriter) WriteInt16(v int16, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeInt(int64(v))) } func (wr *msgpackWriter) WriteInt32(v int32, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeInt(int64(v))) } func (wr *msgpackWriter) WriteInt64(v int64, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeInt(v)) } func (wr *msgpackWriter) WriteUint8(v uint8, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeUint(uint64(v))) } func (wr *msgpackWriter) WriteUint16(v uint16, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeUint(uint64(v))) } func (wr *msgpackWriter) WriteUint32(v uint32, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeUint(uint64(v))) } func (wr *msgpackWriter) WriteUint64(v uint64, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeUint(uint64(v))) } func (wr *msgpackWriter) WriteBool(v bool, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeBool(v)) } func (wr *msgpackWriter) WriteFloat32(v float32, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeFloat32(v)) } func (wr *msgpackWriter) WriteFloat64(v float64, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeFloat64(v)) } func (wr *msgpackWriter) WriteString(v string, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeString(v)) } func (wr *msgpackWriter) WriteBytes(v []byte, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } return errors.WithStack(wr.enc.EncodeBytes(v)) } func (wr *msgpackWriter) BeginContainer(length int, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } - if err := wr.enc.EncodeArrayLen(length); err != nil { - return errors.WithStack(err) - } - wr.containers = append(wr.containers, writeContainer{ - length: length, - assoc: false, - }) - return nil -} - -func (wr *msgpackWriter) BeginAssocContainer(length int, field string) error { - if err := wr.writeFieldName(field); err != nil { - return err - } - if err := wr.enc.EncodeMapLen(length); err != nil { - return errors.WithStack(err) - } - wr.containers = append(wr.containers, writeContainer{ - length: length, - assoc: true, - }) - return nil + return wr.BeginCollection(length, field) } func (wr *msgpackWriter) EndContainer() error { + return wr.EndCollection() +} + +func (wr *msgpackWriter) BeginCollection(length int, field string) error { + if err := wr.enc.EncodeArrayLen(length); err != nil { + return errors.WithStack(err) + } + wr.containers = append(wr.containers, struct{}{}) + return nil +} + +func (wr *msgpackWriter) EndCollection() error { if len(wr.containers) == 0 { return errors.New("there is no open containers") } diff --git a/msgpack_writer_test.go b/msgpack_writer_test.go index 783b085..1aaaccb 100644 --- a/msgpack_writer_test.go +++ b/msgpack_writer_test.go @@ -5,7 +5,7 @@ import ( "encoding/hex" "testing" - "git.bit5.ru/backend/meta/v2" + "git.bit5.ru/backend/meta/v3" "github.com/stretchr/testify/require" ) @@ -29,7 +29,7 @@ func TestMsgpackWriter(t *testing.T) { err := s.Write(wr) require.NoError(t, err) - expected := "84a26631a6626c61626c61a2663281a56669656c6401a2663393020406a266349281a56669656c640a81a56669656c64cd0400" + expected := "94a6626c61626c6191019302040692910a91cd0400" actual := hex.EncodeToString(buf.Bytes()) require.EqualValues(t, expected, actual) }) @@ -56,7 +56,7 @@ func TestMsgpackWriter(t *testing.T) { err := s.Write(wr) require.NoError(t, err) - expected := "85a26631a6626c61626c61a2663281a56669656c6401a2663393020406a266349281a56669656c640a81a56669656c64cd0400a166a6717765727479" + expected := "95a6626c61626c6191019302040692910a91cd0400a6717765727479" actual := hex.EncodeToString(buf.Bytes()) require.EqualValues(t, expected, actual) }) diff --git a/structs_test.go b/structs_test.go index 23aea77..18bcb2b 100644 --- a/structs_test.go +++ b/structs_test.go @@ -1,7 +1,7 @@ package meta_test import ( - "git.bit5.ru/backend/meta/v2" + "git.bit5.ru/backend/meta/v3" ) type TestParent struct { @@ -132,7 +132,7 @@ func (s *TestParent) ReadFields(reader meta.Reader) error { } func (s *TestParent) Write(writer meta.Writer) error { - if err := writer.BeginAssocContainer(4, ""); err != nil { + if err := writer.BeginContainer(4, ""); err != nil { return err } if err := s.WriteFields(writer); err != nil { @@ -147,7 +147,7 @@ func (s *TestParent) WriteFields(writer meta.Writer) error { return err } - if err := writer.BeginAssocContainer(1, "f2"); err != nil { + if err := writer.BeginContainer(1, "f2"); err != nil { return err } if err := s.Field2.WriteFields(writer); err != nil { @@ -157,24 +157,23 @@ func (s *TestParent) WriteFields(writer meta.Writer) error { return err } - if err := writer.BeginContainer(len(s.Field3), "f3"); err != nil { + if err := writer.BeginCollection(len(s.Field3), "f3"); err != nil { return err } for _, v := range s.Field3 { if err := writer.WriteInt8(v, ""); err != nil { return err } - } - if err := writer.EndContainer(); err != nil { + if err := writer.EndCollection(); err != nil { return err } - if err := writer.BeginContainer(len(s.Field4), "f4"); err != nil { + if err := writer.BeginCollection(len(s.Field4), "f4"); err != nil { return err } for _, v := range s.Field4 { - if err := writer.BeginAssocContainer(1, ""); err != nil { + if err := writer.BeginContainer(1, ""); err != nil { return err } if err := v.WriteFields(writer); err != nil { @@ -185,7 +184,7 @@ func (s *TestParent) WriteFields(writer meta.Writer) error { } } - if err := writer.EndContainer(); err != nil { + if err := writer.EndCollection(); err != nil { return err } @@ -244,7 +243,7 @@ func (s *TestChild) ReadFields(reader meta.Reader) error { } func (s *TestChild) Write(writer meta.Writer) error { - if err := writer.BeginAssocContainer(5, ""); err != nil { + if err := writer.BeginContainer(5, ""); err != nil { return err } if err := s.WriteFields(writer); err != nil { @@ -310,7 +309,7 @@ func (s *TestFoo) ReadFields(reader meta.Reader) error { } func (s *TestFoo) Write(writer meta.Writer) error { - if err := writer.BeginAssocContainer(1, ""); err != nil { + if err := writer.BeginContainer(1, ""); err != nil { return err } if err := s.WriteFields(writer); err != nil {