|
@@ -3,11 +3,22 @@ package snappy
|
|
|
import (
|
|
import (
|
|
|
"bytes"
|
|
"bytes"
|
|
|
"encoding/binary"
|
|
"encoding/binary"
|
|
|
|
|
+ "errors"
|
|
|
|
|
|
|
|
master "github.com/golang/snappy"
|
|
master "github.com/golang/snappy"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-var xerialHeader = []byte{130, 83, 78, 65, 80, 80, 89, 0}
|
|
|
|
|
|
|
+const (
|
|
|
|
|
+ sizeOffset = 16
|
|
|
|
|
+ sizeBytes = 4
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+var (
|
|
|
|
|
+ xerialHeader = []byte{130, 83, 78, 65, 80, 80, 89, 0}
|
|
|
|
|
+ // ErrMalformed is returned by the decoder when the xerial framing
|
|
|
|
|
+ // is malformed
|
|
|
|
|
+ ErrMalformed = errors.New("malformed xerial framing")
|
|
|
|
|
+)
|
|
|
|
|
|
|
|
// Encode encodes data as snappy with no framing header.
|
|
// Encode encodes data as snappy with no framing header.
|
|
|
func Encode(src []byte) []byte {
|
|
func Encode(src []byte) []byte {
|
|
@@ -17,26 +28,43 @@ func Encode(src []byte) []byte {
|
|
|
// Decode decodes snappy data whether it is traditional unframed
|
|
// Decode decodes snappy data whether it is traditional unframed
|
|
|
// or includes the xerial framing format.
|
|
// or includes the xerial framing format.
|
|
|
func Decode(src []byte) ([]byte, error) {
|
|
func Decode(src []byte) ([]byte, error) {
|
|
|
|
|
+ var max = len(src)
|
|
|
|
|
+ if max < len(xerialHeader) {
|
|
|
|
|
+ return nil, ErrMalformed
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if !bytes.Equal(src[:8], xerialHeader) {
|
|
if !bytes.Equal(src[:8], xerialHeader) {
|
|
|
return master.Decode(nil, src)
|
|
return master.Decode(nil, src)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if max < sizeOffset+sizeBytes {
|
|
|
|
|
+ return nil, ErrMalformed
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
var (
|
|
var (
|
|
|
- pos = uint32(16)
|
|
|
|
|
- max = uint32(len(src))
|
|
|
|
|
|
|
+ pos = sizeOffset
|
|
|
dst = make([]byte, 0, len(src))
|
|
dst = make([]byte, 0, len(src))
|
|
|
chunk []byte
|
|
chunk []byte
|
|
|
err error
|
|
err error
|
|
|
)
|
|
)
|
|
|
- for pos < max {
|
|
|
|
|
- size := binary.BigEndian.Uint32(src[pos : pos+4])
|
|
|
|
|
- pos += 4
|
|
|
|
|
|
|
|
|
|
- chunk, err = master.Decode(chunk, src[pos:pos+size])
|
|
|
|
|
|
|
+ for pos+sizeBytes <= max {
|
|
|
|
|
+ size := int(binary.BigEndian.Uint32(src[pos : pos+sizeBytes]))
|
|
|
|
|
+ pos += sizeBytes
|
|
|
|
|
+
|
|
|
|
|
+ nextPos := pos + size
|
|
|
|
|
+ // On architectures where int is 32-bytes wide size + pos could
|
|
|
|
|
+ // overflow so we need to check the low bound as well as the
|
|
|
|
|
+ // high
|
|
|
|
|
+ if nextPos < pos || nextPos > max {
|
|
|
|
|
+ return nil, ErrMalformed
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ chunk, err = master.Decode(chunk, src[pos:nextPos])
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
- pos += size
|
|
|
|
|
|
|
+ pos = nextPos
|
|
|
dst = append(dst, chunk...)
|
|
dst = append(dst, chunk...)
|
|
|
}
|
|
}
|
|
|
return dst, nil
|
|
return dst, nil
|