Explorar el Código

snappy: fix (1) encoding a 0-length input returned garbage, and
(2) decoding into an existing buffer returned excess bytes.

R=bradfitz
CC=golang-dev
http://codereview.appspot.com/6294045

Nigel Tao hace 13 años
padre
commit
2d44ef2b92
Se han modificado 3 ficheros con 26 adiciones y 11 borrados
  1. 4 1
      snappy/decode.go
  2. 3 1
      snappy/encode.go
  3. 19 9
      snappy/snappy_test.go

+ 4 - 1
snappy/decode.go

@@ -117,5 +117,8 @@ func Decode(dst, src []byte) ([]byte, error) {
 			dst[d] = dst[d-offset]
 		}
 	}
-	return dst, nil
+	if d != dLen {
+		return nil, ErrCorrupt
+	}
+	return dst[:d], nil
 }

+ 3 - 1
snappy/encode.go

@@ -96,7 +96,9 @@ func Encode(dst, src []byte) ([]byte, error) {
 
 	// Return early if src is short.
 	if len(src) <= 4 {
-		d += emitLiteral(dst[d:], src)
+		if len(src) != 0 {
+			d += emitLiteral(dst[d:], src)
+		}
 		return dst[:d], nil
 	}
 

+ 19 - 9
snappy/snappy_test.go

@@ -13,12 +13,12 @@ import (
 	"testing"
 )
 
-func roundtrip(b []byte) error {
-	e, err := Encode(nil, b)
+func roundtrip(b, ebuf, dbuf []byte) error {
+	e, err := Encode(ebuf, b)
 	if err != nil {
 		return fmt.Errorf("encoding error: %v", err)
 	}
-	d, err := Decode(nil, e)
+	d, err := Decode(dbuf, e)
 	if err != nil {
 		return fmt.Errorf("decoding error: %v", err)
 	}
@@ -28,11 +28,21 @@ func roundtrip(b []byte) error {
 	return nil
 }
 
+func TestEmpty(t *testing.T) {
+	if err := roundtrip(nil, nil, nil); err != nil {
+		t.Fatal(err)
+	}
+}
+
 func TestSmallCopy(t *testing.T) {
-	for i := 0; i < 32; i++ {
-		s := "aaaa" + strings.Repeat("b", i) + "aaaabbbb"
-		if err := roundtrip([]byte(s)); err != nil {
-			t.Fatalf("i=%d: %v", i, err)
+	for _, ebuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} {
+		for _, dbuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} {
+			for i := 0; i < 32; i++ {
+				s := "aaaa" + strings.Repeat("b", i) + "aaaabbbb"
+				if err := roundtrip([]byte(s), ebuf, dbuf); err != nil {
+					t.Errorf("len(ebuf)=%d, len(dbuf)=%d, i=%d: %v", len(ebuf), len(dbuf), i, err)
+				}
+			}
 		}
 	}
 }
@@ -44,7 +54,7 @@ func TestSmallRand(t *testing.T) {
 		for i, _ := range b {
 			b[i] = uint8(rand.Uint32())
 		}
-		if err := roundtrip(b); err != nil {
+		if err := roundtrip(b, nil, nil); err != nil {
 			t.Fatal(err)
 		}
 	}
@@ -56,7 +66,7 @@ func TestSmallRegular(t *testing.T) {
 		for i, _ := range b {
 			b[i] = uint8(i%10 + 'a')
 		}
-		if err := roundtrip(b); err != nil {
+		if err := roundtrip(b, nil, nil); err != nil {
 			t.Fatal(err)
 		}
 	}