Browse Source

codec: Support decoding into a struct from an encoded array.

We now keep track of the current encoded type, and when decoding into a struct,
we decode by field name if a map, and by field position if an array.

Also, decodeNaked has been simplified so that the decodeHandle is not passed in
as a parameter. Instead, each *driver keeps a reference to the *Handle that created it.

Fixes #13 .
Ugorji Nwoke 12 years ago
parent
commit
366b4c7347
3 changed files with 180 additions and 68 deletions
  1. 62 15
      codec/binc.go
  2. 65 32
      codec/decode.go
  3. 53 21
      codec/msgpack.go

+ 62 - 15
codec/binc.go

@@ -82,7 +82,9 @@ type bincEncDriver struct {
 
 type bincDecDriver struct {
 	r      decReader
+	h      *BincHandle
 	bdRead bool
+	bdType decodeEncodedType
 	bd     byte
 	vd     byte
 	vs     byte
@@ -94,8 +96,8 @@ func (_ *BincHandle) newEncDriver(w encWriter) encDriver {
 	return &bincEncDriver{w: w}
 }
 
-func (_ *BincHandle) newDecDriver(r decReader) decDriver {
-	return &bincDecDriver{r: r}
+func (h *BincHandle) newDecDriver(r decReader) decDriver {
+	return &bincDecDriver{r: r, h: h}
 }
 
 func (_ *BincHandle) writeExt() bool {
@@ -345,9 +347,53 @@ func (d *bincDecDriver) initReadNext() {
 	d.vd = d.bd >> 4
 	d.vs = d.bd & 0x0f
 	d.bdRead = true
+	d.bdType = detUnset
+}
+
+func (d *bincDecDriver) currentEncodedType() decodeEncodedType {
+	if d.bdType == detUnset {
+		switch d.vd {
+		case bincVdSpecial:
+			switch d.vs {
+			case bincSpNil:
+				d.bdType = detNil
+			case bincSpFalse, bincSpTrue:
+				d.bdType = detBool
+			case bincSpNan, bincSpNegInf, bincSpPosInf, bincSpZeroFloat:
+				d.bdType = detFloat
+			case bincSpZero, bincSpNegOne:
+				d.bdType = detInt
+			default:
+				decErr("currentEncodedType: Unrecognized special value 0x%x", d.vs)
+			}
+		case bincVdSmallInt:
+			d.bdType = detInt
+		case bincVdUint:
+			d.bdType = detUint
+		case bincVdInt:
+			d.bdType = detInt
+		case bincVdFloat:
+			d.bdType = detFloat
+		case bincVdSymbol, bincVdString:
+			d.bdType = detString
+		case bincVdByteArray:
+			d.bdType = detBytes
+		case bincVdTimestamp:
+			d.bdType = detTimestamp
+		case bincVdCustomExt:
+			d.bdType = detExt
+		case bincVdArray:
+			d.bdType = detArray
+		case bincVdMap:
+			d.bdType = detMap
+		default:
+			decErr("currentEncodedType: Unrecognized d.vd: 0x%x", d.vd)
+		}		
+	}
+	return d.bdType
 }
 
-func (d *bincDecDriver) currentIsNil() bool {
+func (d *bincDecDriver) tryDecodeAsNil() bool {
 	if d.bd == bincVdSpecial<<4|bincSpNil {
 		d.bdRead = false
 		return true
@@ -699,7 +745,7 @@ func (d *bincDecDriver) decodeExt(tag byte) (xbs []byte) {
 	return
 }
 
-func (d *bincDecDriver) decodeNaked(h decodeHandleI) (rv reflect.Value, ctx decodeNakedContext) {
+func (d *bincDecDriver) decodeNaked() (rv reflect.Value, ctx decodeNakedContext) {
 	d.initReadNext()
 	var v interface{}
 
@@ -726,7 +772,7 @@ func (d *bincDecDriver) decodeNaked(h decodeHandleI) (rv reflect.Value, ctx deco
 		case bincSpNegOne:
 			v = int8(-1)
 		default:
-			decErr("Unrecognized special value 0x%x", d.vs)
+			decErr("decodeNaked: Unrecognized special value 0x%x", d.vs)
 		}
 	case bincVdSmallInt:
 		v = int8(d.vs) + 1
@@ -752,33 +798,30 @@ func (d *bincDecDriver) decodeNaked(h decodeHandleI) (rv reflect.Value, ctx deco
 		//ctx = dncExt
 		l := d.decLen()
 		xtag := d.r.readn1()
-		opts := h.(*BincHandle)
 		var bfn func(reflect.Value, []byte) error
-		rv, bfn = opts.getDecodeExtForTag(xtag)
+		rv, bfn = d.h.getDecodeExtForTag(xtag)
 		if bfn == nil {
-			decErr("Unable to find type mapped to extension tag: %v", xtag)
+			decErr("decodeNaked: Unable to find type mapped to extension tag: %v", xtag)
 		}
 		if fnerr := bfn(rv, d.r.readn(l)); fnerr != nil {
 			panic(fnerr)
 		}
 	case bincVdArray:
 		ctx = dncContainer
-		opts := h.(*BincHandle)
-		if opts.SliceType == nil {
+		if d.h.SliceType == nil {
 			rv = reflect.New(intfSliceTyp).Elem()
 		} else {
-			rv = reflect.New(opts.SliceType).Elem()
+			rv = reflect.New(d.h.SliceType).Elem()
 		}
 	case bincVdMap:
 		ctx = dncContainer
-		opts := h.(*BincHandle)
-		if opts.MapType == nil {
+		if d.h.MapType == nil {
 			rv = reflect.MakeMap(mapIntfIntfTyp)
 		} else {
-			rv = reflect.MakeMap(opts.MapType)
+			rv = reflect.MakeMap(d.h.MapType)
 		}
 	default:
-		decErr("Unrecognized d.vd: 0x%x", d.vd)
+		decErr("decodeNaked: Unrecognized d.vd: 0x%x", d.vd)
 	}
 
 	if ctx == dncHandled {
@@ -789,3 +832,7 @@ func (d *bincDecDriver) decodeNaked(h decodeHandleI) (rv reflect.Value, ctx deco
 	}
 	return
 }
+
+
+
+

+ 65 - 32
codec/decode.go

@@ -14,6 +14,8 @@ var (
 	msgBadDesc = "Unrecognized descriptor byte"
 )
 
+// when decoding without schema, the nakedContext tells us what 
+// we decoded into, or if decoding has been handled.
 type decodeNakedContext uint8
 
 const (
@@ -23,6 +25,24 @@ const (
 	dncContainer
 )
 
+// decodeEncodedType is the current type in the encoded stream
+type decodeEncodedType uint8
+
+const (
+	detUnset decodeEncodedType = iota
+	detNil
+	detInt
+	detUint
+	detFloat
+	detBool
+	detString
+	detBytes
+	detMap
+	detArray
+	detTimestamp
+	detExt
+)
+	
 // decReader abstracts the reading source, allowing implementations that can
 // read from an io.Reader or directly off a byte slice with zero-copying.
 type decReader interface {
@@ -36,11 +56,12 @@ type decReader interface {
 
 type decDriver interface {
 	initReadNext()
-	currentIsNil() bool
+	tryDecodeAsNil() bool
+	currentEncodedType() decodeEncodedType
 	decodeBuiltinType(rt uintptr, rv reflect.Value) bool
 	//decodeNaked should completely handle extensions, builtins, primitives, etc.
 	//Numbers are decoded as int64, uint64, float64 only (no smaller sized number types).
-	decodeNaked(h decodeHandleI) (rv reflect.Value, ctx decodeNakedContext)
+	decodeNaked() (rv reflect.Value, ctx decodeNakedContext)
 	decodeInt(bitsize uint8) (i int64)
 	decodeUint(bitsize uint8) (ui uint64)
 	decodeFloat(chkOverflow32 bool) (f float64)
@@ -229,7 +250,8 @@ func NewDecoderBytes(in []byte, h Handle) *Decoder {
 //     we reset the destination map/slice to be a zero-length non-nil map/slice.
 //   - Also, if the encoded value is Nil in the stream, then we try to set
 //     the container to its "zero" value (e.g. nil for slice/map).
-// 
+//   - Note that a struct can be decoded from an array in the stream,
+//     by updating fields as they occur in the struct.
 func (d *Decoder) Decode(v interface{}) (err error) {
 	defer panicToErr(&err)
 	d.decode(v)
@@ -303,7 +325,7 @@ func (d *Decoder) decodeValue(rv reflect.Value) {
 	//if nil interface, use some hieristics to set the nil interface to an
 	//appropriate value based on the first byte read (byte descriptor bd)
 	if wasNilIntf {
-		if dd.currentIsNil() {
+		if dd.tryDecodeAsNil() {
 			return
 		}
 		//Prevent from decoding into e.g. error, io.Reader, etc if it's nil and non-nil value in stream.
@@ -311,14 +333,14 @@ func (d *Decoder) decodeValue(rv reflect.Value) {
 		if num := rt.NumMethod(); num > 0 {
 			decErr("decodeValue: Cannot decode non-nil codec value into nil %v (%v methods)", rt, num)
 		} else {
-			rv, ndesc = dd.decodeNaked(d.h)
+			rv, ndesc = dd.decodeNaked()
 			if ndesc == dncHandled {
 				rvOrig.Set(rv)
 				return
 			}
 			rt = rv.Type()
 		}
-	} else if dd.currentIsNil() {
+	} else if dd.tryDecodeAsNil() {
 		// Note: if stream is set to nil, we set the dereferenced value to its "zero" value (if settable).
 		for rv.Kind() == reflect.Ptr {
 			rv = rv.Elem()
@@ -405,36 +427,47 @@ func (d *Decoder) decodeValue(rv reflect.Value) {
 	case reflect.Interface:
 		d.decodeValue(rv.Elem())
 	case reflect.Struct:
-		containerLen := dd.readMapLen()
-
-		if containerLen == 0 {
-			break
-		}
-
-		sfi := getStructFieldInfos(rtid, rt)
-		for j := 0; j < containerLen; j++ {
-			// var rvkencname string
-			// ddecode(&rvkencname)
-			dd.initReadNext()
-			rvkencname := dd.decodeString()
-			// rvksi := sfi.getForEncName(rvkencname)
-			if k := sfi.indexForEncName(rvkencname); k > -1 {
-				sfik := sfi[k]
-				if sfik.i > -1 {
-					d.decodeValue(rv.Field(int(sfik.i)))
-				} else {
-					d.decodeValue(rv.FieldByIndex(sfik.is))
-				}
-				// d.decodeValue(sfi.field(k, rv))
-			} else {
-				if d.h.errorIfNoField() {
-					decErr("No matching struct field found when decoding stream map with key: %v", rvkencname)
+		if currEncodedType := dd.currentEncodedType(); currEncodedType == detMap {
+			containerLen := dd.readMapLen()
+			if containerLen == 0 {
+				break
+			}
+			sfi := getStructFieldInfos(rtid, rt)
+			for j := 0; j < containerLen; j++ {
+				// var rvkencname string
+				// ddecode(&rvkencname)
+				dd.initReadNext()
+				rvkencname := dd.decodeString()
+				// rvksi := sfi.getForEncName(rvkencname)
+				if k := sfi.indexForEncName(rvkencname); k > -1 {
+					sfik := sfi[k]
+					if sfik.i > -1 {
+						d.decodeValue(rv.Field(int(sfik.i)))
+					} else {
+						d.decodeValue(rv.FieldByIndex(sfik.is))
+					}
+					// d.decodeValue(sfi.field(k, rv))
 				} else {
-					var nilintf0 interface{}
-					d.decodeValue(reflect.ValueOf(&nilintf0).Elem())
+					if d.h.errorIfNoField() {
+						decErr("No matching struct field found when decoding stream map with key: %v", rvkencname)
+					} else {
+						var nilintf0 interface{}
+						d.decodeValue(reflect.ValueOf(&nilintf0).Elem())
+					}
 				}
 			}
+		} else if currEncodedType == detArray {
+			containerLen := dd.readMapLen()
+			if containerLen == 0 {
+				break
+			}
+			for j := 0; j < containerLen; j++ {
+				d.decodeValue(rv.Field(j))
+			}
+		} else {
+			decErr("Only encoded map or array can be decoded into a struct")
 		}
+		
 	case reflect.Slice:
 		// Be more careful calling Set() here, because a reflect.Value from an array
 		// may have come in here (which may not be settable).

+ 53 - 21
codec/msgpack.go

@@ -123,6 +123,7 @@ type msgpackDecDriver struct {
 	h      *MsgpackHandle
 	bd     byte
 	bdRead bool
+	bdType decodeEncodedType
 }
 
 func (e *msgpackEncDriver) encodeBuiltinType(rt uintptr, rv reflect.Value) bool {
@@ -277,7 +278,7 @@ func (d *msgpackDecDriver) decodeBuiltinType(rt uintptr, rv reflect.Value) bool
 // It is called when a nil interface{} is passed, leaving it up to the DecDriver
 // to introspect the stream and decide how best to decode.
 // It deciphers the value by looking at the stream first.
-func (d *msgpackDecDriver) decodeNaked(h decodeHandleI) (rv reflect.Value, ctx decodeNakedContext) {
+func (d *msgpackDecDriver) decodeNaked() (rv reflect.Value, ctx decodeNakedContext) {
 	d.initReadNext()
 	bd := d.bd
 
@@ -326,8 +327,7 @@ func (d *msgpackDecDriver) decodeNaked(h decodeHandleI) (rv reflect.Value, ctx d
 		case bd == mpStr8, bd == mpStr16, bd == mpStr32, bd >= mpFixStrMin && bd <= mpFixStrMax:
 			ctx = dncContainer
 			// v = containerRaw
-			opts := h.(*MsgpackHandle)
-			if opts.rawToStringOverride || opts.RawToString {
+			if d.h.rawToStringOverride || d.h.RawToString {
 				var rvm string
 				rv = reflect.ValueOf(&rvm).Elem()
 			} else {
@@ -339,28 +339,25 @@ func (d *msgpackDecDriver) decodeNaked(h decodeHandleI) (rv reflect.Value, ctx d
 		case bd == mpArray16, bd == mpArray32, bd >= mpFixArrayMin && bd <= mpFixArrayMax:
 			ctx = dncContainer
 			// v = containerList
-			opts := h.(*MsgpackHandle)
-			if opts.SliceType == nil {
+			if d.h.SliceType == nil {
 				rv = reflect.New(intfSliceTyp).Elem()
 			} else {
-				rv = reflect.New(opts.SliceType).Elem()
+				rv = reflect.New(d.h.SliceType).Elem()
 			}
 		case bd == mpMap16, bd == mpMap32, bd >= mpFixMapMin && bd <= mpFixMapMax:
 			ctx = dncContainer
 			// v = containerMap
-			opts := h.(*MsgpackHandle)
-			if opts.MapType == nil {
+			if d.h.MapType == nil {
 				rv = reflect.MakeMap(mapIntfIntfTyp)
 			} else {
-				rv = reflect.MakeMap(opts.MapType)
+				rv = reflect.MakeMap(d.h.MapType)
 			}
 		case bd >= mpFixExt1 && bd <= mpFixExt16, bd >= mpExt8 && bd <= mpExt32:
 			//ctx = dncExt
 			clen := d.readExtLen()
 			xtag := d.r.readn1()
-			opts := h.(*MsgpackHandle)
 			var bfn func(reflect.Value, []byte) error
-			rv, bfn = opts.getDecodeExtForTag(xtag)
+			rv, bfn = d.h.getDecodeExtForTag(xtag)
 			if bfn == nil {
 				decErr("Unable to find type mapped to extension tag: %v", xtag)
 			}
@@ -560,9 +557,53 @@ func (d *msgpackDecDriver) initReadNext() {
 	}
 	d.bd = d.r.readn1()
 	d.bdRead = true
+	d.bdType = detUnset
 }
 
-func (d *msgpackDecDriver) currentIsNil() bool {
+func (d *msgpackDecDriver) currentEncodedType() decodeEncodedType {
+	if d.bdType == detUnset {
+	bd := d.bd
+	switch bd {
+	case mpNil:
+		d.bdType = detNil
+	case mpFalse, mpTrue:
+		d.bdType = detBool
+	case mpFloat, mpDouble:
+		d.bdType = detFloat
+	case mpUint8, mpUint16, mpUint32, mpUint64:
+		d.bdType = detUint
+	case mpInt8, mpInt16, mpInt32, mpInt64:
+		d.bdType = detInt
+	default:
+		switch {
+		case bd >= mpPosFixNumMin && bd <= mpPosFixNumMax:
+			d.bdType = detInt
+		case bd >= mpNegFixNumMin && bd <= mpNegFixNumMax:
+			d.bdType = detInt
+		case bd == mpStr8, bd == mpStr16, bd == mpStr32, bd >= mpFixStrMin && bd <= mpFixStrMax:
+			if d.h.rawToStringOverride || d.h.RawToString {
+				d.bdType = detString
+			} else {
+				d.bdType = detBytes
+			}
+		case bd == mpBin8, bd == mpBin16, bd == mpBin32:
+			d.bdType = detBytes
+		case bd == mpArray16, bd == mpArray32, bd >= mpFixArrayMin && bd <= mpFixArrayMax:
+			d.bdType = detArray
+		case bd == mpMap16, bd == mpMap32, bd >= mpFixMapMin && bd <= mpFixMapMax:
+			d.bdType = detMap
+		case bd >= mpFixExt1 && bd <= mpFixExt16, bd >= mpExt8 && bd <= mpExt32:
+			d.bdType = detExt
+		default:
+			decErr("currentEncodedType: Undeciphered descriptor: %s: hex: %x, dec: %d", msgBadDesc, bd, bd)
+		}
+	}
+	}
+	return d.bdType
+}
+
+
+func (d *msgpackDecDriver) tryDecodeAsNil() bool {
 	if d.bd == mpNil {
 		d.bdRead = false
 		return true
@@ -768,12 +809,3 @@ func (h *MsgpackHandle) writeExt() bool {
 	return h.WriteExt
 }
 
-
-
-
-
-
-
-
-
-