Procházet zdrojové kódy

go.net/html: fix the tokenizer when the underlying io.Reader returns
either (0, nil) or an (n, err) such that n > 0 && err != nil. Both
cases are valid by the io.Reader contract.

R=r
CC=golang-dev
https://golang.org/cl/12513043

Nigel Tao před 12 roky
rodič
revize
e8489d83dd
2 změnil soubory, kde provedl 108 přidání a 4 odebrání
  1. 28 4
      html/token.go
  2. 80 0
      html/token_test.go

+ 28 - 4
html/token.go

@@ -133,6 +133,11 @@ type Tokenizer struct {
 	// subsequent Next calls would return an ErrorToken.
 	// err is never reset. Once it becomes non-nil, it stays non-nil.
 	err error
+	// readErr is the error returned by the io.Reader r. It is separate from
+	// err because it is valid for an io.Reader to return (n int, err1 error)
+	// such that n > 0 && err1 != nil, and callers should always process the
+	// n > 0 bytes before considering the error err1.
+	readErr error
 	// buf[raw.start:raw.end] holds the raw bytes of the current token.
 	// buf[raw.end:] is buffered input that will yield future tokens.
 	raw span
@@ -222,7 +227,12 @@ func (z *Tokenizer) Err() error {
 // Pre-condition: z.err == nil.
 func (z *Tokenizer) readByte() byte {
 	if z.raw.end >= len(z.buf) {
-		// Our buffer is exhausted and we have to read from z.r.
+		// Our buffer is exhausted and we have to read from z.r. Check if the
+		// previous read resulted in an error.
+		if z.readErr != nil {
+			z.err = z.readErr
+			return 0
+		}
 		// We copy z.buf[z.raw.start:z.raw.end] to the beginning of z.buf. If the length
 		// z.raw.end - z.raw.start is more than half the capacity of z.buf, then we
 		// allocate a new buffer before the copy.
@@ -253,9 +263,10 @@ func (z *Tokenizer) readByte() byte {
 		z.raw.start, z.raw.end, z.buf = 0, d, buf1[:d]
 		// Now that we have copied the live bytes to the start of the buffer,
 		// we read from z.r into the remainder.
-		n, err := z.r.Read(buf1[d:cap(buf1)])
-		if err != nil {
-			z.err = err
+		var n int
+		n, z.readErr = readAtLeastOneByte(z.r, buf1[d:cap(buf1)])
+		if n == 0 {
+			z.err = z.readErr
 			return 0
 		}
 		z.buf = buf1[:d+n]
@@ -265,6 +276,19 @@ func (z *Tokenizer) readByte() byte {
 	return x
 }
 
+// readAtLeastOneByte wraps an io.Reader so that reading cannot return (0, nil).
+// It returns io.ErrNoProgress if the underlying r.Read method returns (0, nil)
+// too many times in succession.
+func readAtLeastOneByte(r io.Reader, b []byte) (int, error) {
+	for i := 0; i < 100; i++ {
+		n, err := r.Read(b)
+		if n != 0 || err != nil {
+			return n, err
+		}
+	}
+	return 0, io.ErrNoProgress
+}
+
 // skipWhiteSpace skips past any white space.
 func (z *Tokenizer) skipWhiteSpace() {
 	if z.err != nil {

+ 80 - 0
html/token_test.go

@@ -8,6 +8,7 @@ import (
 	"bytes"
 	"io"
 	"io/ioutil"
+	"reflect"
 	"runtime"
 	"strings"
 	"testing"
@@ -531,6 +532,85 @@ func TestConvertNewlines(t *testing.T) {
 	}
 }
 
+func TestReaderEdgeCases(t *testing.T) {
+	const s = "<p>An io.Reader can return (0, nil) or (n, io.EOF).</p>"
+	testCases := []io.Reader{
+		&zeroOneByteReader{s: s},
+		&eofStringsReader{s: s},
+		&stuckReader{},
+	}
+	for i, tc := range testCases {
+		got := []TokenType{}
+		z := NewTokenizer(tc)
+		for {
+			tt := z.Next()
+			if tt == ErrorToken {
+				break
+			}
+			got = append(got, tt)
+		}
+		if err := z.Err(); err != nil && err != io.EOF {
+			if err != io.ErrNoProgress {
+				t.Errorf("i=%d: %v", i, err)
+			}
+			continue
+		}
+		want := []TokenType{
+			StartTagToken,
+			TextToken,
+			EndTagToken,
+		}
+		if !reflect.DeepEqual(got, want) {
+			t.Errorf("i=%d: got %v, want %v", i, got, want)
+			continue
+		}
+	}
+}
+
+// zeroOneByteReader is like a strings.Reader that alternates between
+// returning 0 bytes and 1 byte at a time.
+type zeroOneByteReader struct {
+	s string
+	n int
+}
+
+func (r *zeroOneByteReader) Read(p []byte) (int, error) {
+	if len(p) == 0 {
+		return 0, nil
+	}
+	if len(r.s) == 0 {
+		return 0, io.EOF
+	}
+	r.n++
+	if r.n%2 != 0 {
+		return 0, nil
+	}
+	p[0], r.s = r.s[0], r.s[1:]
+	return 1, nil
+}
+
+// eofStringsReader is like a strings.Reader but can return an (n, err) where
+// n > 0 && err != nil.
+type eofStringsReader struct {
+	s string
+}
+
+func (r *eofStringsReader) Read(p []byte) (int, error) {
+	n := copy(p, r.s)
+	r.s = r.s[n:]
+	if r.s != "" {
+		return n, nil
+	}
+	return n, io.EOF
+}
+
+// stuckReader is an io.Reader that always returns no data and no error.
+type stuckReader struct{}
+
+func (*stuckReader) Read(p []byte) (int, error) {
+	return 0, nil
+}
+
 const (
 	rawLevel = iota
 	lowLevel