Kaynağa Gözat

snappy: add Reset methods to Reader and Writer.

LGTM=bradfitz
R=bradfitz
CC=golang-codereviews
https://codereview.appspot.com/202990043
Nigel Tao 10 yıl önce
ebeveyn
işleme
eaed4addcd
3 değiştirilmiş dosya ile 115 ekleme ve 18 silme
  1. 18 5
      snappy/decode.go
  2. 14 4
      snappy/encode.go
  3. 83 9
      snappy/snappy_test.go

+ 18 - 5
snappy/decode.go

@@ -131,15 +131,16 @@ func Decode(dst, src []byte) ([]byte, error) {
 // NewReader returns a new Reader that decompresses from r, using the framing
 // format described at
 // https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt
-func NewReader(r io.Reader) io.Reader {
-	return &reader{
+func NewReader(r io.Reader) *Reader {
+	return &Reader{
 		r:       r,
 		decoded: make([]byte, maxUncompressedChunkLen),
 		buf:     make([]byte, MaxEncodedLen(maxUncompressedChunkLen)+checksumSize),
 	}
 }
 
-type reader struct {
+// Reader is an io.Reader than can read Snappy-compressed bytes.
+type Reader struct {
 	r       io.Reader
 	err     error
 	decoded []byte
@@ -149,7 +150,18 @@ type reader struct {
 	readHeader bool
 }
 
-func (r *reader) readFull(p []byte) (ok bool) {
+// Reset discards any buffered data, resets all state, and switches the Snappy
+// reader to read from r. This permits reusing a Reader rather than allocating
+// a new one.
+func (r *Reader) Reset(reader io.Reader) {
+	r.r = reader
+	r.err = nil
+	r.i = 0
+	r.j = 0
+	r.readHeader = false
+}
+
+func (r *Reader) readFull(p []byte) (ok bool) {
 	if _, r.err = io.ReadFull(r.r, p); r.err != nil {
 		if r.err == io.ErrUnexpectedEOF {
 			r.err = ErrCorrupt
@@ -159,7 +171,8 @@ func (r *reader) readFull(p []byte) (ok bool) {
 	return true
 }
 
-func (r *reader) Read(p []byte) (int, error) {
+// Read satisfies the io.Reader interface.
+func (r *Reader) Read(p []byte) (int, error) {
 	if r.err != nil {
 		return 0, r.err
 	}

+ 14 - 4
snappy/encode.go

@@ -177,14 +177,15 @@ func MaxEncodedLen(srcLen int) int {
 // NewWriter returns a new Writer that compresses to w, using the framing
 // format described at
 // https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt
-func NewWriter(w io.Writer) io.Writer {
-	return &writer{
+func NewWriter(w io.Writer) *Writer {
+	return &Writer{
 		w:   w,
 		enc: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)),
 	}
 }
 
-type writer struct {
+// Writer is an io.Writer than can write Snappy-compressed bytes.
+type Writer struct {
 	w           io.Writer
 	err         error
 	enc         []byte
@@ -192,7 +193,16 @@ type writer struct {
 	wroteHeader bool
 }
 
-func (w *writer) Write(p []byte) (n int, errRet error) {
+// Reset discards the writer's state and switches the Snappy writer to write to
+// w. This permits reusing a Writer rather than allocating a new one.
+func (w *Writer) Reset(writer io.Writer) {
+	w.w = writer
+	w.err = nil
+	w.wroteHeader = false
+}
+
+// Write satisfies the io.Writer interface.
+func (w *Writer) Write(p []byte) (n int, errRet error) {
 	if w.err != nil {
 		return 0, w.err
 	}

+ 83 - 9
snappy/snappy_test.go

@@ -79,8 +79,19 @@ func TestSmallRegular(t *testing.T) {
 	}
 }
 
+func cmp(a, b []byte) error {
+	if len(a) != len(b) {
+		return fmt.Errorf("got %d bytes, want %d", len(a), len(b))
+	}
+	for i := range a {
+		if a[i] != b[i] {
+			return fmt.Errorf("byte #%d: got 0x%02x, want 0x%02x", i, a[i], b[i])
+		}
+	}
+	return nil
+}
+
 func TestFramingFormat(t *testing.T) {
-loop:
 	for _, tf := range testFiles {
 		if err := downloadTestdata(tf.filename); err != nil {
 			t.Fatalf("failed to download testdata: %s", err)
@@ -96,17 +107,80 @@ loop:
 			t.Errorf("%s: decoding: %v", tf.filename, err)
 			continue
 		}
-		if !bytes.Equal(dst, src) {
-			if len(dst) != len(src) {
-				t.Errorf("%s: got %d bytes, want %d", tf.filename, len(dst), len(src))
+		if err := cmp(dst, src); err != nil {
+			t.Errorf("%s: %v", tf.filename, err)
+			continue
+		}
+	}
+}
+
+func TestReaderReset(t *testing.T) {
+	gold := bytes.Repeat([]byte("All that is gold does not glitter,\n"), 10000)
+	buf := new(bytes.Buffer)
+	if _, err := NewWriter(buf).Write(gold); err != nil {
+		t.Fatalf("Write: %v", err)
+	}
+	encoded, invalid, partial := buf.String(), "invalid", "partial"
+	r := NewReader(nil)
+	for i, s := range []string{encoded, invalid, partial, encoded, partial, invalid, encoded, encoded} {
+		if s == partial {
+			r.Reset(strings.NewReader(encoded))
+			if _, err := r.Read(make([]byte, 101)); err != nil {
+				t.Errorf("#%d: %v", i, err)
 				continue
 			}
-			for i := range dst {
-				if dst[i] != src[i] {
-					t.Errorf("%s: byte #%d: got 0x%02x, want 0x%02x", tf.filename, i, dst[i], src[i])
-					continue loop
-				}
+			continue
+		}
+		r.Reset(strings.NewReader(s))
+		got, err := ioutil.ReadAll(r)
+		switch s {
+		case encoded:
+			if err != nil {
+				t.Errorf("#%d: %v", i, err)
+				continue
 			}
+			if err := cmp(got, gold); err != nil {
+				t.Errorf("%#d: %v", i, err)
+				continue
+			}
+		case invalid:
+			if err == nil {
+				t.Errorf("#%d: got nil error, want non-nil", i)
+				continue
+			}
+		}
+	}
+}
+
+func TestWriterReset(t *testing.T) {
+	gold := bytes.Repeat([]byte("Not all those who wander are lost;\n"), 10000)
+	var gots, wants [][]byte
+	const n = 20
+	w, failed := NewWriter(nil), false
+	for i := 0; i <= n; i++ {
+		buf := new(bytes.Buffer)
+		w.Reset(buf)
+		want := gold[:len(gold)*i/n]
+		if _, err := w.Write(want); err != nil {
+			t.Errorf("#%d: Write: %v", i, err)
+			failed = true
+			continue
+		}
+		got, err := ioutil.ReadAll(NewReader(buf))
+		if err != nil {
+			t.Errorf("#%d: ReadAll: %v", i, err)
+			failed = true
+			continue
+		}
+		gots = append(gots, got)
+		wants = append(wants, want)
+	}
+	if failed {
+		return
+	}
+	for i := range gots {
+		if err := cmp(gots[i], wants[i]); err != nil {
+			t.Errorf("#%d: %v", i, err)
 		}
 	}
 }