340 lines
6.2 KiB
Go
340 lines
6.2 KiB
Go
package msgpack
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/vmihailenco/msgpack/codes"
|
|
)
|
|
|
|
const mapElemsAllocLimit = 1e4
|
|
|
|
var mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil))
|
|
var mapStringStringType = mapStringStringPtrType.Elem()
|
|
|
|
var mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil))
|
|
var mapStringInterfaceType = mapStringInterfacePtrType.Elem()
|
|
|
|
var errInvalidCode = errors.New("invalid code")
|
|
|
|
func decodeMapValue(d *Decoder, v reflect.Value) error {
|
|
size, err := d.DecodeMapLen()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
typ := v.Type()
|
|
if size == -1 {
|
|
v.Set(reflect.Zero(typ))
|
|
return nil
|
|
}
|
|
|
|
if v.IsNil() {
|
|
v.Set(reflect.MakeMap(typ))
|
|
}
|
|
if size == 0 {
|
|
return nil
|
|
}
|
|
|
|
return decodeMapValueSize(d, v, size)
|
|
}
|
|
|
|
func decodeMapValueSize(d *Decoder, v reflect.Value, size int) error {
|
|
typ := v.Type()
|
|
keyType := typ.Key()
|
|
valueType := typ.Elem()
|
|
|
|
for i := 0; i < size; i++ {
|
|
mk := reflect.New(keyType).Elem()
|
|
if err := d.DecodeValue(mk); err != nil {
|
|
return err
|
|
}
|
|
|
|
mv := reflect.New(valueType).Elem()
|
|
if err := d.DecodeValue(mv); err != nil {
|
|
return err
|
|
}
|
|
|
|
v.SetMapIndex(mk, mv)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DecodeMapLen decodes map length. Length is -1 when map is nil.
|
|
func (d *Decoder) DecodeMapLen() (int, error) {
|
|
c, err := d.readCode()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if codes.IsExt(c) {
|
|
if err = d.skipExtHeader(c); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
c, err = d.readCode()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
return d.mapLen(c)
|
|
}
|
|
|
|
func (d *Decoder) mapLen(c codes.Code) (int, error) {
|
|
size, err := d._mapLen(c)
|
|
err = expandInvalidCodeMapLenError(c, err)
|
|
return size, err
|
|
}
|
|
|
|
func (d *Decoder) _mapLen(c codes.Code) (int, error) {
|
|
if c == codes.Nil {
|
|
return -1, nil
|
|
}
|
|
if c >= codes.FixedMapLow && c <= codes.FixedMapHigh {
|
|
return int(c & codes.FixedMapMask), nil
|
|
}
|
|
if c == codes.Map16 {
|
|
size, err := d.uint16()
|
|
return int(size), err
|
|
}
|
|
if c == codes.Map32 {
|
|
size, err := d.uint32()
|
|
return int(size), err
|
|
}
|
|
return 0, errInvalidCode
|
|
}
|
|
|
|
func expandInvalidCodeMapLenError(c codes.Code, err error) error {
|
|
if err == errInvalidCode {
|
|
return fmt.Errorf("msgpack: invalid code=%x decoding map length", c)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func decodeMapStringStringValue(d *Decoder, v reflect.Value) error {
|
|
mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string)
|
|
return d.decodeMapStringStringPtr(mptr)
|
|
}
|
|
|
|
func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error {
|
|
size, err := d.DecodeMapLen()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if size == -1 {
|
|
*ptr = nil
|
|
return nil
|
|
}
|
|
|
|
m := *ptr
|
|
if m == nil {
|
|
*ptr = make(map[string]string, min(size, mapElemsAllocLimit))
|
|
m = *ptr
|
|
}
|
|
|
|
for i := 0; i < size; i++ {
|
|
mk, err := d.DecodeString()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mv, err := d.DecodeString()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m[mk] = mv
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error {
|
|
ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{})
|
|
return d.decodeMapStringInterfacePtr(ptr)
|
|
}
|
|
|
|
func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error {
|
|
n, err := d.DecodeMapLen()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n == -1 {
|
|
*ptr = nil
|
|
return nil
|
|
}
|
|
|
|
m := *ptr
|
|
if m == nil {
|
|
*ptr = make(map[string]interface{}, min(n, mapElemsAllocLimit))
|
|
m = *ptr
|
|
}
|
|
|
|
for i := 0; i < n; i++ {
|
|
mk, err := d.DecodeString()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mv, err := d.decodeInterfaceCond()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m[mk] = mv
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *Decoder) DecodeMap() (interface{}, error) {
|
|
if d.decodeMapFunc != nil {
|
|
return d.decodeMapFunc(d)
|
|
}
|
|
|
|
size, err := d.DecodeMapLen()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if size == -1 {
|
|
return nil, nil
|
|
}
|
|
if size == 0 {
|
|
return make(map[string]interface{}), nil
|
|
}
|
|
|
|
code, err := d.PeekCode()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if codes.IsString(code) || codes.IsBin(code) {
|
|
return d.decodeMapStringInterfaceSize(size)
|
|
}
|
|
|
|
key, err := d.decodeInterfaceCond()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
value, err := d.decodeInterfaceCond()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
keyType := reflect.TypeOf(key)
|
|
valueType := reflect.TypeOf(value)
|
|
|
|
mapType := reflect.MapOf(keyType, valueType)
|
|
mapValue := reflect.MakeMap(mapType)
|
|
|
|
mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value))
|
|
size--
|
|
|
|
err = decodeMapValueSize(d, mapValue, size)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return mapValue.Interface(), nil
|
|
}
|
|
|
|
func (d *Decoder) decodeMapStringInterfaceSize(size int) (map[string]interface{}, error) {
|
|
m := make(map[string]interface{}, min(size, mapElemsAllocLimit))
|
|
for i := 0; i < size; i++ {
|
|
mk, err := d.DecodeString()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mv, err := d.decodeInterfaceCond()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m[mk] = mv
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
func (d *Decoder) skipMap(c codes.Code) error {
|
|
n, err := d.mapLen(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i := 0; i < n; i++ {
|
|
if err := d.Skip(); err != nil {
|
|
return err
|
|
}
|
|
if err := d.Skip(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func decodeStructValue(d *Decoder, v reflect.Value) error {
|
|
c, err := d.readCode()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var isArray bool
|
|
|
|
n, err := d._mapLen(c)
|
|
if err != nil {
|
|
var err2 error
|
|
n, err2 = d.arrayLen(c)
|
|
if err2 != nil {
|
|
return expandInvalidCodeMapLenError(c, err)
|
|
}
|
|
isArray = true
|
|
}
|
|
if n == -1 {
|
|
if err = mustSet(v); err != nil {
|
|
return err
|
|
}
|
|
v.Set(reflect.Zero(v.Type()))
|
|
return nil
|
|
}
|
|
|
|
var fields *fields
|
|
if d.useJSONTag {
|
|
fields = jsonStructs.Fields(v.Type())
|
|
} else {
|
|
fields = structs.Fields(v.Type())
|
|
}
|
|
|
|
if isArray {
|
|
for i, f := range fields.List {
|
|
if i >= n {
|
|
break
|
|
}
|
|
if err := f.DecodeValue(d, v); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// Skip extra values.
|
|
for i := len(fields.List); i < n; i++ {
|
|
if err := d.Skip(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
for i := 0; i < n; i++ {
|
|
name, err := d.DecodeString()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if f := fields.Table[name]; f != nil {
|
|
if err := f.DecodeValue(d, v); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := d.Skip(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|