Prechádzať zdrojové kódy

codec: cbor: support SkipUnexpectedTags flag and optimize descriptor validation

      leverage major descriptor (instead of checking descriptor range)
      refactor handling of int, float and time.Time
      support SkipUnexpectedTags (so we skip unexpected tags when decoding as per spec)

Fixes #300
Ugorji Nwoke 6 rokov pred
rodič
commit
b7b8e1c6cc
1 zmenil súbory, kde vykonal 228 pridanie a 129 odobranie
  1. 228 129
      codec/cbor.go

+ 228 - 129
codec/cbor.go

@@ -9,17 +9,19 @@ import (
 	"time"
 )
 
+// major
 const (
 	cborMajorUint byte = iota
 	cborMajorNegInt
 	cborMajorBytes
-	cborMajorText
+	cborMajorString
 	cborMajorArray
 	cborMajorMap
 	cborMajorTag
-	cborMajorOther
+	cborMajorSimpleOrFloat
 )
 
+// simple
 const (
 	cborBdFalse byte = 0xf4 + iota
 	cborBdTrue
@@ -31,6 +33,7 @@ const (
 	cborBdFloat64
 )
 
+// indefinite
 const (
 	cborBdIndefiniteBytes  byte = 0x5f
 	cborBdIndefiniteString byte = 0x7f
@@ -49,6 +52,7 @@ const (
 	CborStreamBreak  byte = 0xff
 )
 
+// base values
 const (
 	cborBaseUint   byte = 0x00
 	cborBaseNegInt byte = 0x20
@@ -60,44 +64,51 @@ const (
 	cborBaseSimple byte = 0xe0
 )
 
+// const (
+// 	cborSelfDesrTag  byte = 0xd9
+// 	cborSelfDesrTag2 byte = 0xd9
+// 	cborSelfDesrTag3 byte = 0xf7
+// )
+
 func cbordesc(bd byte) string {
-	switch bd {
-	case cborBdNil:
-		return "nil"
-	case cborBdFalse:
-		return "false"
-	case cborBdTrue:
-		return "true"
-	case cborBdFloat16, cborBdFloat32, cborBdFloat64:
-		return "float"
-	case cborBdIndefiniteBytes:
-		return "bytes*"
-	case cborBdIndefiniteString:
-		return "string*"
-	case cborBdIndefiniteArray:
-		return "array*"
-	case cborBdIndefiniteMap:
-		return "map*"
-	default:
-		switch {
-		case bd >= cborBaseUint && bd < cborBaseNegInt:
-			return "(u)int"
-		case bd >= cborBaseNegInt && bd < cborBaseBytes:
-			return "int"
-		case bd >= cborBaseBytes && bd < cborBaseString:
-			return "bytes"
-		case bd >= cborBaseString && bd < cborBaseArray:
-			return "string"
-		case bd >= cborBaseArray && bd < cborBaseMap:
-			return "array"
-		case bd >= cborBaseMap && bd < cborBaseTag:
-			return "map"
-		case bd >= cborBaseTag && bd < cborBaseSimple:
-			return "ext"
+	switch bd >> 5 {
+	case cborMajorUint:
+		return "(u)int"
+	case cborMajorNegInt:
+		return "int"
+	case cborMajorBytes:
+		return "bytes"
+	case cborMajorString:
+		return "string"
+	case cborMajorArray:
+		return "array"
+	case cborMajorMap:
+		return "map"
+	case cborMajorTag:
+		return "tag"
+	case cborMajorSimpleOrFloat: // default
+		switch bd {
+		case cborBdNil:
+			return "nil"
+		case cborBdFalse:
+			return "false"
+		case cborBdTrue:
+			return "true"
+		case cborBdFloat16, cborBdFloat32, cborBdFloat64:
+			return "float"
+		case cborBdIndefiniteBytes:
+			return "bytes*"
+		case cborBdIndefiniteString:
+			return "string*"
+		case cborBdIndefiniteArray:
+			return "array*"
+		case cborBdIndefiniteMap:
+			return "map*"
 		default:
-			return "unknown"
+			return "unknown(simple)"
 		}
 	}
+	return "unknown"
 }
 
 // -------------------
@@ -293,17 +304,57 @@ type cborDecDriver struct {
 	br     bool // bytes reader
 	bdRead bool
 	bd     byte
+	st     bool // skip tags
 	noBuiltInTypes
 	// decNoSeparator
 	decDriverNoopContainerReader
 	// _ [3]uint64 // padding
 }
 
+// func (d *cborDecDriver) readNextBdSkipTags() {
+// 	d.bd = d.r.readn1()
+// 	if d.h.SkipUnexpectedTags {
+// 		for d.bd >= cborBaseTag && d.bd < cborBaseSimple {
+// 			d.decUint()
+// 			d.bd = d.r.readn1()
+// 		}
+// 	}
+// 	d.bdRead = true
+// }
+
+// func (d *cborDecDriver) readNextBd() {
+// 	d.bd = d.r.readn1()
+// 	if d.handleCborSelfDesc && d.bd == cborSelfDesrTag {
+// 		if x := d.readn1(); x == cborSelfDesrTag2 {
+// 			if x = d.readn1(); x != cborSelfDesrTag3 {
+// 				d.d.errorf("mishandled self desc: expected 0xd9d9f7, got: 0xd9d9%x", x)
+// 			}
+// 		} else {
+// 			d.unreadn1()
+// 		}
+// 	}
+// 	d.bdRead = true
+// }
+
 func (d *cborDecDriver) readNextBd() {
 	d.bd = d.r.readn1()
 	d.bdRead = true
 }
 
+// skipTags is called to skip any tags in the stream.
+//
+// Since any value can be tagged, then we should call skipTags
+// before any value is decoded.
+//
+// By definition, skipTags should not be called before
+// checking for break, or nil or undefined.
+func (d *cborDecDriver) skipTags() {
+	for d.bd>>5 == cborMajorTag {
+		d.decUint()
+		d.bd = d.r.readn1()
+	}
+}
+
 func (d *cborDecDriver) uncacheRead() {
 	if d.bdRead {
 		d.r.unreadn1()
@@ -317,13 +368,13 @@ func (d *cborDecDriver) ContainerType() (vt valueType) {
 	}
 	if d.bd == cborBdNil {
 		return valueTypeNil
-	} else if d.bd == cborBdIndefiniteBytes || (d.bd >= cborBaseBytes && d.bd < cborBaseString) {
+	} else if d.bd == cborBdIndefiniteBytes || (d.bd>>5 == cborMajorBytes) {
 		return valueTypeBytes
-	} else if d.bd == cborBdIndefiniteString || (d.bd >= cborBaseString && d.bd < cborBaseArray) {
+	} else if d.bd == cborBdIndefiniteString || (d.bd>>5 == cborMajorString) {
 		return valueTypeString
-	} else if d.bd == cborBdIndefiniteArray || (d.bd >= cborBaseArray && d.bd < cborBaseMap) {
+	} else if d.bd == cborBdIndefiniteArray || (d.bd>>5 == cborMajorArray) {
 		return valueTypeArray
-	} else if d.bd == cborBdIndefiniteMap || (d.bd >= cborBaseMap && d.bd < cborBaseTag) {
+	} else if d.bd == cborBdIndefiniteMap || (d.bd>>5 == cborMajorMap) {
 		return valueTypeMap
 	}
 	// else {
@@ -380,14 +431,16 @@ func (d *cborDecDriver) decCheckInteger() (neg bool) {
 	if !d.bdRead {
 		d.readNextBd()
 	}
+	if d.st {
+		d.skipTags()
+	}
 	major := d.bd >> 5
 	if major == cborMajorUint {
 	} else if major == cborMajorNegInt {
 		neg = true
 	} else {
-		d.d.errorf("not an integer - invalid major %v from descriptor %x/%s",
-			major, d.bd, cbordesc(d.bd))
-		return
+		d.d.errorf("invalid integer; got major %v from descriptor %x/%s, expected %v or %v",
+			major, d.bd, cbordesc(d.bd), cborMajorUint, cborMajorNegInt)
 	}
 	return
 }
@@ -395,20 +448,23 @@ func (d *cborDecDriver) decCheckInteger() (neg bool) {
 func (d *cborDecDriver) DecodeInt64() (i int64) {
 	neg := d.decCheckInteger()
 	ui := d.decUint()
+	d.bdRead = false
+	return cborDecInt64(ui, neg)
+}
+
+func cborDecInt64(ui uint64, neg bool) (i int64) {
 	// check if this number can be converted to an int without overflow
 	if neg {
 		i = -(chkOvf.SignedIntV(ui + 1))
 	} else {
 		i = chkOvf.SignedIntV(ui)
 	}
-	d.bdRead = false
 	return
 }
 
 func (d *cborDecDriver) DecodeUint64() (ui uint64) {
 	if d.decCheckInteger() {
-		d.d.errorf("assigning negative signed value to unsigned type")
-		return
+		d.d.errorf("cannot assign negative signed value to unsigned type")
 	}
 	ui = d.decUint()
 	d.bdRead = false
@@ -419,17 +475,25 @@ func (d *cborDecDriver) DecodeFloat64() (f float64) {
 	if !d.bdRead {
 		d.readNextBd()
 	}
-	if bd := d.bd; bd == cborBdFloat16 {
+	if d.st {
+		d.skipTags()
+	}
+	switch d.bd {
+	case cborBdFloat16:
 		f = float64(math.Float32frombits(halfFloatToFloatBits(bigen.Uint16(d.r.readx(2)))))
-	} else if bd == cborBdFloat32 {
+	case cborBdFloat32:
 		f = float64(math.Float32frombits(bigen.Uint32(d.r.readx(4))))
-	} else if bd == cborBdFloat64 {
+	case cborBdFloat64:
 		f = math.Float64frombits(bigen.Uint64(d.r.readx(8)))
-	} else if bd >= cborBaseUint && bd < cborBaseBytes {
-		f = float64(d.DecodeInt64())
-	} else {
-		d.d.errorf("float only valid from float16/32/64 - invalid descriptor %x/%s", bd, cbordesc(bd))
-		return
+	default:
+		major := d.bd >> 5
+		if major == cborMajorUint {
+			f = float64(cborDecInt64(d.decUint(), false))
+		} else if major == cborMajorNegInt {
+			f = float64(cborDecInt64(d.decUint(), true))
+		} else {
+			d.d.errorf("float only valid from float16/32/64 - invalid descriptor %x/%s", d.bd, cbordesc(d.bd))
+		}
 	}
 	d.bdRead = false
 	return
@@ -440,9 +504,12 @@ func (d *cborDecDriver) DecodeBool() (b bool) {
 	if !d.bdRead {
 		d.readNextBd()
 	}
-	if bd := d.bd; bd == cborBdTrue {
+	if d.st {
+		d.skipTags()
+	}
+	if d.bd == cborBdTrue {
 		b = true
-	} else if bd == cborBdFalse {
+	} else if d.bd == cborBdFalse {
 	} else {
 		d.d.errorf("not bool - %s %x/%s", msgBadDesc, d.bd, cbordesc(d.bd))
 		return
@@ -455,10 +522,16 @@ func (d *cborDecDriver) ReadMapStart() (length int) {
 	if !d.bdRead {
 		d.readNextBd()
 	}
+	if d.st {
+		d.skipTags()
+	}
 	d.bdRead = false
 	if d.bd == cborBdIndefiniteMap {
 		return -1
 	}
+	if d.bd>>5 != cborMajorMap {
+		d.d.errorf("error reading map; got major type: %x, expected: %x/%s", d.bd>>5, cborMajorMap, cbordesc(d.bd))
+	}
 	return d.decLen()
 }
 
@@ -466,10 +539,16 @@ func (d *cborDecDriver) ReadArrayStart() (length int) {
 	if !d.bdRead {
 		d.readNextBd()
 	}
+	if d.st {
+		d.skipTags()
+	}
 	d.bdRead = false
 	if d.bd == cborBdIndefiniteArray {
 		return -1
 	}
+	if d.bd>>5 != cborMajorArray {
+		d.d.errorf("error reading array; got major type: %x, expect: %x/%s", d.bd>>5, cborMajorArray, cbordesc(d.bd))
+	}
 	return d.decLen()
 }
 
@@ -480,11 +559,9 @@ func (d *cborDecDriver) decLen() int {
 func (d *cborDecDriver) decAppendIndefiniteBytes(bs []byte) []byte {
 	d.bdRead = false
 	for !d.CheckBreak() {
-		major := d.bd >> 5
-		if major != cborMajorBytes && major != cborMajorText {
-			d.d.errorf("expect bytes/string major type in indefinite string/bytes;"+
-				" got major %v from descriptor %x/%x", major, d.bd, cbordesc(d.bd))
-			return nil
+		if major := d.bd >> 5; major != cborMajorBytes && major != cborMajorString {
+			d.d.errorf("error reading indefinite string/bytes;"+
+				" got major %v, expected %x/%s", major, d.bd, cbordesc(d.bd))
 		}
 		n := d.decLen()
 		oldLen := len(bs)
@@ -508,6 +585,9 @@ func (d *cborDecDriver) DecodeBytes(bs []byte, zerocopy bool) (bsOut []byte) {
 	if !d.bdRead {
 		d.readNextBd()
 	}
+	if d.st {
+		d.skipTags()
+	}
 	if d.bd == cborBdNil || d.bd == cborBdUndefined {
 		d.bdRead = false
 		return nil
@@ -523,7 +603,7 @@ func (d *cborDecDriver) DecodeBytes(bs []byte, zerocopy bool) (bsOut []byte) {
 		return d.decAppendIndefiniteBytes(bs[:0])
 	}
 	// check if an "array" of uint8's (see ContainerType for how to infer if an array)
-	// if d.bd == cborBdIndefiniteArray || (d.bd >= cborBaseArray && d.bd < cborBaseMap) {
+	// if d.bd == cborBdIndefiniteArray || (d.bd >> 5 == cborMajorArray) {
 	// 	bsOut, _ = fastpathTV.DecSliceUint8V(bs, true, d.d)
 	// 	return
 	// }
@@ -542,7 +622,7 @@ func (d *cborDecDriver) DecodeBytes(bs []byte, zerocopy bool) (bsOut []byte) {
 		}
 		return bs
 	}
-	if d.bd >= cborBaseArray && d.bd < cborBaseMap {
+	if d.bd>>5 == cborMajorArray {
 		d.bdRead = false
 		if zerocopy && len(bs) == 0 {
 			bs = d.d.b[:]
@@ -582,15 +662,15 @@ func (d *cborDecDriver) DecodeTime() (t time.Time) {
 		d.bdRead = false
 		return
 	}
+	if d.bd>>5 != cborMajorTag {
+		d.d.errorf("error reading tag; expected major type: %x, got: %x", cborMajorTag, d.bd>>5)
+	}
 	xtag := d.decUint()
 	d.bdRead = false
 	return d.decodeTime(xtag)
 }
 
 func (d *cborDecDriver) decodeTime(xtag uint64) (t time.Time) {
-	if !d.bdRead {
-		d.readNextBd()
-	}
 	switch xtag {
 	case 0:
 		var err error
@@ -598,21 +678,25 @@ func (d *cborDecDriver) decodeTime(xtag uint64) (t time.Time) {
 			d.d.errorv(err)
 		}
 	case 1:
-		// decode an int64 or a float, and infer time.Time from there.
-		// for floats, round to microseconds, as that is what is guaranteed to fit well.
-		switch {
-		case d.bd == cborBdFloat16, d.bd == cborBdFloat32:
-			f1, f2 := math.Modf(d.DecodeFloat64())
-			t = time.Unix(int64(f1), int64(f2*1e9))
-		case d.bd == cborBdFloat64:
-			f1, f2 := math.Modf(d.DecodeFloat64())
-			t = time.Unix(int64(f1), int64(f2*1e9))
-		case d.bd >= cborBaseUint && d.bd < cborBaseNegInt,
-			d.bd >= cborBaseNegInt && d.bd < cborBaseBytes:
-			t = time.Unix(d.DecodeInt64(), 0)
-		default:
-			d.d.errorf("time.Time can only be decoded from a number (or RFC3339 string)")
-		}
+		// if !d.bdRead {
+		// 	d.readNextBd()
+		// }
+		// // decode an int64 or a float, and infer time.Time from there.
+		// // for floats, round to microseconds, as that is what is guaranteed to fit well.
+		// switch {
+		// case d.bd == cborBdFloat16, d.bd == cborBdFloat32:
+		// 	f1, f2 := math.Modf(d.DecodeFloat64())
+		// 	t = time.Unix(int64(f1), int64(f2*1e9))
+		// case d.bd == cborBdFloat64:
+		// 	f1, f2 := math.Modf(d.DecodeFloat64())
+		// 	t = time.Unix(int64(f1), int64(f2*1e9))
+		// case d.bd >= cborBaseUint && d.bd < cborBaseBytes:
+		// 	t = time.Unix(d.DecodeInt64(), 0)
+		// default:
+		// 	d.d.errorf("time.Time can only be decoded from a number (or RFC3339 string)")
+		// }
+		f1, f2 := math.Modf(d.DecodeFloat64())
+		t = time.Unix(int64(f1), int64(f2*1e9))
 	default:
 		d.d.errorf("invalid tag for time.Time - expecting 0 or 1, got 0x%x", xtag)
 	}
@@ -625,6 +709,9 @@ func (d *cborDecDriver) DecodeExt(rv interface{}, xtag uint64, ext Ext) (realxta
 	if !d.bdRead {
 		d.readNextBd()
 	}
+	if d.bd>>5 != cborMajorTag {
+		d.d.errorf("error reading tag; expected major type: %x, got: %x", cborMajorTag, d.bd>>5)
+	}
 	u := d.decUint()
 	d.bdRead = false
 	realxtag = u
@@ -653,71 +740,76 @@ func (d *cborDecDriver) DecodeNaked() {
 	n := d.d.naked()
 	var decodeFurther bool
 
-	switch d.bd {
-	case cborBdNil:
-		n.v = valueTypeNil
-	case cborBdFalse:
-		n.v = valueTypeBool
-		n.b = false
-	case cborBdTrue:
-		n.v = valueTypeBool
-		n.b = true
-	case cborBdFloat16, cborBdFloat32, cborBdFloat64:
-		n.v = valueTypeFloat
-		n.f = d.DecodeFloat64()
-	case cborBdIndefiniteBytes:
+	switch d.bd >> 5 {
+	case cborMajorUint:
+		if d.h.SignedInteger {
+			n.v = valueTypeInt
+			n.i = d.DecodeInt64()
+		} else {
+			n.v = valueTypeUint
+			n.u = d.DecodeUint64()
+		}
+	case cborMajorNegInt:
+		n.v = valueTypeInt
+		n.i = d.DecodeInt64()
+	case cborMajorBytes:
 		decNakedReadRawBytes(d, d.d, n, d.h.RawToString)
-	case cborBdIndefiniteString:
+	case cborMajorString:
 		n.v = valueTypeString
 		n.s = d.DecodeString()
-	case cborBdIndefiniteArray:
+	case cborMajorArray:
 		n.v = valueTypeArray
 		decodeFurther = true
-	case cborBdIndefiniteMap:
+	case cborMajorMap:
 		n.v = valueTypeMap
 		decodeFurther = true
-	default:
-		switch {
-		case d.bd >= cborBaseUint && d.bd < cborBaseNegInt:
-			if d.h.SignedInteger {
-				n.v = valueTypeInt
-				n.i = d.DecodeInt64()
-			} else {
-				n.v = valueTypeUint
-				n.u = d.DecodeUint64()
-			}
-		case d.bd >= cborBaseNegInt && d.bd < cborBaseBytes:
-			n.v = valueTypeInt
-			n.i = d.DecodeInt64()
-		case d.bd >= cborBaseBytes && d.bd < cborBaseString:
+	case cborMajorTag:
+		n.v = valueTypeExt
+		n.u = d.decUint()
+		n.l = nil
+		if n.u == 0 || n.u == 1 {
+			d.bdRead = false
+			n.v = valueTypeTime
+			n.t = d.decodeTime(n.u)
+		} else if d.st && d.h.getExtForTag(n.u) == nil {
+			d.skipTags()
+			d.DecodeNaked()
+			goto FINISH
+		}
+		// d.bdRead = false
+		// d.d.decode(&re.Value) // handled by decode itself.
+		// decodeFurther = true
+	case cborMajorSimpleOrFloat:
+		switch d.bd {
+		case cborBdNil:
+			n.v = valueTypeNil
+		case cborBdFalse:
+			n.v = valueTypeBool
+			n.b = false
+		case cborBdTrue:
+			n.v = valueTypeBool
+			n.b = true
+		case cborBdFloat16, cborBdFloat32, cborBdFloat64:
+			n.v = valueTypeFloat
+			n.f = d.DecodeFloat64()
+		case cborBdIndefiniteBytes:
 			decNakedReadRawBytes(d, d.d, n, d.h.RawToString)
-		case d.bd >= cborBaseString && d.bd < cborBaseArray:
+		case cborBdIndefiniteString:
 			n.v = valueTypeString
 			n.s = d.DecodeString()
-		case d.bd >= cborBaseArray && d.bd < cborBaseMap:
+		case cborBdIndefiniteArray:
 			n.v = valueTypeArray
 			decodeFurther = true
-		case d.bd >= cborBaseMap && d.bd < cborBaseTag:
+		case cborBdIndefiniteMap:
 			n.v = valueTypeMap
 			decodeFurther = true
-		case d.bd >= cborBaseTag && d.bd < cborBaseSimple:
-			n.v = valueTypeExt
-			n.u = d.decUint()
-			n.l = nil
-			if n.u == 0 || n.u == 1 {
-				d.bdRead = false
-				n.v = valueTypeTime
-				n.t = d.decodeTime(n.u)
-			}
-			// d.bdRead = false
-			// d.d.decode(&re.Value) // handled by decode itself.
-			// decodeFurther = true
 		default:
 			d.d.errorf("decodeNaked: Unrecognized d.bd: 0x%x", d.bd)
-			return
 		}
+	default: // should never happen
+		d.d.errorf("decodeNaked: Unrecognized d.bd: 0x%x", d.bd)
 	}
-
+FINISH:
 	if !decodeFurther {
 		d.bdRead = false
 	}
@@ -752,6 +844,12 @@ type CborHandle struct {
 	// If unset, we encode time.Time using seconds past epoch.
 	TimeRFC3339 bool
 
+	// SkipUnexpectedTags says to skip over any tags for which extensions are
+	// not defined. This is in keeping with the cbor spec on "Optional Tagging of Items".
+	//
+	// Furthermore, this allows the skipping over of the Self Describing Tag 0xd9d9f7.
+	SkipUnexpectedTags bool
+
 	_ [1]uint64 // padding (cache-aligned)
 }
 
@@ -768,7 +866,7 @@ func (h *CborHandle) newEncDriver(e *Encoder) encDriver {
 }
 
 func (h *CborHandle) newDecDriver(d *Decoder) decDriver {
-	return &cborDecDriver{d: d, h: h, r: d.r(), br: d.bytes}
+	return &cborDecDriver{d: d, h: h, r: d.r(), br: d.bytes, st: h.SkipUnexpectedTags}
 }
 
 func (e *cborEncDriver) reset() {
@@ -778,6 +876,7 @@ func (e *cborEncDriver) reset() {
 func (d *cborDecDriver) reset() {
 	d.r, d.br = d.d.r(), d.d.bytes
 	d.bd, d.bdRead = 0, false
+	d.st = d.h.SkipUnexpectedTags
 }
 
 var _ decDriver = (*cborDecDriver)(nil)