瀏覽代碼

fix write string

Tao Wen 9 年之前
父節點
當前提交
1163c348f6
共有 2 個文件被更改,包括 77 次插入17 次删除
  1. 71 11
      feature_stream.go
  2. 6 6
      jsoniter_string_test.go

+ 71 - 11
feature_stream.go

@@ -80,6 +80,19 @@ func (b *Stream) writeByte(c byte) error {
 	return nil
 }
 
+func (b *Stream) writeTwoBytes(c1 byte, c2 byte) error {
+	if b.Error != nil {
+		return b.Error
+	}
+	if b.Available() <= 1 && b.Flush() != nil {
+		return b.Error
+	}
+	b.buf[b.n] = c1
+	b.buf[b.n + 1] = c2
+	b.n += 2
+	return nil
+}
+
 // Flush writes any buffered data to the underlying io.Writer.
 func (b *Stream) Flush() error {
 	if b.Error != nil {
@@ -118,20 +131,67 @@ func (b *Stream) WriteRaw(s string) {
 	b.n += n
 }
 
-func (b *Stream) WriteString(s string) {
-	b.writeByte('"')
-	for len(s) > b.Available() && b.Error == nil {
-		n := copy(b.buf[b.n:], s)
-		b.n += n
-		s = s[n:]
-		b.Flush()
+func (stream *Stream) WriteString(s string) {
+	valLen := len(s)
+	toWriteLen := valLen
+	bufLengthMinusTwo := len(stream.buf) - 2 // make room for the quotes
+	if stream.n + toWriteLen > bufLengthMinusTwo {
+		toWriteLen = bufLengthMinusTwo - stream.n
 	}
-	if b.Error != nil {
+	if toWriteLen < 0 {
+		stream.Flush()
+		if stream.n + toWriteLen > bufLengthMinusTwo {
+			toWriteLen = bufLengthMinusTwo - stream.n
+		}
+	}
+	n := stream.n
+	stream.buf[n] = '"'
+	n++
+	// write string, the fast path, without utf8 and escape support
+	i := 0
+	for ; i < toWriteLen; i++ {
+		c := s[i]
+		if c > 31 && c != '"' && c != '\\' {
+			stream.buf[n] = c
+			n++
+		} else {
+			break;
+		}
+	}
+	if i == valLen {
+		stream.buf[n] = '"'
+		n++
+		stream.n = n
 		return
 	}
-	n := copy(b.buf[b.n:], s)
-	b.n += n
-	b.writeByte('"')
+	stream.n = n
+	// for the remaining parts, we process them char by char
+	stream.writeStringSlowPath(s, i, valLen);
+	stream.writeByte('"')
+}
+
+func (stream *Stream) writeStringSlowPath(s string, i int, valLen int) {
+	for ; i < valLen; i++ {
+		c := s[i]
+		switch (c) {
+		case '"':
+			stream.writeTwoBytes('\\', '"')
+		case '\\':
+			stream.writeTwoBytes('\\', '\\')
+		case '\b':
+			stream.writeTwoBytes('\\', 'b')
+		case '\f':
+			stream.writeTwoBytes('\\', 'f')
+		case '\n':
+			stream.writeTwoBytes('\\', 'n')
+		case '\r':
+			stream.writeTwoBytes('\\', 'r')
+		case '\t':
+			stream.writeTwoBytes('\\', 't')
+		default:
+			stream.writeByte(c);
+		}
+	}
 }
 
 func (stream *Stream) WriteNil() {

+ 6 - 6
jsoniter_string_test.go

@@ -67,12 +67,12 @@ func Test_read_string_via_read(t *testing.T) {
 
 func Test_write_string(t *testing.T) {
 	should := require.New(t)
-	buf := &bytes.Buffer{}
-	stream := NewStream(buf, 4096)
-	stream.WriteString("hello")
-	stream.Flush()
-	should.Nil(stream.Error)
-	should.Equal(`"hello"`, buf.String())
+	str, err := MarshalToString("hello")
+	should.Equal(`"hello"`, str)
+	should.Nil(err)
+	str, err = MarshalToString(`hel"lo`)
+	should.Equal(`"hel\"lo"`, str)
+	should.Nil(err)
 }
 
 func Test_write_val_string(t *testing.T) {