improve writer interface

This commit is contained in:
Pavel Merzlyakov 2023-07-03 22:09:39 +03:00
parent abe543a275
commit fe1e485334
9 changed files with 268 additions and 104 deletions

View File

@ -3,7 +3,7 @@ package meta_test
import ( import (
"testing" "testing"
"git.bit5.ru/backend/meta/v2" "git.bit5.ru/backend/meta/v3"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"

2
go.mod
View File

@ -1,4 +1,4 @@
module git.bit5.ru/backend/meta/v2 module git.bit5.ru/backend/meta/v3
go 1.20 go 1.20

View File

@ -48,8 +48,10 @@ type Writer interface {
WriteString(v string, field string) error WriteString(v string, field string) error
WriteBytes(v []byte, field string) error WriteBytes(v []byte, field string) error
BeginCollection(length int, field string) error
EndCollection() error
BeginContainer(length int, field string) error BeginContainer(length int, field string) error
BeginAssocContainer(length int, field string) error
EndContainer() error EndContainer() error
} }

171
msgpack_assoc_writer.go Normal file
View File

@ -0,0 +1,171 @@
package meta
import (
"io"
"github.com/pkg/errors"
"github.com/vmihailenco/msgpack/v5"
)
type msgpackAssocWriter struct {
enc *msgpack.Encoder
containers []writeContainer
}
type writeContainer struct {
length int
assoc bool
}
func NewMsgpackAssocWriter(w io.Writer) Writer {
return &msgpackAssocWriter{
enc: msgpack.NewEncoder(w),
containers: make([]writeContainer, 0, 1),
}
}
func (wr *msgpackAssocWriter) currentContainer() writeContainer {
if wr == nil || len(wr.containers) == 0 {
return writeContainer{}
}
return wr.containers[len(wr.containers)-1]
}
func (wr *msgpackAssocWriter) writeFieldName(field string) error {
if !wr.currentContainer().assoc {
return nil
}
return errors.WithStack(wr.enc.EncodeString(field))
}
func (wr *msgpackAssocWriter) WriteInt8(v int8, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeInt(int64(v)))
}
func (wr *msgpackAssocWriter) WriteInt16(v int16, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeInt(int64(v)))
}
func (wr *msgpackAssocWriter) WriteInt32(v int32, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeInt(int64(v)))
}
func (wr *msgpackAssocWriter) WriteInt64(v int64, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeInt(v))
}
func (wr *msgpackAssocWriter) WriteUint8(v uint8, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeUint(uint64(v)))
}
func (wr *msgpackAssocWriter) WriteUint16(v uint16, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeUint(uint64(v)))
}
func (wr *msgpackAssocWriter) WriteUint32(v uint32, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeUint(uint64(v)))
}
func (wr *msgpackAssocWriter) WriteUint64(v uint64, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeUint(uint64(v)))
}
func (wr *msgpackAssocWriter) WriteBool(v bool, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeBool(v))
}
func (wr *msgpackAssocWriter) WriteFloat32(v float32, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeFloat32(v))
}
func (wr *msgpackAssocWriter) WriteFloat64(v float64, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeFloat64(v))
}
func (wr *msgpackAssocWriter) WriteString(v string, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeString(v))
}
func (wr *msgpackAssocWriter) WriteBytes(v []byte, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeBytes(v))
}
func (wr *msgpackAssocWriter) BeginContainer(length int, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
if err := wr.enc.EncodeMapLen(length); err != nil {
return errors.WithStack(err)
}
wr.containers = append(wr.containers, writeContainer{
length: length,
assoc: true,
})
return nil
}
func (wr *msgpackAssocWriter) EndContainer() error {
if len(wr.containers) == 0 {
return errors.New("there is no open containers")
}
wr.containers = wr.containers[:len(wr.containers)-1]
return nil
}
func (wr *msgpackAssocWriter) BeginCollection(length int, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
if err := wr.enc.EncodeArrayLen(length); err != nil {
return errors.WithStack(err)
}
wr.containers = append(wr.containers, writeContainer{
length: length,
assoc: false,
})
return nil
}
func (wr *msgpackAssocWriter) EndCollection() error {
return wr.EndContainer()
}

View File

@ -0,0 +1,63 @@
package meta_test
import (
"bytes"
"encoding/hex"
"testing"
"git.bit5.ru/backend/meta/v3"
"github.com/stretchr/testify/require"
)
func TestMsgpackAssocWriter(t *testing.T) {
t.Run("write struct", func(t *testing.T) {
var buf bytes.Buffer
wr := meta.NewMsgpackAssocWriter(&buf)
s := TestParent{
Field1: "blabla",
Field2: TestFoo{
Field: 1,
},
Field3: []int8{2, 4, 6},
Field4: []TestFoo{
{Field: 10},
{Field: 1024},
},
}
err := s.Write(wr)
require.NoError(t, err)
expected := "84a26631a6626c61626c61a2663281a56669656c6401a2663393020406a266349281a56669656c640a81a56669656c64cd0400"
actual := hex.EncodeToString(buf.Bytes())
require.EqualValues(t, expected, actual)
})
t.Run("write child struct", func(t *testing.T) {
var buf bytes.Buffer
wr := meta.NewMsgpackAssocWriter(&buf)
s := TestChild{
Field: "qwerty",
TestParent: TestParent{
Field1: "blabla",
Field2: TestFoo{
Field: 1,
},
Field3: []int8{2, 4, 6},
Field4: []TestFoo{
{Field: 10},
{Field: 1024},
},
},
}
err := s.Write(wr)
require.NoError(t, err)
expected := "85a26631a6626c61626c61a2663281a56669656c6401a2663393020406a266349281a56669656c640a81a56669656c64cd0400a166a6717765727479"
actual := hex.EncodeToString(buf.Bytes())
require.EqualValues(t, expected, actual)
})
}

View File

@ -5,7 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"testing" "testing"
"git.bit5.ru/backend/meta/v2" "git.bit5.ru/backend/meta/v3"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )

View File

@ -9,156 +9,85 @@ import (
type msgpackWriter struct { type msgpackWriter struct {
enc *msgpack.Encoder enc *msgpack.Encoder
containers []writeContainer containers []struct{}
}
type writeContainer struct {
length int
assoc bool
} }
func NewMsgpackWriter(w io.Writer) Writer { func NewMsgpackWriter(w io.Writer) Writer {
return &msgpackWriter{ return &msgpackWriter{
enc: msgpack.NewEncoder(w), enc: msgpack.NewEncoder(w),
containers: make([]writeContainer, 0, 1), containers: make([]struct{}, 0, 1),
} }
} }
func (wr *msgpackWriter) currentContainer() writeContainer {
if wr == nil || len(wr.containers) == 0 {
return writeContainer{}
}
return wr.containers[len(wr.containers)-1]
}
func (wr *msgpackWriter) writeFieldName(field string) error {
if !wr.currentContainer().assoc {
return nil
}
return errors.WithStack(wr.enc.EncodeString(field))
}
func (wr *msgpackWriter) WriteInt8(v int8, field string) error { func (wr *msgpackWriter) WriteInt8(v int8, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeInt(int64(v))) return errors.WithStack(wr.enc.EncodeInt(int64(v)))
} }
func (wr *msgpackWriter) WriteInt16(v int16, field string) error { func (wr *msgpackWriter) WriteInt16(v int16, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeInt(int64(v))) return errors.WithStack(wr.enc.EncodeInt(int64(v)))
} }
func (wr *msgpackWriter) WriteInt32(v int32, field string) error { func (wr *msgpackWriter) WriteInt32(v int32, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeInt(int64(v))) return errors.WithStack(wr.enc.EncodeInt(int64(v)))
} }
func (wr *msgpackWriter) WriteInt64(v int64, field string) error { func (wr *msgpackWriter) WriteInt64(v int64, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeInt(v)) return errors.WithStack(wr.enc.EncodeInt(v))
} }
func (wr *msgpackWriter) WriteUint8(v uint8, field string) error { func (wr *msgpackWriter) WriteUint8(v uint8, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeUint(uint64(v))) return errors.WithStack(wr.enc.EncodeUint(uint64(v)))
} }
func (wr *msgpackWriter) WriteUint16(v uint16, field string) error { func (wr *msgpackWriter) WriteUint16(v uint16, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeUint(uint64(v))) return errors.WithStack(wr.enc.EncodeUint(uint64(v)))
} }
func (wr *msgpackWriter) WriteUint32(v uint32, field string) error { func (wr *msgpackWriter) WriteUint32(v uint32, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeUint(uint64(v))) return errors.WithStack(wr.enc.EncodeUint(uint64(v)))
} }
func (wr *msgpackWriter) WriteUint64(v uint64, field string) error { func (wr *msgpackWriter) WriteUint64(v uint64, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeUint(uint64(v))) return errors.WithStack(wr.enc.EncodeUint(uint64(v)))
} }
func (wr *msgpackWriter) WriteBool(v bool, field string) error { func (wr *msgpackWriter) WriteBool(v bool, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeBool(v)) return errors.WithStack(wr.enc.EncodeBool(v))
} }
func (wr *msgpackWriter) WriteFloat32(v float32, field string) error { func (wr *msgpackWriter) WriteFloat32(v float32, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeFloat32(v)) return errors.WithStack(wr.enc.EncodeFloat32(v))
} }
func (wr *msgpackWriter) WriteFloat64(v float64, field string) error { func (wr *msgpackWriter) WriteFloat64(v float64, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeFloat64(v)) return errors.WithStack(wr.enc.EncodeFloat64(v))
} }
func (wr *msgpackWriter) WriteString(v string, field string) error { func (wr *msgpackWriter) WriteString(v string, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeString(v)) return errors.WithStack(wr.enc.EncodeString(v))
} }
func (wr *msgpackWriter) WriteBytes(v []byte, field string) error { func (wr *msgpackWriter) WriteBytes(v []byte, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
return errors.WithStack(wr.enc.EncodeBytes(v)) return errors.WithStack(wr.enc.EncodeBytes(v))
} }
func (wr *msgpackWriter) BeginContainer(length int, field string) error { func (wr *msgpackWriter) BeginContainer(length int, field string) error {
if err := wr.writeFieldName(field); err != nil { return wr.BeginCollection(length, field)
return err
}
if err := wr.enc.EncodeArrayLen(length); err != nil {
return errors.WithStack(err)
}
wr.containers = append(wr.containers, writeContainer{
length: length,
assoc: false,
})
return nil
}
func (wr *msgpackWriter) BeginAssocContainer(length int, field string) error {
if err := wr.writeFieldName(field); err != nil {
return err
}
if err := wr.enc.EncodeMapLen(length); err != nil {
return errors.WithStack(err)
}
wr.containers = append(wr.containers, writeContainer{
length: length,
assoc: true,
})
return nil
} }
func (wr *msgpackWriter) EndContainer() error { func (wr *msgpackWriter) EndContainer() error {
return wr.EndCollection()
}
func (wr *msgpackWriter) BeginCollection(length int, field string) error {
if err := wr.enc.EncodeArrayLen(length); err != nil {
return errors.WithStack(err)
}
wr.containers = append(wr.containers, struct{}{})
return nil
}
func (wr *msgpackWriter) EndCollection() error {
if len(wr.containers) == 0 { if len(wr.containers) == 0 {
return errors.New("there is no open containers") return errors.New("there is no open containers")
} }

View File

@ -5,7 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"testing" "testing"
"git.bit5.ru/backend/meta/v2" "git.bit5.ru/backend/meta/v3"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -29,7 +29,7 @@ func TestMsgpackWriter(t *testing.T) {
err := s.Write(wr) err := s.Write(wr)
require.NoError(t, err) require.NoError(t, err)
expected := "84a26631a6626c61626c61a2663281a56669656c6401a2663393020406a266349281a56669656c640a81a56669656c64cd0400" expected := "94a6626c61626c6191019302040692910a91cd0400"
actual := hex.EncodeToString(buf.Bytes()) actual := hex.EncodeToString(buf.Bytes())
require.EqualValues(t, expected, actual) require.EqualValues(t, expected, actual)
}) })
@ -56,7 +56,7 @@ func TestMsgpackWriter(t *testing.T) {
err := s.Write(wr) err := s.Write(wr)
require.NoError(t, err) require.NoError(t, err)
expected := "85a26631a6626c61626c61a2663281a56669656c6401a2663393020406a266349281a56669656c640a81a56669656c64cd0400a166a6717765727479" expected := "95a6626c61626c6191019302040692910a91cd0400a6717765727479"
actual := hex.EncodeToString(buf.Bytes()) actual := hex.EncodeToString(buf.Bytes())
require.EqualValues(t, expected, actual) require.EqualValues(t, expected, actual)
}) })

View File

@ -1,7 +1,7 @@
package meta_test package meta_test
import ( import (
"git.bit5.ru/backend/meta/v2" "git.bit5.ru/backend/meta/v3"
) )
type TestParent struct { type TestParent struct {
@ -132,7 +132,7 @@ func (s *TestParent) ReadFields(reader meta.Reader) error {
} }
func (s *TestParent) Write(writer meta.Writer) error { func (s *TestParent) Write(writer meta.Writer) error {
if err := writer.BeginAssocContainer(4, ""); err != nil { if err := writer.BeginContainer(4, ""); err != nil {
return err return err
} }
if err := s.WriteFields(writer); err != nil { if err := s.WriteFields(writer); err != nil {
@ -147,7 +147,7 @@ func (s *TestParent) WriteFields(writer meta.Writer) error {
return err return err
} }
if err := writer.BeginAssocContainer(1, "f2"); err != nil { if err := writer.BeginContainer(1, "f2"); err != nil {
return err return err
} }
if err := s.Field2.WriteFields(writer); err != nil { if err := s.Field2.WriteFields(writer); err != nil {
@ -157,24 +157,23 @@ func (s *TestParent) WriteFields(writer meta.Writer) error {
return err return err
} }
if err := writer.BeginContainer(len(s.Field3), "f3"); err != nil { if err := writer.BeginCollection(len(s.Field3), "f3"); err != nil {
return err return err
} }
for _, v := range s.Field3 { for _, v := range s.Field3 {
if err := writer.WriteInt8(v, ""); err != nil { if err := writer.WriteInt8(v, ""); err != nil {
return err return err
} }
} }
if err := writer.EndContainer(); err != nil { if err := writer.EndCollection(); err != nil {
return err return err
} }
if err := writer.BeginContainer(len(s.Field4), "f4"); err != nil { if err := writer.BeginCollection(len(s.Field4), "f4"); err != nil {
return err return err
} }
for _, v := range s.Field4 { for _, v := range s.Field4 {
if err := writer.BeginAssocContainer(1, ""); err != nil { if err := writer.BeginContainer(1, ""); err != nil {
return err return err
} }
if err := v.WriteFields(writer); err != nil { if err := v.WriteFields(writer); err != nil {
@ -185,7 +184,7 @@ func (s *TestParent) WriteFields(writer meta.Writer) error {
} }
} }
if err := writer.EndContainer(); err != nil { if err := writer.EndCollection(); err != nil {
return err return err
} }
@ -244,7 +243,7 @@ func (s *TestChild) ReadFields(reader meta.Reader) error {
} }
func (s *TestChild) Write(writer meta.Writer) error { func (s *TestChild) Write(writer meta.Writer) error {
if err := writer.BeginAssocContainer(5, ""); err != nil { if err := writer.BeginContainer(5, ""); err != nil {
return err return err
} }
if err := s.WriteFields(writer); err != nil { if err := s.WriteFields(writer); err != nil {
@ -310,7 +309,7 @@ func (s *TestFoo) ReadFields(reader meta.Reader) error {
} }
func (s *TestFoo) Write(writer meta.Writer) error { func (s *TestFoo) Write(writer meta.Writer) error {
if err := writer.BeginAssocContainer(1, ""); err != nil { if err := writer.BeginContainer(1, ""); err != nil {
return err return err
} }
if err := s.WriteFields(writer); err != nil { if err := s.WriteFields(writer); err != nil {