Переглянути джерело

Merge branch 'liggitt-malformed-string-test'

Tao Wen 8 роки тому
батько
коміт
4a84b0b30e

+ 3 - 3
feature_any.go

@@ -143,13 +143,13 @@ func (iter *Iterator) readAny() Any {
 		iter.unreadByte()
 		return &stringAny{baseAny{}, iter.ReadString()}
 	case 'n':
-		iter.skipFixedBytes(3) // null
+		iter.skipThreeBytes('u', 'l', 'l') // null
 		return &nilAny{}
 	case 't':
-		iter.skipFixedBytes(3) // true
+		iter.skipThreeBytes('r', 'u', 'e') // true
 		return &trueAny{}
 	case 'f':
-		iter.skipFixedBytes(4) // false
+		iter.skipFourBytes('a', 'l', 's', 'e') // false
 		return &falseAny{}
 	case '{':
 		return iter.readObjectAny()

+ 1 - 1
feature_iter.go

@@ -275,7 +275,7 @@ func (iter *Iterator) Read() interface{} {
 	case Number:
 		return iter.ReadFloat64()
 	case Nil:
-		iter.skipFixedBytes(4) // null
+		iter.skipFourBytes('n', 'u', 'l', 'l')
 		return nil
 	case Bool:
 		return iter.ReadBool()

+ 2 - 2
feature_iter_array.go

@@ -5,7 +5,7 @@ func (iter *Iterator) ReadArray() (ret bool) {
 	c := iter.nextToken()
 	switch c {
 	case 'n':
-		iter.skipFixedBytes(3)
+		iter.skipThreeBytes('u', 'l', 'l')
 		return false // null
 	case '[':
 		c = iter.nextToken()
@@ -44,7 +44,7 @@ func (iter *Iterator) ReadArrayCB(callback func(*Iterator) bool) (ret bool) {
 		return true
 	}
 	if c == 'n' {
-		iter.skipFixedBytes(3)
+		iter.skipThreeBytes('u', 'l', 'l')
 		return true // null
 	}
 	iter.ReportError("ReadArrayCB", "expect [ or n, but found: "+string([]byte{c}))

+ 4 - 4
feature_iter_object.go

@@ -13,7 +13,7 @@ func (iter *Iterator) ReadObject() (ret string) {
 	c := iter.nextToken()
 	switch c {
 	case 'n':
-		iter.skipFixedBytes(3)
+		iter.skipThreeBytes('u', 'l', 'l')
 		return "" // null
 	case '{':
 		c = iter.nextToken()
@@ -103,7 +103,7 @@ func (iter *Iterator) ReadObjectCB(callback func(*Iterator, string) bool) bool {
 		return false
 	}
 	if c == 'n' {
-		iter.skipFixedBytes(3)
+		iter.skipThreeBytes('u', 'l', 'l')
 		return true // null
 	}
 	iter.ReportError("ReadObjectCB", `expect { or n`)
@@ -144,7 +144,7 @@ func (iter *Iterator) ReadMapCB(callback func(*Iterator, string) bool) bool {
 		return false
 	}
 	if c == 'n' {
-		iter.skipFixedBytes(3)
+		iter.skipThreeBytes('u', 'l', 'l')
 		return true // null
 	}
 	iter.ReportError("ReadMapCB", `expect { or n`)
@@ -161,7 +161,7 @@ func (iter *Iterator) readObjectStart() bool {
 		iter.unreadByte()
 		return true
 	} else if c == 'n' {
-		iter.skipFixedBytes(3)
+		iter.skipThreeBytes('u', 'l', 'l')
 		return false
 	}
 	iter.ReportError("readObjectStart", "expect { or n")

+ 39 - 17
feature_iter_skip.go

@@ -7,7 +7,7 @@ import "fmt"
 func (iter *Iterator) ReadNil() (ret bool) {
 	c := iter.nextToken()
 	if c == 'n' {
-		iter.skipFixedBytes(3) // null
+		iter.skipThreeBytes('u', 'l', 'l') // null
 		return true
 	}
 	iter.unreadByte()
@@ -18,11 +18,11 @@ func (iter *Iterator) ReadNil() (ret bool) {
 func (iter *Iterator) ReadBool() (ret bool) {
 	c := iter.nextToken()
 	if c == 't' {
-		iter.skipFixedBytes(3)
+		iter.skipThreeBytes('r', 'u', 'e')
 		return true
 	}
 	if c == 'f' {
-		iter.skipFixedBytes(4)
+		iter.skipFourBytes('a', 'l', 's', 'e')
 		return false
 	}
 	iter.ReportError("ReadBool", "expect t or f")
@@ -71,10 +71,12 @@ func (iter *Iterator) Skip() {
 	switch c {
 	case '"':
 		iter.skipString()
-	case 'n', 't':
-		iter.skipFixedBytes(3) // null or true
+	case 'n':
+		iter.skipThreeBytes('u', 'l', 'l') // null
+	case 't':
+		iter.skipThreeBytes('r', 'u', 'e') // true
 	case 'f':
-		iter.skipFixedBytes(4) // false
+		iter.skipFourBytes('a', 'l', 's', 'e') // false
 	case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
 		iter.skipNumber()
 	case '[':
@@ -226,16 +228,36 @@ func (iter *Iterator) skipNumber() {
 	}
 }
 
-func (iter *Iterator) skipFixedBytes(n int) {
-	iter.head += n
-	if iter.head >= iter.tail {
-		more := iter.head - iter.tail
-		if !iter.loadMore() {
-			if more > 0 {
-				iter.ReportError("skipFixedBytes", "unexpected end")
-			}
-			return
-		}
-		iter.head += more
+func (iter *Iterator) skipFourBytes(b1, b2, b3, b4 byte) {
+	if iter.readByte() != b1 {
+		iter.ReportError("skipFourBytes", fmt.Sprintf("expect %s", string([]byte{b1, b2, b3, b4})))
+		return
+	}
+	if iter.readByte() != b2 {
+		iter.ReportError("skipFourBytes", fmt.Sprintf("expect %s", string([]byte{b1, b2, b3, b4})))
+		return
+	}
+	if iter.readByte() != b3 {
+		iter.ReportError("skipFourBytes", fmt.Sprintf("expect %s", string([]byte{b1, b2, b3, b4})))
+		return
+	}
+	if iter.readByte() != b4 {
+		iter.ReportError("skipFourBytes", fmt.Sprintf("expect %s", string([]byte{b1, b2, b3, b4})))
+		return
+	}
+}
+
+func (iter *Iterator) skipThreeBytes(b1, b2, b3 byte) {
+	if iter.readByte() != b1 {
+		iter.ReportError("skipThreeBytes", fmt.Sprintf("expect %s", string([]byte{b1, b2, b3})))
+		return
+	}
+	if iter.readByte() != b2 {
+		iter.ReportError("skipThreeBytes", fmt.Sprintf("expect %s", string([]byte{b1, b2, b3})))
+		return
+	}
+	if iter.readByte() != b3 {
+		iter.ReportError("skipThreeBytes", fmt.Sprintf("expect %s", string([]byte{b1, b2, b3})))
+		return
 	}
 }

+ 6 - 1
feature_iter_string.go

@@ -2,6 +2,7 @@ package jsoniter
 
 import (
 	"unicode/utf16"
+	"fmt"
 )
 
 // ReadString read string from iterator
@@ -16,11 +17,15 @@ func (iter *Iterator) ReadString() (ret string) {
 				return ret
 			} else if c == '\\' {
 				break
+			} else if c < ' ' {
+				iter.ReportError("ReadString",
+					fmt.Sprintf(`invalid control character found: %d`, c))
+				return
 			}
 		}
 		return iter.readStringSlowPath()
 	} else if c == 'n' {
-		iter.skipFixedBytes(3)
+		iter.skipThreeBytes('u', 'l', 'l')
 		return ""
 	}
 	iter.ReportError("ReadString", `expects " or n`)

+ 7 - 0
jsoniter_null_test.go

@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"github.com/stretchr/testify/require"
 	"testing"
+	"io"
 )
 
 func Test_read_null(t *testing.T) {
@@ -13,6 +14,12 @@ func Test_read_null(t *testing.T) {
 	should.True(iter.ReadNil())
 	iter = ParseString(ConfigDefault, `null`)
 	should.Nil(iter.Read())
+	iter = ParseString(ConfigDefault, `navy`)
+	iter.Read()
+	should.True(iter.Error != nil && iter.Error != io.EOF)
+	iter = ParseString(ConfigDefault, `navy`)
+	iter.ReadNil()
+	should.True(iter.Error != nil && iter.Error != io.EOF)
 }
 
 func Test_write_null(t *testing.T) {

+ 51 - 1
jsoniter_string_test.go

@@ -6,11 +6,61 @@ import (
 	"bytes"
 	"encoding/json"
 	"fmt"
-	"github.com/stretchr/testify/require"
 	"testing"
 	"unicode/utf8"
+
+	"github.com/stretchr/testify/require"
 )
 
+func Test_read_string(t *testing.T) {
+	badInputs := []string{
+		``,
+		`"`,
+		`"\"`,
+		`"\\\"`,
+		"\"\n\"",
+	}
+	for i :=0; i < 32; i++ {
+		// control characters are invalid
+		badInputs = append(badInputs, string([]byte{'"', byte(i), '"'}))
+	}
+
+	for _, input := range badInputs {
+		testReadString(t, input, "", true, "json.Unmarshal", json.Unmarshal)
+		testReadString(t, input, "", true, "jsoniter.Unmarshal", Unmarshal)
+		testReadString(t, input, "", true, "jsoniter.ConfigCompatibleWithStandardLibrary.Unmarshal", ConfigCompatibleWithStandardLibrary.Unmarshal)
+	}
+
+	goodInputs := []struct {
+		input       string
+		expectValue string
+	}{
+		{`""`, ""},
+		{`"a"`, "a"},
+		{`null`, ""},
+		{`"Iñtërnâtiônàlizætiøn,💝🐹🌇⛔"`, "Iñtërnâtiônàlizætiøn,💝🐹🌇⛔"},
+	}
+
+	for _, tc := range goodInputs {
+		testReadString(t, tc.input, tc.expectValue, false, "json.Unmarshal", json.Unmarshal)
+		testReadString(t, tc.input, tc.expectValue, false, "jsoniter.Unmarshal", Unmarshal)
+		testReadString(t, tc.input, tc.expectValue, false, "jsoniter.ConfigCompatibleWithStandardLibrary.Unmarshal", ConfigCompatibleWithStandardLibrary.Unmarshal)
+	}
+}
+
+func testReadString(t *testing.T, input string, expectValue string, expectError bool, marshalerName string, marshaler func([]byte, interface{}) error) {
+	var value string
+	err := marshaler([]byte(input), &value)
+	if expectError != (err != nil) {
+		t.Errorf("%q: %s: expected error %v, got %v", input, marshalerName, expectError, err)
+		return
+	}
+	if value != expectValue {
+		t.Errorf("%q: %s: expected %q, got %q", input, marshalerName, expectValue, value)
+		return
+	}
+}
+
 func Test_read_normal_string(t *testing.T) {
 	cases := map[string]string{
 		`"0123456789012345678901234567890123456789"`: `0123456789012345678901234567890123456789`,