Ver Fonte

codec: limit size of bytes read to MaxInitLen (specifically for extensions)

Previously, we called readx for extensions. However, readx will, as needed,
make a slice with the capacity passed in. A bad length of the extensions
bytes can cause a request for very large memory.

To mitigate this, we now call readx only for bytesDecReaders, else we
delegate to decByteSlice function which is designed to incrementally
create the []byte as needed.
Ugorji Nwoke há 7 anos atrás
pai
commit
10d0049acb
5 ficheiros alterados com 36 adições e 7 exclusões
  1. 10 2
      codec/binc.go
  2. 7 1
      codec/decode.go
  3. 1 0
      codec/json.go
  4. 8 2
      codec/msgpack.go
  5. 10 2
      codec/simple.go

+ 10 - 2
codec/binc.go

@@ -829,7 +829,11 @@ func (d *bincDecDriver) decodeExtV(verifyTag bool, tag byte) (xtag byte, xbs []b
 			d.d.errorf("wrong extension tag - got %b, expecting: %v", xtag, tag)
 			return
 		}
-		xbs = d.r.readx(l)
+		if d.br {
+			xbs = d.r.readx(l)
+		} else {
+			xbs = decByteSlice(d.r, l, d.d.h.MaxInitLen, d.d.b[:])
+		}
 	} else if d.vd == bincVdByteArray {
 		xbs = d.DecodeBytes(nil, true)
 	} else {
@@ -912,7 +916,11 @@ func (d *bincDecDriver) DecodeNaked() {
 		n.v = valueTypeExt
 		l := d.decLen()
 		n.u = uint64(d.r.readn1())
-		n.l = d.r.readx(l)
+		if d.br {
+			n.l = d.r.readx(l)
+		} else {
+			n.l = decByteSlice(d.r, l, d.d.h.MaxInitLen, d.d.b[:])
+		}
 	case bincVdArray:
 		n.v = valueTypeArray
 		decodeFurther = true

+ 7 - 1
codec/decode.go

@@ -44,7 +44,6 @@ var (
 // read from an io.Reader or directly off a byte slice with zero-copying.
 type decReader interface {
 	unreadn1()
-
 	// readx will use the implementation scratch buffer if possible i.e. n < len(scratchbuf), OR
 	// just return a view of the []byte being decoded from.
 	// Ensure you call detachZeroCopyBytes later if this needs to be sent outside codec control.
@@ -2535,6 +2534,13 @@ func decByteSlice(r decReader, clen, maxInitLen int, bs []byte) (bsOut []byte) {
 	return
 }
 
+// func decByteSliceZeroCopy(r decReader, clen, maxInitLen int, bs []byte) (bsOut []byte) {
+// 	if _, ok := r.(*bytesDecReader); ok && clen <= maxInitLen {
+// 		return r.readx(clen)
+// 	}
+// 	return decByteSlice(r, clen, maxInitLen, bs)
+// }
+
 func detachZeroCopyBytes(isBytesReader bool, dest []byte, in []byte) (out []byte) {
 	if xlen := len(in); xlen > 0 {
 		if isBytesReader || xlen <= scratchByteArrayLen {

+ 1 - 0
codec/json.go

@@ -707,6 +707,7 @@ func (d *jsonDecDriver) ReadMapEnd() {
 }
 
 func (d *jsonDecDriver) readLit(length, fromIdx uint8) {
+	// length here is always less than 8 (literals are: null, true, false)
 	bs := d.r.readx(int(length))
 	d.tok = 0
 	if jsonValidateSymbols && !bytes.Equal(bs, jsonLiterals[fromIdx:fromIdx+length]) {

+ 8 - 2
codec/msgpack.go

@@ -517,8 +517,10 @@ func (d *msgpackDecDriver) DecodeNaked() {
 			if n.u == uint64(mpTimeExtTagU) {
 				n.v = valueTypeTime
 				n.t = d.decodeTime(clen)
-			} else {
+			} else if d.br {
 				n.l = d.r.readx(clen)
+			} else {
+				n.l = decByteSlice(d.r, clen, d.d.h.MaxInitLen, d.d.b[:])
 			}
 		default:
 			d.d.errorf("cannot infer value: %s: Ox%x/%d/%s", msgBadDesc, bd, bd, mpdesc(bd))
@@ -911,7 +913,11 @@ func (d *msgpackDecDriver) decodeExtV(verifyTag bool, tag byte) (xtag byte, xbs
 			d.d.errorf("wrong extension tag - got %b, expecting %v", xtag, tag)
 			return
 		}
-		xbs = d.r.readx(clen)
+		if d.br {
+			xbs = d.r.readx(clen)
+		} else {
+			xbs = decByteSlice(d.r, clen, d.d.h.MaxInitLen, d.d.b[:])
+		}
 	}
 	d.bdRead = false
 	return

+ 10 - 2
codec/simple.go

@@ -509,7 +509,11 @@ func (d *simpleDecDriver) decodeExtV(verifyTag bool, tag byte) (xtag byte, xbs [
 			d.d.errorf("wrong extension tag. Got %b. Expecting: %v", xtag, tag)
 			return
 		}
-		xbs = d.r.readx(l)
+		if d.br {
+			xbs = d.r.readx(l)
+		} else {
+			xbs = decByteSlice(d.r, l, d.d.h.MaxInitLen, d.d.b[:])
+		}
 	case simpleVdByteArray, simpleVdByteArray + 1,
 		simpleVdByteArray + 2, simpleVdByteArray + 3, simpleVdByteArray + 4:
 		xbs = d.DecodeBytes(nil, true)
@@ -570,7 +574,11 @@ func (d *simpleDecDriver) DecodeNaked() {
 		n.v = valueTypeExt
 		l := d.decLen()
 		n.u = uint64(d.r.readn1())
-		n.l = d.r.readx(l)
+		if d.br {
+			n.l = d.r.readx(l)
+		} else {
+			n.l = decByteSlice(d.r, l, d.d.h.MaxInitLen, d.d.b[:])
+		}
 	case simpleVdArray, simpleVdArray + 1, simpleVdArray + 2,
 		simpleVdArray + 3, simpleVdArray + 4:
 		n.v = valueTypeArray