Parcourir la source

snappy: implement the framing format.

The format is described at
https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt

LGTM=bradfitz
R=bradfitz
CC=golang-codereviews
https://codereview.appspot.com/199400043
Nigel Tao il y a 10 ans
Parent
commit
4c08685702
4 fichiers modifiés avec 312 ajouts et 21 suppressions
  1. 157 2
      snappy/decode.go
  2. 74 0
      snappy/encode.go
  3. 30 0
      snappy/snappy.go
  4. 51 19
      snappy/snappy_test.go

+ 157 - 2
snappy/decode.go

@@ -7,10 +7,15 @@ package snappy
 import (
 	"encoding/binary"
 	"errors"
+	"io"
 )
 
-// ErrCorrupt reports that the input is invalid.
-var ErrCorrupt = errors.New("snappy: corrupt input")
+var (
+	// ErrCorrupt reports that the input is invalid.
+	ErrCorrupt = errors.New("snappy: corrupt input")
+	// ErrUnsupported reports that the input isn't supported.
+	ErrUnsupported = errors.New("snappy: unsupported input")
+)
 
 // DecodedLen returns the length of the decoded block.
 func DecodedLen(src []byte) (int, error) {
@@ -122,3 +127,153 @@ func Decode(dst, src []byte) ([]byte, error) {
 	}
 	return dst[:d], nil
 }
+
+// NewReader returns a new Reader that decompresses from r, using the framing
+// format described at
+// https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt
+func NewReader(r io.Reader) io.Reader {
+	return &reader{
+		r:       r,
+		decoded: make([]byte, maxUncompressedChunkLen),
+		buf:     make([]byte, MaxEncodedLen(maxUncompressedChunkLen)+checksumSize),
+	}
+}
+
+type reader struct {
+	r       io.Reader
+	err     error
+	decoded []byte
+	buf     []byte
+	// decoded[i:j] contains decoded bytes that have not yet been passed on.
+	i, j       int
+	readHeader bool
+}
+
+func (r *reader) readFull(p []byte) (ok bool) {
+	if _, r.err = io.ReadFull(r.r, p); r.err != nil {
+		if r.err == io.ErrUnexpectedEOF {
+			r.err = ErrCorrupt
+		}
+		return false
+	}
+	return true
+}
+
+func (r *reader) Read(p []byte) (int, error) {
+	if r.err != nil {
+		return 0, r.err
+	}
+	for {
+		if r.i < r.j {
+			n := copy(p, r.decoded[r.i:r.j])
+			r.i += n
+			return n, nil
+		}
+		if !r.readFull(r.buf[:4]) {
+			return 0, r.err
+		}
+		chunkType := r.buf[0]
+		if !r.readHeader {
+			if chunkType != chunkTypeStreamIdentifier {
+				r.err = ErrCorrupt
+				return 0, r.err
+			}
+			r.readHeader = true
+		}
+		chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
+		if chunkLen > len(r.buf) {
+			r.err = ErrUnsupported
+			return 0, r.err
+		}
+
+		// The chunk types are specified at
+		// https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt
+		switch chunkType {
+		case chunkTypeCompressedData:
+			// Section 4.2. Compressed data (chunk type 0x00).
+			if chunkLen < checksumSize {
+				r.err = ErrCorrupt
+				return 0, r.err
+			}
+			buf := r.buf[:chunkLen]
+			if !r.readFull(buf) {
+				return 0, r.err
+			}
+			checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
+			buf = buf[checksumSize:]
+
+			n, err := DecodedLen(buf)
+			if err != nil {
+				r.err = err
+				return 0, r.err
+			}
+			if n > len(r.decoded) {
+				r.err = ErrCorrupt
+				return 0, r.err
+			}
+			if _, err := Decode(r.decoded, buf); err != nil {
+				r.err = err
+				return 0, r.err
+			}
+			if crc(r.decoded[:n]) != checksum {
+				r.err = ErrCorrupt
+				return 0, r.err
+			}
+			r.i, r.j = 0, n
+			continue
+
+		case chunkTypeUncompressedData:
+			// Section 4.3. Uncompressed data (chunk type 0x01).
+			if chunkLen < checksumSize {
+				r.err = ErrCorrupt
+				return 0, r.err
+			}
+			buf := r.buf[:checksumSize]
+			if !r.readFull(buf) {
+				return 0, r.err
+			}
+			checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
+			// Read directly into r.decoded instead of via r.buf.
+			n := chunkLen - checksumSize
+			if !r.readFull(r.decoded[:n]) {
+				return 0, r.err
+			}
+			if crc(r.decoded[:n]) != checksum {
+				r.err = ErrCorrupt
+				return 0, r.err
+			}
+			r.i, r.j = 0, n
+			continue
+
+		case chunkTypeStreamIdentifier:
+			// Section 4.1. Stream identifier (chunk type 0xff).
+			if chunkLen != len(magicBody) {
+				r.err = ErrCorrupt
+				return 0, r.err
+			}
+			if !r.readFull(r.buf[:len(magicBody)]) {
+				return 0, r.err
+			}
+			for i := 0; i < len(magicBody); i++ {
+				if r.buf[i] != magicBody[i] {
+					r.err = ErrCorrupt
+					return 0, r.err
+				}
+			}
+			continue
+		}
+
+		if chunkType <= 0x7f {
+			// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
+			r.err = ErrUnsupported
+			return 0, r.err
+
+		} else {
+			// Section 4.4 Padding (chunk type 0xfe).
+			// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
+			if !r.readFull(r.buf[:chunkLen]) {
+				return 0, r.err
+			}
+		}
+	}
+}

+ 74 - 0
snappy/encode.go

@@ -6,6 +6,7 @@ package snappy
 
 import (
 	"encoding/binary"
+	"io"
 )
 
 // We limit how far copy back-references can go, the same as the C++ code.
@@ -172,3 +173,76 @@ func MaxEncodedLen(srcLen int) int {
 	// This last factor dominates the blowup, so the final estimate is:
 	return 32 + srcLen + srcLen/6
 }
+
+// NewWriter returns a new Writer that compresses to w, using the framing
+// format described at
+// https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt
+func NewWriter(w io.Writer) io.Writer {
+	return &writer{
+		w:   w,
+		enc: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)),
+	}
+}
+
+type writer struct {
+	w           io.Writer
+	err         error
+	enc         []byte
+	buf         [checksumSize + chunkHeaderSize]byte
+	wroteHeader bool
+}
+
+func (w *writer) Write(p []byte) (n int, errRet error) {
+	if w.err != nil {
+		return 0, w.err
+	}
+	if !w.wroteHeader {
+		copy(w.enc, magicChunk)
+		if _, err := w.w.Write(w.enc[:len(magicChunk)]); err != nil {
+			w.err = err
+			return n, err
+		}
+		w.wroteHeader = true
+	}
+	for len(p) > 0 {
+		var uncompressed []byte
+		if len(p) > maxUncompressedChunkLen {
+			uncompressed, p = p[:maxUncompressedChunkLen], p[maxUncompressedChunkLen:]
+		} else {
+			uncompressed, p = p, nil
+		}
+		checksum := crc(uncompressed)
+
+		// Compress the buffer, discarding the result if the improvement
+		// isn't at least 12.5%.
+		chunkType := uint8(chunkTypeCompressedData)
+		chunkBody, err := Encode(w.enc, uncompressed)
+		if err != nil {
+			w.err = err
+			return n, err
+		}
+		if len(chunkBody) >= len(uncompressed)-len(uncompressed)/8 {
+			chunkType, chunkBody = chunkTypeUncompressedData, uncompressed
+		}
+
+		chunkLen := 4 + len(chunkBody)
+		w.buf[0] = chunkType
+		w.buf[1] = uint8(chunkLen >> 0)
+		w.buf[2] = uint8(chunkLen >> 8)
+		w.buf[3] = uint8(chunkLen >> 16)
+		w.buf[4] = uint8(checksum >> 0)
+		w.buf[5] = uint8(checksum >> 8)
+		w.buf[6] = uint8(checksum >> 16)
+		w.buf[7] = uint8(checksum >> 24)
+		if _, err = w.w.Write(w.buf[:]); err != nil {
+			w.err = err
+			return n, err
+		}
+		if _, err = w.w.Write(chunkBody); err != nil {
+			w.err = err
+			return n, err
+		}
+		n += len(uncompressed)
+	}
+	return n, nil
+}

+ 30 - 0
snappy/snappy.go

@@ -8,6 +8,10 @@
 // The C++ snappy implementation is at http://code.google.com/p/snappy/
 package snappy
 
+import (
+	"hash/crc32"
+)
+
 /*
 Each encoded block begins with the varint-encoded length of the decoded data,
 followed by a sequence of chunks. Chunks begin and end on byte boundaries. The
@@ -36,3 +40,29 @@ const (
 	tagCopy2   = 0x02
 	tagCopy4   = 0x03
 )
+
+const (
+	checksumSize    = 4
+	chunkHeaderSize = 4
+	magicChunk      = "\xff\x06\x00\x00" + magicBody
+	magicBody       = "sNaPpY"
+	// https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt says
+	// that "the uncompressed data in a chunk must be no longer than 65536 bytes".
+	maxUncompressedChunkLen = 65536
+)
+
+const (
+	chunkTypeCompressedData   = 0x00
+	chunkTypeUncompressedData = 0x01
+	chunkTypePadding          = 0xfe
+	chunkTypeStreamIdentifier = 0xff
+)
+
+var crcTable = crc32.MakeTable(crc32.Castagnoli)
+
+// crc implements the checksum specified in section 3 of
+// https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt
+func crc(b []byte) uint32 {
+	c := crc32.Update(0, crcTable, b)
+	return uint32(c>>15|c<<17) + 0xa282ead8
+}

+ 51 - 19
snappy/snappy_test.go

@@ -58,7 +58,7 @@ func TestSmallRand(t *testing.T) {
 	rand.Seed(27354294)
 	for n := 1; n < 20000; n += 23 {
 		b := make([]byte, n)
-		for i, _ := range b {
+		for i := range b {
 			b[i] = uint8(rand.Uint32())
 		}
 		if err := roundtrip(b, nil, nil); err != nil {
@@ -70,7 +70,7 @@ func TestSmallRand(t *testing.T) {
 func TestSmallRegular(t *testing.T) {
 	for n := 1; n < 20000; n += 23 {
 		b := make([]byte, n)
-		for i, _ := range b {
+		for i := range b {
 			b[i] = uint8(i%10 + 'a')
 		}
 		if err := roundtrip(b, nil, nil); err != nil {
@@ -79,6 +79,38 @@ func TestSmallRegular(t *testing.T) {
 	}
 }
 
+func TestFramingFormat(t *testing.T) {
+loop:
+	for _, tf := range testFiles {
+		if err := downloadTestdata(tf.filename); err != nil {
+			t.Fatalf("failed to download testdata: %s", err)
+		}
+		src := readFile(t, filepath.Join("testdata", tf.filename))
+		buf := new(bytes.Buffer)
+		if _, err := NewWriter(buf).Write(src); err != nil {
+			t.Errorf("%s: encoding: %v", tf.filename, err)
+			continue
+		}
+		dst, err := ioutil.ReadAll(NewReader(buf))
+		if err != nil {
+			t.Errorf("%s: decoding: %v", tf.filename, err)
+			continue
+		}
+		if !bytes.Equal(dst, src) {
+			if len(dst) != len(src) {
+				t.Errorf("%s: got %d bytes, want %d", tf.filename, len(dst), len(src))
+				continue
+			}
+			for i := range dst {
+				if dst[i] != src[i] {
+					t.Errorf("%s: byte #%d: got 0x%02x, want 0x%02x", tf.filename, i, dst[i], src[i])
+					continue loop
+				}
+			}
+		}
+	}
+}
+
 func benchDecode(b *testing.B, src []byte) {
 	encoded, err := Encode(nil, src)
 	if err != nil {
@@ -102,7 +134,7 @@ func benchEncode(b *testing.B, src []byte) {
 	}
 }
 
-func readFile(b *testing.B, filename string) []byte {
+func readFile(b testing.TB, filename string) []byte {
 	src, err := ioutil.ReadFile(filename)
 	if err != nil {
 		b.Fatalf("failed reading %s: %s", filename, err)
@@ -175,6 +207,19 @@ const baseURL = "https://snappy.googlecode.com/svn/trunk/testdata/"
 
 func downloadTestdata(basename string) (errRet error) {
 	filename := filepath.Join("testdata", basename)
+	if stat, err := os.Stat(filename); err == nil && stat.Size() != 0 {
+		return nil
+	}
+
+	if !*download {
+		return fmt.Errorf("test data not found; skipping benchmark without the -download flag")
+	}
+	// Download the official snappy C++ implementation reference test data
+	// files for benchmarking.
+	if err := os.Mkdir("testdata", 0777); err != nil && !os.IsExist(err) {
+		return fmt.Errorf("failed to create testdata: %s", err)
+	}
+
 	f, err := os.Create(filename)
 	if err != nil {
 		return fmt.Errorf("failed to create %s: %s", filename, err)
@@ -198,23 +243,10 @@ func downloadTestdata(basename string) (errRet error) {
 }
 
 func benchFile(b *testing.B, n int, decode bool) {
-	filename := filepath.Join("testdata", testFiles[n].filename)
-	if stat, err := os.Stat(filename); err != nil || stat.Size() == 0 {
-		if !*download {
-			b.Fatal("test data not found; skipping benchmark without the -download flag")
-		}
-		// Download the official snappy C++ implementation reference test data
-		// files for benchmarking.
-		if err := os.Mkdir("testdata", 0777); err != nil && !os.IsExist(err) {
-			b.Fatalf("failed to create testdata: %s", err)
-		}
-		for _, tf := range testFiles {
-			if err := downloadTestdata(tf.filename); err != nil {
-				b.Fatalf("failed to download testdata: %s", err)
-			}
-		}
+	if err := downloadTestdata(testFiles[n].filename); err != nil {
+		b.Fatalf("failed to download testdata: %s", err)
 	}
-	data := readFile(b, filename)
+	data := readFile(b, filepath.Join("testdata", testFiles[n].filename))
 	if decode {
 		benchDecode(b, data)
 	} else {