Sfoglia il codice sorgente

Limit nesting depth

Jordan Liggitt 6 anni fa
parent
commit
eec24895fe
7 ha cambiato i file con 361 aggiunte e 6 eliminazioni
  1. 27 0
      iter.go
  2. 8 2
      iter_array.go
  3. 20 4
      iter_object.go
  4. 19 0
      iter_skip_sloppy.go
  5. 238 0
      misc_tests/jsoniter_nested_test.go
  6. 5 0
      reflect.go
  7. 44 0
      reflect_struct_decoder.go

+ 27 - 0
iter.go

@@ -74,6 +74,7 @@ type Iterator struct {
 	buf              []byte
 	head             int
 	tail             int
+	depth            int
 	captureStartedAt int
 	captured         []byte
 	Error            error
@@ -88,6 +89,7 @@ func NewIterator(cfg API) *Iterator {
 		buf:    nil,
 		head:   0,
 		tail:   0,
+		depth:  0,
 	}
 }
 
@@ -99,6 +101,7 @@ func Parse(cfg API, reader io.Reader, bufSize int) *Iterator {
 		buf:    make([]byte, bufSize),
 		head:   0,
 		tail:   0,
+		depth:  0,
 	}
 }
 
@@ -110,6 +113,7 @@ func ParseBytes(cfg API, input []byte) *Iterator {
 		buf:    input,
 		head:   0,
 		tail:   len(input),
+		depth:  0,
 	}
 }
 
@@ -128,6 +132,7 @@ func (iter *Iterator) Reset(reader io.Reader) *Iterator {
 	iter.reader = reader
 	iter.head = 0
 	iter.tail = 0
+	iter.depth = 0
 	return iter
 }
 
@@ -137,6 +142,7 @@ func (iter *Iterator) ResetBytes(input []byte) *Iterator {
 	iter.buf = input
 	iter.head = 0
 	iter.tail = len(input)
+	iter.depth = 0
 	return iter
 }
 
@@ -320,3 +326,24 @@ func (iter *Iterator) Read() interface{} {
 		return nil
 	}
 }
+
+// limit maximum depth of nesting, as allowed by https://tools.ietf.org/html/rfc7159#section-9
+const maxDepth = 10000
+
+func (iter *Iterator) incrementDepth() (success bool) {
+	iter.depth++
+	if iter.depth <= maxDepth {
+		return true
+	}
+	iter.ReportError("incrementDepth", "exceeded max depth")
+	return false
+}
+
+func (iter *Iterator) decrementDepth() (success bool) {
+	iter.depth--
+	if iter.depth >= 0 {
+		return true
+	}
+	iter.ReportError("decrementDepth", "unexpected negative nesting")
+	return false
+}

+ 8 - 2
iter_array.go

@@ -28,26 +28,32 @@ func (iter *Iterator) ReadArray() (ret bool) {
 func (iter *Iterator) ReadArrayCB(callback func(*Iterator) bool) (ret bool) {
 	c := iter.nextToken()
 	if c == '[' {
+		if !iter.incrementDepth() {
+			return false
+		}
 		c = iter.nextToken()
 		if c != ']' {
 			iter.unreadByte()
 			if !callback(iter) {
+				iter.decrementDepth()
 				return false
 			}
 			c = iter.nextToken()
 			for c == ',' {
 				if !callback(iter) {
+					iter.decrementDepth()
 					return false
 				}
 				c = iter.nextToken()
 			}
 			if c != ']' {
 				iter.ReportError("ReadArrayCB", "expect ] in the end, but found "+string([]byte{c}))
+				iter.decrementDepth()
 				return false
 			}
-			return true
+			return iter.decrementDepth()
 		}
-		return true
+		return iter.decrementDepth()
 	}
 	if c == 'n' {
 		iter.skipThreeBytes('u', 'l', 'l')

+ 20 - 4
iter_object.go

@@ -112,6 +112,9 @@ func (iter *Iterator) ReadObjectCB(callback func(*Iterator, string) bool) bool {
 	c := iter.nextToken()
 	var field string
 	if c == '{' {
+		if !iter.incrementDepth() {
+			return false
+		}
 		c = iter.nextToken()
 		if c == '"' {
 			iter.unreadByte()
@@ -121,6 +124,7 @@ func (iter *Iterator) ReadObjectCB(callback func(*Iterator, string) bool) bool {
 				iter.ReportError("ReadObject", "expect : after object field, but found "+string([]byte{c}))
 			}
 			if !callback(iter, field) {
+				iter.decrementDepth()
 				return false
 			}
 			c = iter.nextToken()
@@ -131,20 +135,23 @@ func (iter *Iterator) ReadObjectCB(callback func(*Iterator, string) bool) bool {
 					iter.ReportError("ReadObject", "expect : after object field, but found "+string([]byte{c}))
 				}
 				if !callback(iter, field) {
+					iter.decrementDepth()
 					return false
 				}
 				c = iter.nextToken()
 			}
 			if c != '}' {
 				iter.ReportError("ReadObjectCB", `object not ended with }`)
+				iter.decrementDepth()
 				return false
 			}
-			return true
+			return iter.decrementDepth()
 		}
 		if c == '}' {
-			return true
+			return iter.decrementDepth()
 		}
 		iter.ReportError("ReadObjectCB", `expect " after }, but found `+string([]byte{c}))
+		iter.decrementDepth()
 		return false
 	}
 	if c == 'n' {
@@ -159,15 +166,20 @@ func (iter *Iterator) ReadObjectCB(callback func(*Iterator, string) bool) bool {
 func (iter *Iterator) ReadMapCB(callback func(*Iterator, string) bool) bool {
 	c := iter.nextToken()
 	if c == '{' {
+		if !iter.incrementDepth() {
+			return false
+		}
 		c = iter.nextToken()
 		if c == '"' {
 			iter.unreadByte()
 			field := iter.ReadString()
 			if iter.nextToken() != ':' {
 				iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c}))
+				iter.decrementDepth()
 				return false
 			}
 			if !callback(iter, field) {
+				iter.decrementDepth()
 				return false
 			}
 			c = iter.nextToken()
@@ -175,23 +187,27 @@ func (iter *Iterator) ReadMapCB(callback func(*Iterator, string) bool) bool {
 				field = iter.ReadString()
 				if iter.nextToken() != ':' {
 					iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c}))
+					iter.decrementDepth()
 					return false
 				}
 				if !callback(iter, field) {
+					iter.decrementDepth()
 					return false
 				}
 				c = iter.nextToken()
 			}
 			if c != '}' {
 				iter.ReportError("ReadMapCB", `object not ended with }`)
+				iter.decrementDepth()
 				return false
 			}
-			return true
+			return iter.decrementDepth()
 		}
 		if c == '}' {
-			return true
+			return iter.decrementDepth()
 		}
 		iter.ReportError("ReadMapCB", `expect " after }, but found `+string([]byte{c}))
+		iter.decrementDepth()
 		return false
 	}
 	if c == 'n' {

+ 19 - 0
iter_skip_sloppy.go

@@ -22,6 +22,9 @@ func (iter *Iterator) skipNumber() {
 
 func (iter *Iterator) skipArray() {
 	level := 1
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		for i := iter.head; i < iter.tail; i++ {
 			switch iter.buf[i] {
@@ -31,8 +34,14 @@ func (iter *Iterator) skipArray() {
 				i = iter.head - 1 // it will be i++ soon
 			case '[': // If open symbol, increase level
 				level++
+				if !iter.incrementDepth() {
+					return
+				}
 			case ']': // If close symbol, increase level
 				level--
+				if !iter.decrementDepth() {
+					return
+				}
 
 				// If we have returned to the original level, we're done
 				if level == 0 {
@@ -50,6 +59,10 @@ func (iter *Iterator) skipArray() {
 
 func (iter *Iterator) skipObject() {
 	level := 1
+	if !iter.incrementDepth() {
+		return
+	}
+
 	for {
 		for i := iter.head; i < iter.tail; i++ {
 			switch iter.buf[i] {
@@ -59,8 +72,14 @@ func (iter *Iterator) skipObject() {
 				i = iter.head - 1 // it will be i++ soon
 			case '{': // If open symbol, increase level
 				level++
+				if !iter.incrementDepth() {
+					return
+				}
 			case '}': // If close symbol, increase level
 				level--
+				if !iter.decrementDepth() {
+					return
+				}
 
 				// If we have returned to the original level, we're done
 				if level == 0 {

+ 238 - 0
misc_tests/jsoniter_nested_test.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"github.com/json-iterator/go"
 	"reflect"
+	"strings"
 	"testing"
 )
 
@@ -15,6 +16,243 @@ type Level2 struct {
 	World string
 }
 
+func Test_deep_nested(t *testing.T) {
+	type unstructured interface{}
+
+	testcases := []struct {
+		name        string
+		data        []byte
+		expectError string
+	}{
+		{
+			name:        "array under maxDepth",
+			data:        []byte(`{"a":` + strings.Repeat(`[`, 10000-1) + strings.Repeat(`]`, 10000-1) + `}`),
+			expectError: "",
+		},
+		{
+			name:        "array over maxDepth",
+			data:        []byte(`{"a":` + strings.Repeat(`[`, 10000) + strings.Repeat(`]`, 10000) + `}`),
+			expectError: "max depth",
+		},
+		{
+			name:        "object under maxDepth",
+			data:        []byte(`{"a":` + strings.Repeat(`{"a":`, 10000-1) + `0` + strings.Repeat(`}`, 10000-1) + `}`),
+			expectError: "",
+		},
+		{
+			name:        "object over maxDepth",
+			data:        []byte(`{"a":` + strings.Repeat(`{"a":`, 10000) + `0` + strings.Repeat(`}`, 10000) + `}`),
+			expectError: "max depth",
+		},
+	}
+
+	targets := []struct {
+		name string
+		new  func() interface{}
+	}{
+		{
+			name: "unstructured",
+			new: func() interface{} {
+				var v interface{}
+				return &v
+			},
+		},
+		{
+			name: "typed named field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed missing field",
+			new: func() interface{} {
+				v := struct {
+					B interface{} `json:"b"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 1 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 2 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+					B interface{} `json:"b"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 3 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+					B interface{} `json:"b"`
+					C interface{} `json:"c"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 4 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+					B interface{} `json:"b"`
+					C interface{} `json:"c"`
+					D interface{} `json:"d"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 5 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+					B interface{} `json:"b"`
+					C interface{} `json:"c"`
+					D interface{} `json:"d"`
+					E interface{} `json:"e"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 6 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+					B interface{} `json:"b"`
+					C interface{} `json:"c"`
+					D interface{} `json:"d"`
+					E interface{} `json:"e"`
+					F interface{} `json:"f"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 7 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+					B interface{} `json:"b"`
+					C interface{} `json:"c"`
+					D interface{} `json:"d"`
+					E interface{} `json:"e"`
+					F interface{} `json:"f"`
+					G interface{} `json:"g"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 8 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+					B interface{} `json:"b"`
+					C interface{} `json:"c"`
+					D interface{} `json:"d"`
+					E interface{} `json:"e"`
+					F interface{} `json:"f"`
+					G interface{} `json:"g"`
+					H interface{} `json:"h"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 9 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+					B interface{} `json:"b"`
+					C interface{} `json:"c"`
+					D interface{} `json:"d"`
+					E interface{} `json:"e"`
+					F interface{} `json:"f"`
+					G interface{} `json:"g"`
+					H interface{} `json:"h"`
+					I interface{} `json:"i"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 10 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+					B interface{} `json:"b"`
+					C interface{} `json:"c"`
+					D interface{} `json:"d"`
+					E interface{} `json:"e"`
+					F interface{} `json:"f"`
+					G interface{} `json:"g"`
+					H interface{} `json:"h"`
+					I interface{} `json:"i"`
+					J interface{} `json:"j"`
+				}{}
+				return &v
+			},
+		},
+		{
+			name: "typed 11 field",
+			new: func() interface{} {
+				v := struct {
+					A interface{} `json:"a"`
+					B interface{} `json:"b"`
+					C interface{} `json:"c"`
+					D interface{} `json:"d"`
+					E interface{} `json:"e"`
+					F interface{} `json:"f"`
+					G interface{} `json:"g"`
+					H interface{} `json:"h"`
+					I interface{} `json:"i"`
+					J interface{} `json:"j"`
+					K interface{} `json:"k"`
+				}{}
+				return &v
+			},
+		},
+	}
+
+	for _, tc := range testcases {
+		t.Run(tc.name, func(t *testing.T) {
+			for _, target := range targets {
+				t.Run(target.name, func(t *testing.T) {
+					err := jsoniter.Unmarshal(tc.data, target.new())
+					if len(tc.expectError) == 0 {
+						if err != nil {
+							t.Errorf("unexpected error: %v", err)
+						}
+					} else {
+						if err == nil {
+							t.Errorf("expected error, got none")
+						} else if !strings.Contains(err.Error(), tc.expectError) {
+							t.Errorf("expected error containing '%s', got: %v", tc.expectError, err)
+						}
+					}
+				})
+			}
+		})
+	}
+}
+
 func Test_nested(t *testing.T) {
 	iter := jsoniter.ParseString(jsoniter.ConfigDefault, `{"hello": [{"world": "value1"}, {"world": "value2"}]}`)
 	l1 := Level1{}

+ 5 - 0
reflect.go

@@ -60,6 +60,7 @@ func (b *ctx) append(prefix string) *ctx {
 
 // ReadVal copy the underlying JSON into go interface, same as json.Unmarshal
 func (iter *Iterator) ReadVal(obj interface{}) {
+	depth := iter.depth
 	cacheKey := reflect2.RTypeOf(obj)
 	decoder := iter.cfg.getDecoderFromCache(cacheKey)
 	if decoder == nil {
@@ -76,6 +77,10 @@ func (iter *Iterator) ReadVal(obj interface{}) {
 		return
 	}
 	decoder.Decode(ptr, iter)
+	if iter.depth != depth {
+		iter.ReportError("ReadVal", "unexpected mismatched nesting")
+		return
+	}
 }
 
 // WriteVal copy the go interface into underlying JSON, same as json.Marshal

+ 44 - 0
reflect_struct_decoder.go

@@ -500,6 +500,9 @@ func (decoder *generalStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator)
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	var c byte
 	for c = ','; c == ','; c = iter.nextToken() {
 		decoder.decodeOneField(ptr, iter)
@@ -510,6 +513,7 @@ func (decoder *generalStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator)
 	if c != '}' {
 		iter.ReportError("struct Decode", `expect }, but found `+string([]byte{c}))
 	}
+	iter.decrementDepth()
 }
 
 func (decoder *generalStructDecoder) decodeOneField(ptr unsafe.Pointer, iter *Iterator) {
@@ -571,6 +575,9 @@ func (decoder *oneFieldStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator)
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		if iter.readFieldHash() == decoder.fieldHash {
 			decoder.fieldDecoder.Decode(ptr, iter)
@@ -584,6 +591,7 @@ func (decoder *oneFieldStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator)
 	if iter.Error != nil && iter.Error != io.EOF {
 		iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error())
 	}
+	iter.decrementDepth()
 }
 
 type twoFieldsStructDecoder struct {
@@ -598,6 +606,9 @@ func (decoder *twoFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		switch iter.readFieldHash() {
 		case decoder.fieldHash1:
@@ -614,6 +625,7 @@ func (decoder *twoFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator
 	if iter.Error != nil && iter.Error != io.EOF {
 		iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error())
 	}
+	iter.decrementDepth()
 }
 
 type threeFieldsStructDecoder struct {
@@ -630,6 +642,9 @@ func (decoder *threeFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		switch iter.readFieldHash() {
 		case decoder.fieldHash1:
@@ -648,6 +663,7 @@ func (decoder *threeFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat
 	if iter.Error != nil && iter.Error != io.EOF {
 		iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error())
 	}
+	iter.decrementDepth()
 }
 
 type fourFieldsStructDecoder struct {
@@ -666,6 +682,9 @@ func (decoder *fourFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		switch iter.readFieldHash() {
 		case decoder.fieldHash1:
@@ -686,6 +705,7 @@ func (decoder *fourFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato
 	if iter.Error != nil && iter.Error != io.EOF {
 		iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error())
 	}
+	iter.decrementDepth()
 }
 
 type fiveFieldsStructDecoder struct {
@@ -706,6 +726,9 @@ func (decoder *fiveFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		switch iter.readFieldHash() {
 		case decoder.fieldHash1:
@@ -728,6 +751,7 @@ func (decoder *fiveFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato
 	if iter.Error != nil && iter.Error != io.EOF {
 		iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error())
 	}
+	iter.decrementDepth()
 }
 
 type sixFieldsStructDecoder struct {
@@ -750,6 +774,9 @@ func (decoder *sixFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		switch iter.readFieldHash() {
 		case decoder.fieldHash1:
@@ -774,6 +801,7 @@ func (decoder *sixFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator
 	if iter.Error != nil && iter.Error != io.EOF {
 		iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error())
 	}
+	iter.decrementDepth()
 }
 
 type sevenFieldsStructDecoder struct {
@@ -798,6 +826,9 @@ func (decoder *sevenFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		switch iter.readFieldHash() {
 		case decoder.fieldHash1:
@@ -824,6 +855,7 @@ func (decoder *sevenFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat
 	if iter.Error != nil && iter.Error != io.EOF {
 		iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error())
 	}
+	iter.decrementDepth()
 }
 
 type eightFieldsStructDecoder struct {
@@ -850,6 +882,9 @@ func (decoder *eightFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		switch iter.readFieldHash() {
 		case decoder.fieldHash1:
@@ -878,6 +913,7 @@ func (decoder *eightFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat
 	if iter.Error != nil && iter.Error != io.EOF {
 		iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error())
 	}
+	iter.decrementDepth()
 }
 
 type nineFieldsStructDecoder struct {
@@ -906,6 +942,9 @@ func (decoder *nineFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		switch iter.readFieldHash() {
 		case decoder.fieldHash1:
@@ -936,6 +975,7 @@ func (decoder *nineFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato
 	if iter.Error != nil && iter.Error != io.EOF {
 		iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error())
 	}
+	iter.decrementDepth()
 }
 
 type tenFieldsStructDecoder struct {
@@ -966,6 +1006,9 @@ func (decoder *tenFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator
 	if !iter.readObjectStart() {
 		return
 	}
+	if !iter.incrementDepth() {
+		return
+	}
 	for {
 		switch iter.readFieldHash() {
 		case decoder.fieldHash1:
@@ -998,6 +1041,7 @@ func (decoder *tenFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator
 	if iter.Error != nil && iter.Error != io.EOF {
 		iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error())
 	}
+	iter.decrementDepth()
 }
 
 type structFieldDecoder struct {