Browse Source

Merge branch 'v4' of https://github.com/pierrec/lz4 into v4

Pierre.Curto 5 years ago
parent
commit
dc85acba08

+ 14 - 13
internal/lz4block/block.go

@@ -26,21 +26,19 @@ const (
 )
 )
 
 
 // Pool of hash tables for CompressBlock.
 // Pool of hash tables for CompressBlock.
-var HashTablePool = hashTablePool{sync.Pool{New: func() interface{} { return make([]int, htSize) }}}
+var HashTablePool = hashTablePool{sync.Pool{New: func() interface{} { return new([htSize]int) }}}
 
 
 type hashTablePool struct {
 type hashTablePool struct {
 	sync.Pool
 	sync.Pool
 }
 }
 
 
-func (p *hashTablePool) Get() []int {
-	return p.Pool.Get().([]int)
+func (p *hashTablePool) Get() *[htSize]int {
+	return p.Pool.Get().(*[htSize]int)
 }
 }
 
 
 // Zero out the table to avoid non-deterministic outputs (see issue#65).
 // Zero out the table to avoid non-deterministic outputs (see issue#65).
-func (p *hashTablePool) Put(t []int) {
-	for i := range t {
-		t[i] = 0
-	}
+func (p *hashTablePool) Put(t *[htSize]int) {
+	*t = [htSize]int{}
 	p.Pool.Put(t)
 	p.Pool.Put(t)
 }
 }
 
 
@@ -90,8 +88,9 @@ func CompressBlock(src, dst []byte, hashTable []int) (_ int, err error) {
 	}
 	}
 
 
 	if cap(hashTable) < htSize {
 	if cap(hashTable) < htSize {
-		hashTable = HashTablePool.Get()
-		defer HashTablePool.Put(hashTable)
+		poolTable := HashTablePool.Get()
+		defer HashTablePool.Put(poolTable)
+		hashTable = poolTable[:]
 	} else {
 	} else {
 		hashTable = hashTable[:htSize]
 		hashTable = hashTable[:htSize]
 	}
 	}
@@ -275,15 +274,17 @@ func CompressBlockHC(src, dst []byte, depth CompressionLevel, hashTable, chainTa
 	// hashTable: stores the last position found for a given hash
 	// hashTable: stores the last position found for a given hash
 	// chainTable: stores previous positions for a given hash
 	// chainTable: stores previous positions for a given hash
 	if cap(hashTable) < htSize {
 	if cap(hashTable) < htSize {
-		hashTable = HashTablePool.Get()
-		defer HashTablePool.Put(hashTable)
+		poolTable := HashTablePool.Get()
+		defer HashTablePool.Put(poolTable)
+		hashTable = poolTable[:]
 	} else {
 	} else {
 		hashTable = hashTable[:htSize]
 		hashTable = hashTable[:htSize]
 	}
 	}
 	_ = hashTable[htSize-1]
 	_ = hashTable[htSize-1]
 	if cap(chainTable) < htSize {
 	if cap(chainTable) < htSize {
-		chainTable = HashTablePool.Get()
-		defer HashTablePool.Put(chainTable)
+		poolTable := HashTablePool.Get()
+		defer HashTablePool.Put(poolTable)
+		chainTable = poolTable[:]
 	} else {
 	} else {
 		chainTable = chainTable[:htSize]
 		chainTable = chainTable[:htSize]
 	}
 	}

+ 8 - 16
internal/lz4block/decode_amd64.s

@@ -109,8 +109,7 @@ loop:
 	MOVW 16(AX), BX
 	MOVW 16(AX), BX
 	MOVW BX, 16(DI)
 	MOVW BX, 16(DI)
 
 
-	ADDQ $4, DI // minmatch
-	ADDQ CX, DI
+	LEAQ 4(DI)(CX*1), DI // minmatch
 
 
 	// shortcut complete, load next token
 	// shortcut complete, load next token
 	JMP loop
 	JMP loop
@@ -128,8 +127,7 @@ lit_len_loop:
 	JNE lit_len_finalise
 	JNE lit_len_finalise
 
 
 	// bounds check src[si+1]
 	// bounds check src[si+1]
-	MOVQ SI, AX
-	ADDQ $1, AX
+	LEAQ 1(SI), AX
 	CMPQ AX, R9
 	CMPQ AX, R9
 	JGT err_short_buf
 	JGT err_short_buf
 
 
@@ -147,13 +145,11 @@ lit_len_finalise:
 
 
 copy_literal:
 copy_literal:
 	// bounds check src and dst
 	// bounds check src and dst
-	MOVQ SI, AX
-	ADDQ CX, AX
+	LEAQ (SI)(CX*1), AX
 	CMPQ AX, R9
 	CMPQ AX, R9
 	JGT err_short_buf
 	JGT err_short_buf
 
 
-	MOVQ DI, AX
-	ADDQ CX, AX
+	LEAQ (DI)(CX*1), AX
 	CMPQ AX, R8
 	CMPQ AX, R8
 	JGT err_short_buf
 	JGT err_short_buf
 
 
@@ -219,8 +215,7 @@ offset:
 	// free up DX to use for offset
 	// free up DX to use for offset
 	MOVQ DX, CX
 	MOVQ DX, CX
 
 
-	MOVQ SI, AX
-	ADDQ $2, AX
+	LEAQ 2(SI), AX
 	CMPQ AX, R9
 	CMPQ AX, R9
 	JGT err_short_buf
 	JGT err_short_buf
 
 
@@ -247,8 +242,7 @@ match_len_loop:
 	JNE match_len_finalise
 	JNE match_len_finalise
 
 
 	// bounds check src[si+1]
 	// bounds check src[si+1]
-	MOVQ SI, AX
-	ADDQ $1, AX
+	LEAQ 1(SI), AX
 	CMPQ AX, R9
 	CMPQ AX, R9
 	JGT err_short_buf
 	JGT err_short_buf
 
 
@@ -269,8 +263,7 @@ copy_match:
 
 
 	// check we have match_len bytes left in dst
 	// check we have match_len bytes left in dst
 	// di+match_len < len(dst)
 	// di+match_len < len(dst)
-	MOVQ DI, AX
-	ADDQ CX, AX
+	LEAQ (DI)(CX*1), AX
 	CMPQ AX, R8
 	CMPQ AX, R8
 	JGT err_short_buf
 	JGT err_short_buf
 
 
@@ -286,8 +279,7 @@ copy_match:
 	JLT err_short_buf
 	JLT err_short_buf
 
 
 	// if offset + match_len < di
 	// if offset + match_len < di
-	MOVQ BX, AX
-	ADDQ CX, AX
+	LEAQ (BX)(CX*1), AX
 	CMPQ DI, AX
 	CMPQ DI, AX
 	JGT copy_interior_match
 	JGT copy_interior_match
 
 

+ 12 - 12
internal/lz4block/decode_other.go

@@ -10,16 +10,16 @@ func decodeBlock(dst, src []byte) (ret int) {
 		}
 		}
 	}()
 	}()
 
 
-	var si, di int
+	var si, di uint
 	for {
 	for {
 		// Literals and match lengths (token).
 		// Literals and match lengths (token).
-		b := int(src[si])
+		b := uint(src[si])
 		si++
 		si++
 
 
 		// Literals.
 		// Literals.
 		if lLen := b >> 4; lLen > 0 {
 		if lLen := b >> 4; lLen > 0 {
 			switch {
 			switch {
-			case lLen < 0xF && si+16 < len(src):
+			case lLen < 0xF && si+16 < uint(len(src)):
 				// Shortcut 1
 				// Shortcut 1
 				// if we have enough room in src and dst, and the literals length
 				// if we have enough room in src and dst, and the literals length
 				// is small enough (0..14) then copy all 16 bytes, even if not all
 				// is small enough (0..14) then copy all 16 bytes, even if not all
@@ -32,13 +32,13 @@ func decodeBlock(dst, src []byte) (ret int) {
 					// if the match length (4..18) fits within the literals, then copy
 					// if the match length (4..18) fits within the literals, then copy
 					// all 18 bytes, even if not all are part of the literals.
 					// all 18 bytes, even if not all are part of the literals.
 					mLen += 4
 					mLen += 4
-					if offset := int(src[si]) | int(src[si+1])<<8; mLen <= offset {
+					if offset := uint(src[si]) | uint(src[si+1])<<8; mLen <= offset {
 						i := di - offset
 						i := di - offset
 						end := i + 18
 						end := i + 18
-						if end > len(dst) {
+						if end > uint(len(dst)) {
 							// The remaining buffer may not hold 18 bytes.
 							// The remaining buffer may not hold 18 bytes.
 							// See https://github.com/pierrec/lz4/issues/51.
 							// See https://github.com/pierrec/lz4/issues/51.
-							end = len(dst)
+							end = uint(len(dst))
 						}
 						}
 						copy(dst[di:], dst[i:end])
 						copy(dst[di:], dst[i:end])
 						si += 2
 						si += 2
@@ -51,7 +51,7 @@ func decodeBlock(dst, src []byte) (ret int) {
 					lLen += 0xFF
 					lLen += 0xFF
 					si++
 					si++
 				}
 				}
-				lLen += int(src[si])
+				lLen += uint(src[si])
 				si++
 				si++
 				fallthrough
 				fallthrough
 			default:
 			default:
@@ -60,11 +60,11 @@ func decodeBlock(dst, src []byte) (ret int) {
 				di += lLen
 				di += lLen
 			}
 			}
 		}
 		}
-		if si >= len(src) {
-			return di
+		if si >= uint(len(src)) {
+			return int(di)
 		}
 		}
 
 
-		offset := int(src[si]) | int(src[si+1])<<8
+		offset := uint(src[si]) | uint(src[si+1])<<8
 		if offset == 0 {
 		if offset == 0 {
 			return hasError
 			return hasError
 		}
 		}
@@ -77,7 +77,7 @@ func decodeBlock(dst, src []byte) (ret int) {
 				mLen += 0xFF
 				mLen += 0xFF
 				si++
 				si++
 			}
 			}
-			mLen += int(src[si])
+			mLen += uint(src[si])
 			si++
 			si++
 		}
 		}
 		mLen += minMatch
 		mLen += minMatch
@@ -93,6 +93,6 @@ func decodeBlock(dst, src []byte) (ret int) {
 			di += bytesToCopy
 			di += bytesToCopy
 			mLen -= bytesToCopy
 			mLen -= bytesToCopy
 		}
 		}
-		di += copy(dst[di:di+mLen], expanded[:mLen])
+		di += uint(copy(dst[di:di+mLen], expanded[:mLen]))
 	}
 	}
 }
 }

+ 14 - 12
internal/lz4stream/frame.go

@@ -68,8 +68,10 @@ func (f *Frame) InitR(src io.Reader) error {
 		// Header already read.
 		// Header already read.
 		return nil
 		return nil
 	}
 	}
+
 newFrame:
 newFrame:
-	if err := readUint32(src, f.buf[:], &f.Magic); err != nil {
+	var err error
+	if f.Magic, err = f.readUint32(src); err != nil {
 		return err
 		return err
 	}
 	}
 	switch m := f.Magic; {
 	switch m := f.Magic; {
@@ -95,11 +97,11 @@ newFrame:
 	return nil
 	return nil
 }
 }
 
 
-func (f *Frame) CloseR(src io.Reader) error {
+func (f *Frame) CloseR(src io.Reader) (err error) {
 	if !f.Descriptor.Flags.ContentChecksum() {
 	if !f.Descriptor.Flags.ContentChecksum() {
 		return nil
 		return nil
 	}
 	}
-	if err := readUint32(src, f.buf[:], &f.Checksum); err != nil {
+	if f.Checksum, err = f.readUint32(src); err != nil {
 		return err
 		return err
 	}
 	}
 	if c := f.checksum.Sum32(); c != f.Checksum {
 	if c := f.checksum.Sum32(); c != f.Checksum {
@@ -319,9 +321,8 @@ func (b *FrameDataBlock) Write(f *Frame, dst io.Writer) error {
 }
 }
 
 
 func (b *FrameDataBlock) Uncompress(f *Frame, src io.Reader, dst []byte) (int, error) {
 func (b *FrameDataBlock) Uncompress(f *Frame, src io.Reader, dst []byte) (int, error) {
-	buf := f.buf[:]
-	var x uint32
-	if err := readUint32(src, buf, &x); err != nil {
+	x, err := f.readUint32(src)
+	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
 	b.Size = DataBlockSize(x)
 	b.Size = DataBlockSize(x)
@@ -353,7 +354,8 @@ func (b *FrameDataBlock) Uncompress(f *Frame, src io.Reader, dst []byte) (int, e
 	}
 	}
 
 
 	if f.Descriptor.Flags.BlockChecksum() {
 	if f.Descriptor.Flags.BlockChecksum() {
-		if err := readUint32(src, buf, &b.Checksum); err != nil {
+		var err error
+		if b.Checksum, err = f.readUint32(src); err != nil {
 			return 0, err
 			return 0, err
 		}
 		}
 		if c := xxh32.ChecksumZero(data); c != b.Checksum {
 		if c := xxh32.ChecksumZero(data); c != b.Checksum {
@@ -366,10 +368,10 @@ func (b *FrameDataBlock) Uncompress(f *Frame, src io.Reader, dst []byte) (int, e
 	return len(data), nil
 	return len(data), nil
 }
 }
 
 
-func readUint32(r io.Reader, buf []byte, x *uint32) error {
-	if _, err := io.ReadFull(r, buf[:4]); err != nil {
-		return err
+func (f *Frame) readUint32(r io.Reader) (x uint32, err error) {
+	if _, err = io.ReadFull(r, f.buf[:4]); err != nil {
+		return
 	}
 	}
-	*x = binary.LittleEndian.Uint32(buf)
-	return nil
+	x = binary.LittleEndian.Uint32(f.buf[:4])
+	return
 }
 }

+ 3 - 5
internal/xxh32/xxh32zero.go

@@ -79,8 +79,7 @@ func (xxh *XXHZero) Write(input []byte) (int, error) {
 	v1, v2, v3, v4 := xxh.v1, xxh.v2, xxh.v3, xxh.v4
 	v1, v2, v3, v4 := xxh.v1, xxh.v2, xxh.v3, xxh.v4
 	if m > 0 {
 	if m > 0 {
 		// some data left from previous update
 		// some data left from previous update
-		copy(xxh.buf[xxh.bufused:], input[:r])
-		xxh.bufused += len(input) - r
+		copy(xxh.buf[m:], input)
 
 
 		// fast rotl(13)
 		// fast rotl(13)
 		buf := xxh.buf[:16] // BCE hint.
 		buf := xxh.buf[:16] // BCE hint.
@@ -89,7 +88,6 @@ func (xxh *XXHZero) Write(input []byte) (int, error) {
 		v3 = rol13(v3+binary.LittleEndian.Uint32(buf[8:])*prime2) * prime1
 		v3 = rol13(v3+binary.LittleEndian.Uint32(buf[8:])*prime2) * prime1
 		v4 = rol13(v4+binary.LittleEndian.Uint32(buf[12:])*prime2) * prime1
 		v4 = rol13(v4+binary.LittleEndian.Uint32(buf[12:])*prime2) * prime1
 		p = r
 		p = r
-		xxh.bufused = 0
 	}
 	}
 
 
 	for n := n - 16; p <= n; p += 16 {
 	for n := n - 16; p <= n; p += 16 {
@@ -101,8 +99,8 @@ func (xxh *XXHZero) Write(input []byte) (int, error) {
 	}
 	}
 	xxh.v1, xxh.v2, xxh.v3, xxh.v4 = v1, v2, v3, v4
 	xxh.v1, xxh.v2, xxh.v3, xxh.v4 = v1, v2, v3, v4
 
 
-	copy(xxh.buf[xxh.bufused:], input[p:])
-	xxh.bufused += len(input) - p
+	copy(xxh.buf[:], input[p:])
+	xxh.bufused = len(input) - p
 
 
 	return n, nil
 	return n, nil
 }
 }

+ 6 - 6
internal/xxh32/xxh32zero_test.go

@@ -52,10 +52,10 @@ func TestZeroData(t *testing.T) {
 		_, _ = xxh.Write(data)
 		_, _ = xxh.Write(data)
 
 
 		if got, want := xxh.Sum32(), td.sum; got != want {
 		if got, want := xxh.Sum32(), td.sum; got != want {
-			t.Fatalf("got %d; want %d", got, want)
+			t.Fatalf("got %x; want %x", got, want)
 		}
 		}
 		if got, want := xxh32.ChecksumZero(data), td.sum; got != want {
 		if got, want := xxh32.ChecksumZero(data), td.sum; got != want {
-			t.Fatalf("got %d; want %d", got, want)
+			t.Fatalf("got %x; want %x", got, want)
 		}
 		}
 	}
 	}
 }
 }
@@ -69,7 +69,7 @@ func TestZeroSplitData(t *testing.T) {
 		_, _ = xxh.Write(data[l:])
 		_, _ = xxh.Write(data[l:])
 
 
 		if got, want := xxh.Sum32(), td.sum; got != want {
 		if got, want := xxh.Sum32(), td.sum; got != want {
-			t.Fatalf("got %d; want %d", got, want)
+			t.Fatalf("got %x; want %x", got, want)
 		}
 		}
 	}
 	}
 }
 }
@@ -82,7 +82,7 @@ func TestZeroSum(t *testing.T) {
 		b := xxh.Sum(data)
 		b := xxh.Sum(data)
 		h := binary.LittleEndian.Uint32(b[len(data):])
 		h := binary.LittleEndian.Uint32(b[len(data):])
 		if got, want := h, td.sum; got != want {
 		if got, want := h, td.sum; got != want {
-			t.Fatalf("got %d; want %d", got, want)
+			t.Fatalf("got %x; want %x", got, want)
 		}
 		}
 	}
 	}
 }
 }
@@ -92,7 +92,7 @@ func TestZeroChecksum(t *testing.T) {
 		data := []byte(td.data)
 		data := []byte(td.data)
 		h := xxh32.ChecksumZero(data)
 		h := xxh32.ChecksumZero(data)
 		if got, want := h, td.sum; got != want {
 		if got, want := h, td.sum; got != want {
-			t.Fatalf("got %d; want %d", got, want)
+			t.Fatalf("got %x; want %x", got, want)
 		}
 		}
 	}
 	}
 }
 }
@@ -103,7 +103,7 @@ func TestZeroReset(t *testing.T) {
 		_, _ = xxh.Write([]byte(td.data))
 		_, _ = xxh.Write([]byte(td.data))
 		h := xxh.Sum32()
 		h := xxh.Sum32()
 		if got, want := h, td.sum; got != want {
 		if got, want := h, td.sum; got != want {
-			t.Fatalf("got %d; want %d", got, want)
+			t.Fatalf("got %x; want %x", got, want)
 		}
 		}
 		xxh.Reset()
 		xxh.Reset()
 	}
 	}

+ 1 - 1
reader.go

@@ -183,8 +183,8 @@ func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
 			return
 			return
 		}
 		}
 		r.handler(bn)
 		r.handler(bn)
+		bn, err = w.Write(data[:bn])
 		n += int64(bn)
 		n += int64(bn)
-		_, err = w.Write(data[:bn])
 		if err != nil {
 		if err != nil {
 			return
 			return
 		}
 		}

+ 31 - 0
reader_test.go

@@ -2,6 +2,7 @@ package lz4_test
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"errors"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"os"
 	"os"
@@ -101,3 +102,33 @@ func TestReader_Reset(t *testing.T) {
 		t.Fatal("result does not match original")
 		t.Fatal("result does not match original")
 	}
 	}
 }
 }
+
+type brokenWriter int
+
+func (w *brokenWriter) Write(p []byte) (n int, err error) {
+	n = len(p)
+	if n > int(*w) {
+		n = int(*w)
+		err = errors.New("broken")
+	}
+	*w -= brokenWriter(n)
+	return
+}
+
+// WriteTo should report the number of bytes successfully written,
+// not the number successfully decompressed.
+func TestWriteToBrokenWriter(t *testing.T) {
+	const capacity = 10
+	w := brokenWriter(capacity)
+	r := lz4.NewReader(bytes.NewReader(pg1661LZ4))
+
+	n, err := r.WriteTo(&w)
+	switch {
+	case n > capacity:
+		t.Errorf("reported number of bytes written %d too big", n)
+	case err == nil:
+		t.Error("no error from broken Writer")
+	case err.Error() != "broken":
+		t.Errorf("unexpected error %q", err.Error())
+	}
+}