Browse Source

codec: honor MaxInitLen when decoding strings or []byte.

This will prevent "out-of-memory" errors when decoding very large
strings or corrupt streams with a false length that is very large.

Fixes #187
Ugorji Nwoke 9 years ago
parent
commit
d23841a297
8 changed files with 36 additions and 15 deletions
  1. 5 5
      codec/binc.go
  2. 1 1
      codec/cbor.go
  3. 2 0
      codec/codec_test.go
  4. 18 5
      codec/decode.go
  5. 1 0
      codec/helper_test.go
  6. 1 1
      codec/msgpack.go
  7. 1 1
      codec/simple.go
  8. 7 2
      codec/tests.sh

+ 5 - 5
codec/binc.go

@@ -639,12 +639,12 @@ func (d *bincDecDriver) decStringAndBytes(bs []byte, withString, zerocopy bool)
 			if d.br {
 			if d.br {
 				bs2 = d.r.readx(slen)
 				bs2 = d.r.readx(slen)
 			} else if len(bs) == 0 {
 			} else if len(bs) == 0 {
-				bs2 = decByteSlice(d.r, slen, d.b[:])
+				bs2 = decByteSlice(d.r, slen, d.d.h.MaxInitLen, d.b[:])
 			} else {
 			} else {
-				bs2 = decByteSlice(d.r, slen, bs)
+				bs2 = decByteSlice(d.r, slen, d.d.h.MaxInitLen, bs)
 			}
 			}
 		} else {
 		} else {
-			bs2 = decByteSlice(d.r, slen, bs)
+			bs2 = decByteSlice(d.r, slen, d.d.h.MaxInitLen, bs)
 		}
 		}
 		if withString {
 		if withString {
 			s = string(bs2)
 			s = string(bs2)
@@ -696,7 +696,7 @@ func (d *bincDecDriver) decStringAndBytes(bs []byte, withString, zerocopy bool)
 			// since using symbols, do not store any part of
 			// since using symbols, do not store any part of
 			// the parameter bs in the map, as it might be a shared buffer.
 			// the parameter bs in the map, as it might be a shared buffer.
 			// bs2 = decByteSlice(d.r, slen, bs)
 			// bs2 = decByteSlice(d.r, slen, bs)
-			bs2 = decByteSlice(d.r, slen, nil)
+			bs2 = decByteSlice(d.r, slen, d.d.h.MaxInitLen, nil)
 			if withString {
 			if withString {
 				s = string(bs2)
 				s = string(bs2)
 			}
 			}
@@ -747,7 +747,7 @@ func (d *bincDecDriver) DecodeBytes(bs []byte, isstring, zerocopy bool) (bsOut [
 			bs = d.b[:]
 			bs = d.b[:]
 		}
 		}
 	}
 	}
-	return decByteSlice(d.r, clen, bs)
+	return decByteSlice(d.r, clen, d.d.h.MaxInitLen, bs)
 }
 }
 
 
 func (d *bincDecDriver) DecodeExt(rv interface{}, xtag uint64, ext Ext) (realxtag uint64) {
 func (d *bincDecDriver) DecodeExt(rv interface{}, xtag uint64, ext Ext) (realxtag uint64) {

+ 1 - 1
codec/cbor.go

@@ -421,7 +421,7 @@ func (d *cborDecDriver) DecodeBytes(bs []byte, isstring, zerocopy bool) (bsOut [
 			bs = d.b[:]
 			bs = d.b[:]
 		}
 		}
 	}
 	}
-	return decByteSlice(d.r, clen, bs)
+	return decByteSlice(d.r, clen, d.d.h.MaxInitLen, bs)
 }
 }
 
 
 func (d *cborDecDriver) DecodeString() (s string) {
 func (d *cborDecDriver) DecodeString() (s string) {

+ 2 - 0
codec/codec_test.go

@@ -104,6 +104,7 @@ func testInitFlags() {
 	flag.BoolVar(&testSkipIntf, "tf", false, "Skip Interfaces")
 	flag.BoolVar(&testSkipIntf, "tf", false, "Skip Interfaces")
 	flag.BoolVar(&testUseReset, "tr", false, "Use Reset")
 	flag.BoolVar(&testUseReset, "tr", false, "Use Reset")
 	flag.IntVar(&testJsonIndent, "td", 0, "Use JSON Indent")
 	flag.IntVar(&testJsonIndent, "td", 0, "Use JSON Indent")
+	flag.IntVar(&testMaxInitLen, "tx", 0, "Max Init Len")
 	flag.BoolVar(&testUseMust, "tm", true, "Use Must(En|De)code")
 	flag.BoolVar(&testUseMust, "tm", true, "Use Must(En|De)code")
 	flag.BoolVar(&testCheckCircRef, "tl", false, "Use Check Circular Ref")
 	flag.BoolVar(&testCheckCircRef, "tl", false, "Use Check Circular Ref")
 }
 }
@@ -346,6 +347,7 @@ func testInit() {
 		bh.Canonical = testCanonical
 		bh.Canonical = testCanonical
 		bh.CheckCircularRef = testCheckCircRef
 		bh.CheckCircularRef = testCheckCircRef
 		bh.StructToArray = testStructToArray
 		bh.StructToArray = testStructToArray
+		bh.MaxInitLen = testMaxInitLen
 		// mostly doing this for binc
 		// mostly doing this for binc
 		if testWriteNoSymbols {
 		if testWriteNoSymbols {
 			bh.AsSymbols = AsSymbolNone
 			bh.AsSymbols = AsSymbolNone

+ 18 - 5
codec/decode.go

@@ -107,10 +107,10 @@ type DecodeOptions struct {
 	// If nil, we use []interface{}
 	// If nil, we use []interface{}
 	SliceType reflect.Type
 	SliceType reflect.Type
 
 
-	// MaxInitLen defines the initial length that we "make" a collection (slice, chan or map) with.
+	// MaxInitLen defines the maxinum initial length that we "make" a collection (string, slice, map, chan).
 	// If 0 or negative, we default to a sensible value based on the size of an element in the collection.
 	// If 0 or negative, we default to a sensible value based on the size of an element in the collection.
 	//
 	//
-	// For example, when decoding, a stream may say that it has MAX_UINT elements.
+	// For example, when decoding, a stream may say that it has 2^64 elements.
 	// We should not auto-matically provision a slice of that length, to prevent Out-Of-Memory crash.
 	// We should not auto-matically provision a slice of that length, to prevent Out-Of-Memory crash.
 	// Instead, we provision up to MaxInitLen, fill that up, and start appending after that.
 	// Instead, we provision up to MaxInitLen, fill that up, and start appending after that.
 	MaxInitLen int
 	MaxInitLen int
@@ -1961,18 +1961,31 @@ func (x decSliceHelper) ElemContainerState(index int) {
 	}
 	}
 }
 }
 
 
-func decByteSlice(r decReader, clen int, bs []byte) (bsOut []byte) {
+func decByteSlice(r decReader, clen, maxInitLen int, bs []byte) (bsOut []byte) {
 	if clen == 0 {
 	if clen == 0 {
 		return zeroByteSlice
 		return zeroByteSlice
 	}
 	}
 	if len(bs) == clen {
 	if len(bs) == clen {
 		bsOut = bs
 		bsOut = bs
+		r.readb(bsOut)
 	} else if cap(bs) >= clen {
 	} else if cap(bs) >= clen {
 		bsOut = bs[:clen]
 		bsOut = bs[:clen]
+		r.readb(bsOut)
 	} else {
 	} else {
-		bsOut = make([]byte, clen)
+		// bsOut = make([]byte, clen)
+		len2, _ := decInferLen(clen, maxInitLen, 1)
+		bsOut = make([]byte, len2)
+		r.readb(bsOut)
+		for len2 < clen {
+			len3, _ := decInferLen(clen-len2, maxInitLen, 1)
+			// fmt.Printf(">>>>> TESTING: in loop: clen: %v, maxInitLen: %v, len2: %v, len3: %v\n", clen, maxInitLen, len2, len3)
+			bs3 := bsOut
+			bsOut = make([]byte, len2+len3)
+			copy(bsOut, bs3)
+			r.readb(bsOut[len2:])
+			len2 += len3
+		}
 	}
 	}
-	r.readb(bsOut)
 	return
 	return
 }
 }
 
 

+ 1 - 0
codec/helper_test.go

@@ -90,6 +90,7 @@ var (
 	testUseMust        bool
 	testUseMust        bool
 	testCheckCircRef   bool
 	testCheckCircRef   bool
 	testJsonIndent     int
 	testJsonIndent     int
+	testMaxInitLen     int
 )
 )
 
 
 func init() {
 func init() {

+ 1 - 1
codec/msgpack.go

@@ -549,7 +549,7 @@ func (d *msgpackDecDriver) DecodeBytes(bs []byte, isstring, zerocopy bool) (bsOu
 			bs = d.b[:]
 			bs = d.b[:]
 		}
 		}
 	}
 	}
-	return decByteSlice(d.r, clen, bs)
+	return decByteSlice(d.r, clen, d.d.h.MaxInitLen, bs)
 }
 }
 
 
 func (d *msgpackDecDriver) DecodeString() (s string) {
 func (d *msgpackDecDriver) DecodeString() (s string) {

+ 1 - 1
codec/simple.go

@@ -372,7 +372,7 @@ func (d *simpleDecDriver) DecodeBytes(bs []byte, isstring, zerocopy bool) (bsOut
 			bs = d.b[:]
 			bs = d.b[:]
 		}
 		}
 	}
 	}
-	return decByteSlice(d.r, clen, bs)
+	return decByteSlice(d.r, clen, d.d.h.MaxInitLen, bs)
 }
 }
 
 
 func (d *simpleDecDriver) DecodeExt(rv interface{}, xtag uint64, ext Ext) (realxtag uint64) {
 func (d *simpleDecDriver) DecodeExt(rv interface{}, xtag uint64, ext Ext) (realxtag uint64) {

+ 7 - 2
codec/tests.sh

@@ -17,7 +17,8 @@ _run() {
     zargs=""
     zargs=""
     local OPTIND 
     local OPTIND 
     OPTIND=1
     OPTIND=1
-    while getopts "_xurtcinsvgzmefdl" flag
+    # "_xurtcinsvgzmefdl" ===  "_cdefgilmnrtsuvxz"
+    while getopts "_cdefgilmnrtsuvwxz" flag
     do
     do
         case "x$flag" in 
         case "x$flag" in 
             'xr')  ;;
             'xr')  ;;
@@ -29,6 +30,7 @@ _run() {
             'xz') zargs="$zargs -tr" ;;
             'xz') zargs="$zargs -tr" ;;
             'xm') zargs="$zargs -tm" ;;
             'xm') zargs="$zargs -tm" ;;
             'xl') zargs="$zargs -tl" ;;
             'xl') zargs="$zargs -tl" ;;
+            'xw') zargs="$zargs -tx=10" ;;
             *) ;;
             *) ;;
         esac
         esac
     done
     done
@@ -37,7 +39,7 @@ _run() {
     # echo ">>>>>>> TAGS: $ztags"
     # echo ">>>>>>> TAGS: $ztags"
     
     
     OPTIND=1
     OPTIND=1
-    while getopts "_xurtcinsvgzmefdl" flag
+    while getopts "_cdefgilmnrtsuvwxz" flag
     do
     do
         case "x$flag" in 
         case "x$flag" in 
             'xt') printf ">>>>>>> REGULAR    : "; go test "-tags=$ztags" $zargs ; sleep 2 ;;
             'xt') printf ">>>>>>> REGULAR    : "; go test "-tags=$ztags" $zargs ; sleep 2 ;;
@@ -63,6 +65,7 @@ if [[ "x$@" = "x"  || "x$@" = "x-A" ]]; then
     # All: r, x, g, gu
     # All: r, x, g, gu
     _run "-_tcinsed_ml"  # regular
     _run "-_tcinsed_ml"  # regular
     _run "-_tcinsed_ml_z" # regular with reset
     _run "-_tcinsed_ml_z" # regular with reset
+    _run "-w_tcinsed_ml"  # regular with max init len
     _run "-_tcinsed_ml_f" # regular with no fastpath (notfastpath)
     _run "-_tcinsed_ml_f" # regular with no fastpath (notfastpath)
     _run "-x_tcinsed_ml" # external
     _run "-x_tcinsed_ml" # external
     _run "-gx_tcinsed_ml" # codecgen: requires external
     _run "-gx_tcinsed_ml" # codecgen: requires external
@@ -79,6 +82,7 @@ elif [[ "x$@" = "x-C" ]]; then
     # codecgen
     # codecgen
     _run "-gx_tcinsed_ml" # codecgen: requires external
     _run "-gx_tcinsed_ml" # codecgen: requires external
     _run "-gxu_tcinsed_ml" # codecgen + unsafe
     _run "-gxu_tcinsed_ml" # codecgen + unsafe
+    _run "-gxuw_tcinsed_ml" # codecgen + unsafe + maxinitlen
 elif [[ "x$@" = "x-X" ]]; then
 elif [[ "x$@" = "x-X" ]]; then
     # external
     # external
     _run "-x_tcinsed_ml" # external
     _run "-x_tcinsed_ml" # external
@@ -98,5 +102,6 @@ Usage: tests.sh [options...]
       just pass on the options from the command line 
       just pass on the options from the command line 
 EOF
 EOF
 else
 else
+    # e.g. ./tests.sh "-w_tcinsed_ml"
     _run "$@"
     _run "$@"
 fi
 fi