Browse Source

first v4 commit: tests are not ready.

Pierre.Curto 5 years ago
parent
commit
fb4a2ec8e1
24 changed files with 1142 additions and 911 deletions
  1. 3 3
      bench_test.go
  2. 26 16
      block.go
  3. 2 2
      block_test.go
  4. 0 23
      debug.go
  5. 0 7
      debug_stub.go
  6. 0 30
      errors.go
  7. 2 3
      example_test.go
  8. 340 0
      frame.go
  9. 49 0
      frame_gen.go
  10. 51 0
      gen.go
  11. 9 0
      go.mod
  12. 30 0
      go.sum
  13. 24 1
      internal/xxh32/xxh32zero.go
  14. 1 1
      internal/xxh32/xxh32zero_test.go
  15. 27 93
      lz4.go
  16. 0 29
      lz4_go1.10.go
  17. 0 29
      lz4_notgo1.10.go
  18. 165 0
      options.go
  19. 92 0
      options_gen.go
  20. 96 304
      reader.go
  21. 65 0
      state.go
  22. 29 0
      state_gen.go
  23. 119 356
      writer.go
  24. 12 14
      writer_test.go

+ 3 - 3
bench_test.go

@@ -39,7 +39,7 @@ func BenchmarkCompressHC(b *testing.B) {
 	b.ResetTimer()
 
 	for i := 0; i < b.N; i++ {
-		_, _ = lz4.CompressBlockHC(pg1661, buf, 16)
+		_, _ = lz4.CompressBlockHC(pg1661, buf, 16, nil)
 	}
 }
 
@@ -128,7 +128,7 @@ func BenchmarkSkipBytesRand(b *testing.B)   { benchmarkSkipBytes(b, randomLZ4) }
 
 func benchmarkCompress(b *testing.B, uncompressed []byte) {
 	w := bytes.NewBuffer(nil)
-	zw := lz4.NewWriter(w)
+	zw, _ := lz4.NewWriter(w)
 	r := bytes.NewReader(uncompressed)
 
 	// Determine the compressed size of testfile.
@@ -161,7 +161,7 @@ func BenchmarkCompressRand(b *testing.B)   { benchmarkCompress(b, random) }
 func BenchmarkWriterReset(b *testing.B) {
 	b.ReportAllocs()
 
-	zw := lz4.NewWriter(nil)
+	zw, _ := lz4.NewWriter(nil)
 	src := mustLoadFile("testdata/gettysburg.txt")
 	var buf bytes.Buffer
 

+ 26 - 16
block.go

@@ -6,6 +6,15 @@ import (
 	"sync"
 )
 
+// Pool of hash tables for CompressBlock.
+var htPool = sync.Pool{New: func() interface{} { return make([]int, htSize) }}
+
+func recoverBlock(e *error) {
+	if r := recover(); r != nil && *e == nil {
+		*e = ErrInvalidSourceShortBuffer
+	}
+}
+
 // blockHash hashes the lower 6 bytes into a value < htSize.
 func blockHash(x uint64) uint32 {
 	const prime6bytes = 227718039650203
@@ -56,14 +65,13 @@ func CompressBlock(src, dst []byte, hashTable []int) (_ int, err error) {
 	// This significantly speeds up incompressible data and usually has very small impact on compression.
 	// bytes to skip =  1 + (bytes since last match >> adaptSkipLog)
 	const adaptSkipLog = 7
-	if len(hashTable) < htSize {
-		htIface := htPool.Get()
-		defer htPool.Put(htIface)
-		hashTable = (*(htIface).(*[htSize]int))[:]
+	if cap(hashTable) < htSize {
+		hashTable = htPool.Get().([]int)
+		defer htPool.Put(hashTable)
+	} else {
+		hashTable = hashTable[:htSize]
 	}
-	// Prove to the compiler the table has at least htSize elements.
-	// The compiler can see that "uint32() >> hashShift" cannot be out of bounds.
-	hashTable = hashTable[:htSize]
+	_ = hashTable[htSize-1]
 
 	// si: Current position of the search.
 	// anchor: Position of the current literals.
@@ -225,13 +233,6 @@ lastLiterals:
 	return di, nil
 }
 
-// Pool of hash tables for CompressBlock.
-var htPool = sync.Pool{
-	New: func() interface{} {
-		return new([htSize]int)
-	},
-}
-
 // blockHash hashes 4 bytes into a value < winSize.
 func blockHashHC(x uint32) uint32 {
 	const hasher uint32 = 2654435761 // Knuth multiplicative hash.
@@ -249,7 +250,7 @@ func blockHashHC(x uint32) uint32 {
 // the compressed size is 0 and no error, then the data is incompressible.
 //
 // An error is returned if the destination buffer is too small.
-func CompressBlockHC(src, dst []byte, depth int) (_ int, err error) {
+func CompressBlockHC(src, dst []byte, depth CompressionLevel, hashTable []int) (_ int, err error) {
 	defer recoverBlock(&err)
 
 	// Return 0, nil only if the destination buffer size is < CompressBlockBound.
@@ -264,7 +265,16 @@ func CompressBlockHC(src, dst []byte, depth int) (_ int, err error) {
 
 	// hashTable: stores the last position found for a given hash
 	// chainTable: stores previous positions for a given hash
-	var hashTable, chainTable [winSize]int
+	if cap(hashTable) < htSize {
+		hashTable = htPool.Get().([]int)
+		defer htPool.Put(hashTable)
+	} else {
+		hashTable = hashTable[:htSize]
+	}
+	_ = hashTable[htSize-1]
+	chainTable := htPool.Get().([]int)
+	defer htPool.Put(chainTable)
+	_ = chainTable[htSize-1]
 
 	if depth <= 0 {
 		depth = winSize

+ 2 - 2
block_test.go

@@ -111,7 +111,7 @@ func TestCompressUncompressBlock(t *testing.T) {
 			t.Run(fmt.Sprintf("%s HC", tc.file), func(t *testing.T) {
 				// t.Parallel()
 				nhc = run(t, tc, func(src, dst []byte) (int, error) {
-					return lz4.CompressBlockHC(src, dst, -1)
+					return lz4.CompressBlockHC(src, dst, 16, nil)
 				})
 			})
 		})
@@ -153,7 +153,7 @@ func TestCompressCornerCase_CopyDstUpperBound(t *testing.T) {
 	t.Run(fmt.Sprintf("%s HC", file), func(t *testing.T) {
 		t.Parallel()
 		run(src, func(src, dst []byte) (int, error) {
-			return lz4.CompressBlockHC(src, dst, -1)
+			return lz4.CompressBlockHC(src, dst, 16, nil)
 		})
 	})
 }

+ 0 - 23
debug.go

@@ -1,23 +0,0 @@
-// +build lz4debug
-
-package lz4
-
-import (
-	"fmt"
-	"os"
-	"path/filepath"
-	"runtime"
-)
-
-const debugFlag = true
-
-func debug(args ...interface{}) {
-	_, file, line, _ := runtime.Caller(1)
-	file = filepath.Base(file)
-
-	f := fmt.Sprintf("LZ4: %s:%d %s", file, line, args[0])
-	if f[len(f)-1] != '\n' {
-		f += "\n"
-	}
-	fmt.Fprintf(os.Stderr, f, args[1:]...)
-}

+ 0 - 7
debug_stub.go

@@ -1,7 +0,0 @@
-// +build !lz4debug
-
-package lz4
-
-const debugFlag = false
-
-func debug(args ...interface{}) {}

+ 0 - 30
errors.go

@@ -1,30 +0,0 @@
-package lz4
-
-import (
-	"errors"
-	"fmt"
-	"os"
-	rdebug "runtime/debug"
-)
-
-var (
-	// ErrInvalidSourceShortBuffer is returned by UncompressBlock or CompressBLock when a compressed
-	// block is corrupted or the destination buffer is not large enough for the uncompressed data.
-	ErrInvalidSourceShortBuffer = errors.New("lz4: invalid source or destination buffer too short")
-	// ErrInvalid is returned when reading an invalid LZ4 archive.
-	ErrInvalid = errors.New("lz4: bad magic number")
-	// ErrBlockDependency is returned when attempting to decompress an archive created with block dependency.
-	ErrBlockDependency = errors.New("lz4: block dependency not supported")
-	// ErrUnsupportedSeek is returned when attempting to Seek any way but forward from the current position.
-	ErrUnsupportedSeek = errors.New("lz4: can only seek forward from io.SeekCurrent")
-)
-
-func recoverBlock(e *error) {
-	if r := recover(); r != nil && *e == nil {
-		if debugFlag {
-			fmt.Fprintln(os.Stderr, r)
-			rdebug.PrintStack()
-		}
-		*e = ErrInvalidSourceShortBuffer
-	}
-}

+ 2 - 3
example_test.go

@@ -16,7 +16,7 @@ func Example() {
 
 	// The pipe will uncompress the data from the writer.
 	pr, pw := io.Pipe()
-	zw := lz4.NewWriter(pw)
+	zw, _ := lz4.NewWriter(pw)
 	zr := lz4.NewReader(pr)
 
 	go func() {
@@ -36,9 +36,8 @@ func ExampleCompressBlock() {
 	s := "hello world"
 	data := []byte(strings.Repeat(s, 100))
 	buf := make([]byte, len(data))
-	ht := make([]int, 64<<10) // buffer for the compression table
 
-	n, err := lz4.CompressBlock(data, buf, ht)
+	n, err := lz4.CompressBlock(data, buf, nil)
 	if err != nil {
 		fmt.Println(err)
 	}

+ 340 - 0
frame.go

@@ -0,0 +1,340 @@
+package lz4
+
+import (
+	"encoding/binary"
+	"fmt"
+	"io"
+	"io/ioutil"
+
+	"github.com/pierrec/lz4/internal/xxh32"
+)
+
+//go:generate go run gen.go
+
+type Frame struct {
+	Magic      uint32
+	Descriptor FrameDescriptor
+	Blocks     Blocks
+	Checksum   uint32
+	checksum   xxh32.XXHZero
+}
+
+func (f *Frame) initW(w *_Writer) {
+	f.Magic = frameMagic
+	f.Descriptor.initW(w)
+	f.Blocks.initW(w)
+	f.checksum.Reset()
+}
+
+func (f *Frame) closeW(w *_Writer) error {
+	if err := f.Blocks.closeW(w); err != nil {
+		return err
+	}
+	buf := w.buf[:0]
+	if f.Descriptor.Flags.ContentChecksum() {
+		buf = f.checksum.Sum(buf)
+	}
+	// End mark (data block size of uint32(0)).
+	buf = append(buf, 0, 0, 0, 0)
+	_, err := w.src.Write(buf)
+	return err
+}
+
+func (f *Frame) initR(r *_Reader) error {
+	if f.Magic > 0 {
+		// Header already read.
+		return nil
+	}
+newFrame:
+	if err := readUint32(r.src, r.buf[:], &f.Magic); err != nil {
+		return err
+	}
+	switch m := f.Magic; {
+	case m == frameMagic:
+	// All 16 values of frameSkipMagic are valid.
+	case m>>8 == frameSkipMagic>>8:
+		var skip uint32
+		if err := binary.Read(r.src, binary.LittleEndian, &skip); err != nil {
+			return err
+		}
+		if _, err := io.CopyN(ioutil.Discard, r.src, int64(skip)); err != nil {
+			return err
+		}
+		goto newFrame
+	default:
+		return ErrInvalid
+	}
+	if err := f.Descriptor.initR(r); err != nil {
+		return err
+	}
+	f.Blocks.initR(r)
+	f.checksum.Reset()
+	return nil
+}
+
+func (f *Frame) closeR(r *_Reader) error {
+	f.Magic = 0
+	if !f.Descriptor.Flags.ContentChecksum() {
+		return nil
+	}
+	if err := readUint32(r.src, r.buf[:], &f.Checksum); err != nil {
+		return err
+	}
+	if c := f.checksum.Sum32(); c != f.Checksum {
+		return fmt.Errorf("%w: got %x; expected %x", ErrInvalidFrameChecksum, c, f.Checksum)
+	}
+	return nil
+}
+
+type FrameDescriptor struct {
+	Flags       DescriptorFlags
+	ContentSize uint64
+	Checksum    uint8
+}
+
+func (fd *FrameDescriptor) initW(_ *_Writer) {
+	fd.Flags.VersionSet(1)
+	fd.Flags.BlockIndependenceSet(false)
+}
+
+func (fd *FrameDescriptor) write(w *_Writer) error {
+	if fd.Checksum > 0 {
+		// Header already written.
+		return nil
+	}
+
+	buf := w.buf[:]
+	binary.LittleEndian.PutUint16(buf, uint16(fd.Flags))
+
+	var checksum uint32
+	if fd.Flags.Size() {
+		checksum = xxh32.ChecksumZero10(uint16(fd.Flags), fd.ContentSize)
+		binary.LittleEndian.PutUint64(buf[2:], fd.ContentSize)
+		buf = buf[:10]
+	} else {
+		checksum = xxh32.Uint32Zero(uint32(fd.Flags))
+		buf = buf[:2]
+	}
+	fd.Checksum = byte(checksum >> 8)
+	buf = append(buf, fd.Checksum)
+
+	_, err := w.src.Write(buf)
+	return err
+}
+
+func (fd *FrameDescriptor) initR(r *_Reader) error {
+	buf := r.buf[:2]
+	if _, err := io.ReadFull(r.src, buf); err != nil {
+		return err
+	}
+	descr := binary.LittleEndian.Uint64(buf)
+	fd.Flags = DescriptorFlags(descr)
+
+	var checksum uint32
+	if fd.Flags.Size() {
+		buf = buf[:9]
+		if _, err := io.ReadFull(r.src, buf); err != nil {
+			return err
+		}
+		fd.ContentSize = binary.LittleEndian.Uint64(buf)
+		checksum = xxh32.ChecksumZero10(uint16(fd.Flags), fd.ContentSize)
+	} else {
+		buf = buf[:1]
+		var err error
+		if br, ok := r.src.(io.ByteReader); ok {
+			buf[0], err = br.ReadByte()
+		} else {
+			_, err = io.ReadFull(r.src, buf)
+		}
+		if err != nil {
+			return err
+		}
+		checksum = xxh32.Uint32Zero(uint32(fd.Flags))
+	}
+	fd.Checksum = buf[len(buf)-1]
+	if c := byte(checksum >> 8); fd.Checksum != c {
+		return fmt.Errorf("lz4: %w: got %x; expected %x", ErrInvalidHeaderChecksum, c, fd.Checksum)
+	}
+
+	return nil
+}
+
+type Blocks struct {
+	Block  *FrameDataBlock
+	Blocks chan chan *FrameDataBlock
+	err    error
+}
+
+func (b *Blocks) initW(w *_Writer) {
+	size := w.frame.Descriptor.Flags.BlockSizeIndex()
+	if w.isNotConcurrent() {
+		b.Blocks = nil
+		b.Block = newFrameDataBlock(size)
+		return
+	}
+	if cap(b.Blocks) != w.num {
+		b.Blocks = make(chan chan *FrameDataBlock, w.num)
+	}
+	// goroutine managing concurrent block compression goroutines.
+	go func() {
+		// Process next block compression item.
+		for c := range b.Blocks {
+			// Read the next compressed block result.
+			// Waiting here ensures that the blocks are output in the order they were sent.
+			// The incoming channel is always closed as it indicates to the caller that
+			// the block has been processed.
+			block := <-c
+			if block == nil {
+				// Notify the block compression routine that we are done with its result.
+				// This is used when a sentinel block is sent to terminate the compression.
+				close(c)
+				return
+			}
+			// Do not attempt to write the block upon any previous failure.
+			if b.err == nil {
+				// Write the block.
+				if err := block.write(w); err != nil && b.err == nil {
+					// Keep the first error.
+					b.err = err
+					// All pending compression goroutines need to shut down, so we need to keep going.
+				}
+			}
+			close(c)
+		}
+	}()
+}
+
+func (b *Blocks) closeW(w *_Writer) error {
+	if w.isNotConcurrent() {
+		b.Block.closeW(w)
+		b.Block = nil
+		return nil
+	}
+	c := make(chan *FrameDataBlock)
+	b.Blocks <- c
+	c <- nil
+	<-c
+	err := b.err
+	b.err = nil
+	return err
+}
+
+func (b *Blocks) initR(r *_Reader) {
+	size := r.frame.Descriptor.Flags.BlockSizeIndex()
+	b.Block = newFrameDataBlock(size)
+}
+
+func newFrameDataBlock(size BlockSizeIndex) *FrameDataBlock {
+	return &FrameDataBlock{Data: size.get()}
+}
+
+type FrameDataBlock struct {
+	Size     DataBlockSize
+	Data     []byte
+	Checksum uint32
+}
+
+func (b *FrameDataBlock) closeW(w *_Writer) {
+	size := w.frame.Descriptor.Flags.BlockSizeIndex()
+	size.put(b.Data)
+}
+
+// Block compression errors are ignored since the buffer is sized appropriately.
+func (b *FrameDataBlock) compress(w *_Writer, src []byte, ht []int) *FrameDataBlock {
+	dst := b.Data
+	var n int
+	switch w.level {
+	case Fast:
+		n, _ = CompressBlock(src, dst, ht)
+	default:
+		n, _ = CompressBlockHC(src, dst, w.level, ht)
+	}
+	if n == 0 {
+		b.Size.compressedSet(false)
+		dst = src
+	} else {
+		b.Size.compressedSet(true)
+		dst = dst[:n]
+	}
+	b.Data = dst
+	b.Size.sizeSet(len(dst))
+
+	if w.frame.Descriptor.Flags.BlockChecksum() {
+		b.Checksum = xxh32.ChecksumZero(src)
+	}
+	if w.frame.Descriptor.Flags.ContentChecksum() {
+		_, _ = w.frame.checksum.Write(src)
+	}
+	return b
+}
+
+func (b *FrameDataBlock) write(w *_Writer) error {
+	buf := w.buf[:]
+	out := w.src
+
+	binary.LittleEndian.PutUint32(buf, uint32(b.Size))
+	if _, err := out.Write(buf[:4]); err != nil {
+		return err
+	}
+
+	if _, err := out.Write(b.Data); err != nil {
+		return err
+	}
+
+	if b.Checksum == 0 {
+		return nil
+	}
+	binary.LittleEndian.PutUint32(buf, b.Checksum)
+	_, err := out.Write(buf[:4])
+	return err
+}
+
+func (b *FrameDataBlock) uncompress(r *_Reader, dst []byte) (int, error) {
+	var x uint32
+	if err := readUint32(r.src, r.buf[:], &x); err != nil {
+		return 0, err
+	}
+	b.Size = DataBlockSize(x)
+	if b.Size == 0 {
+		return 0, io.EOF
+	}
+
+	isCompressed := b.Size.compressed()
+	var data []byte
+	if isCompressed {
+		data = b.Data
+	} else {
+		data = dst
+	}
+	if _, err := io.ReadFull(r.src, data[:b.Size.size()]); err != nil {
+		return 0, err
+	}
+	if isCompressed {
+		n, err := UncompressBlock(data, dst)
+		if err != nil {
+			return 0, err
+		}
+		data = dst[:n]
+	}
+
+	if r.frame.Descriptor.Flags.BlockChecksum() {
+		if err := readUint32(r.src, r.buf[:], &b.Checksum); err != nil {
+			return 0, err
+		}
+		if c := xxh32.ChecksumZero(data); c != b.Checksum {
+			return 0, fmt.Errorf("lz4: %w: got %x; expected %x", ErrInvalidBlockChecksum, c, b.Checksum)
+		}
+	}
+	if r.frame.Descriptor.Flags.ContentChecksum() {
+		_, _ = r.frame.checksum.Write(data)
+	}
+	return len(data), nil
+}
+
+func readUint32(r io.Reader, buf []byte, x *uint32) error {
+	if _, err := io.ReadFull(r, buf[:4]); err != nil {
+		return err
+	}
+	*x = binary.LittleEndian.Uint32(buf)
+	return nil
+}

+ 49 - 0
frame_gen.go

@@ -0,0 +1,49 @@
+// Code generated by `gen.exe`. DO NOT EDIT.
+
+package lz4
+
+// DescriptorFlags is defined as follow:
+//   field              bits
+//   -----              ----
+//   _                  4
+//   BlockSizeIndex     3
+//   _                  1
+//   _                  2
+//   ContentChecksum    1
+//   Size               1
+//   BlockChecksum      1
+//   BlockIndependence  1
+//   Version            2
+type DescriptorFlags uint16
+
+// Getters.
+func (x DescriptorFlags) BlockSizeIndex() BlockSizeIndex { return BlockSizeIndex(x>>4&0x7) }
+func (x DescriptorFlags) ContentChecksum() bool          { return x>>10&1 != 0 }
+func (x DescriptorFlags) Size() bool                     { return x>>11&1 != 0 }
+func (x DescriptorFlags) BlockChecksum() bool            { return x>>12&1 != 0 }
+func (x DescriptorFlags) BlockIndependence() bool        { return x>>13&1 != 0 }
+func (x DescriptorFlags) Version() uint16                { return uint16(x>>14&0x3) }
+
+// Setters.
+func (x *DescriptorFlags) BlockSizeIndexSet(v BlockSizeIndex) *DescriptorFlags { *x = *x&^(0x7<<4) | (DescriptorFlags(v)&0x7<<4); return x }
+func (x *DescriptorFlags) ContentChecksumSet(v bool) *DescriptorFlags { const b = 1<<10; if v { *x = *x&^b | b } else { *x &^= b }; return x }
+func (x *DescriptorFlags) SizeSet(v bool) *DescriptorFlags { const b = 1<<11; if v { *x = *x&^b | b } else { *x &^= b }; return x }
+func (x *DescriptorFlags) BlockChecksumSet(v bool) *DescriptorFlags { const b = 1<<12; if v { *x = *x&^b | b } else { *x &^= b }; return x }
+func (x *DescriptorFlags) BlockIndependenceSet(v bool) *DescriptorFlags { const b = 1<<13; if v { *x = *x&^b | b } else { *x &^= b }; return x }
+func (x *DescriptorFlags) VersionSet(v uint16) *DescriptorFlags { *x = *x&^(0x3<<14) | (DescriptorFlags(v)&0x3<<14); return x }
+// Code generated by `gen.exe`. DO NOT EDIT.
+
+// DataBlockSize is defined as follow:
+//   field       bits
+//   -----       ----
+//   size        31
+//   compressed  1
+type DataBlockSize uint32
+
+// Getters.
+func (x DataBlockSize) size() int        { return int(x&0x7FFFFFFF) }
+func (x DataBlockSize) compressed() bool { return x>>31&1 != 0 }
+
+// Setters.
+func (x *DataBlockSize) sizeSet(v int) *DataBlockSize { *x = *x&^0x7FFFFFFF | DataBlockSize(v)&0x7FFFFFFF; return x }
+func (x *DataBlockSize) compressedSet(v bool) *DataBlockSize { const b = 1<<31; if v { *x = *x&^b | b } else { *x &^= b }; return x }

+ 51 - 0
gen.go

@@ -0,0 +1,51 @@
+//+build ignore
+
+package main
+
+import (
+	"log"
+	"os"
+
+	"github.com/pierrec/lz4"
+	"github.com/pierrec/packer"
+)
+
+type DescriptorFlags struct {
+	// BD
+	_              [4]int
+	BlockSizeIndex [3]lz4.BlockSizeIndex
+	_              [1]int
+	// FLG
+	_                 [2]int
+	ContentChecksum   [1]bool
+	Size              [1]bool
+	BlockChecksum     [1]bool
+	BlockIndependence [1]bool
+	Version           [2]uint16
+}
+
+type DataBlockSize struct {
+	size       [31]int
+	compressed bool
+}
+
+func main() {
+	out, err := os.Create("frame_gen.go")
+	if err != nil {
+		log.Fatal(err)
+	}
+	defer out.Close()
+
+	pkg := "v4"
+	for i, t := range []interface{}{
+		DescriptorFlags{}, DataBlockSize{},
+	} {
+		if i > 0 {
+			pkg = ""
+		}
+		err := packer.GenPackedStruct(out, &packer.Config{PkgName: pkg}, t)
+		if err != nil {
+			log.Fatalf("%T: %v", t, err)
+		}
+	}
+}

+ 9 - 0
go.mod

@@ -0,0 +1,9 @@
+module github.com/pierrec/lz4
+
+go 1.14
+
+require (
+	github.com/frankban/quicktest v1.9.0 // indirect
+	github.com/pierrec/packer v0.0.0-20200419211718-decbba9fa6fa // indirect
+	golang.org/x/tools v0.0.0-20200420001825-978e26b7c37c // indirect
+)

+ 30 - 0
go.sum

@@ -0,0 +1,30 @@
+github.com/frankban/quicktest v1.9.0 h1:jfEA+Psfr/pHsRJYPpHiNu7PGJnGctNxvTaM3K1EyXk=
+github.com/frankban/quicktest v1.9.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y=
+github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
+github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
+github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/pierrec/packer v0.0.0-20200419211718-decbba9fa6fa h1:yVA4oBytvOmjYjiQzeJL83TlMb/g2y7WbnrspRVtPMk=
+github.com/pierrec/packer v0.0.0-20200419211718-decbba9fa6fa/go.mod h1:GKrs5lzNeoNBN0l+jHHePFsqiNgXrZT5vQvGt28rWjI=
+github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20200420001825-978e26b7c37c h1:JzwTM5XxGxiCwZEIZQPG46csyhWQxQlu2uSi3bEza34=
+golang.org/x/tools v0.0.0-20200420001825-978e26b7c37c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

+ 24 - 1
internal/xxh32/xxh32zero.go

@@ -51,7 +51,7 @@ func (xxh *XXHZero) Size() int {
 	return 4
 }
 
-// BlockSize gives the minimum number of bytes accepted by Write().
+// BlockSizeIndex gives the minimum number of bytes accepted by Write().
 func (xxh *XXHZero) BlockSize() int {
 	return 1
 }
@@ -182,6 +182,29 @@ func ChecksumZero(input []byte) uint32 {
 	return h32
 }
 
+// ChecksumZero10 processes an 10 bytes input.
+func ChecksumZero10(x uint16, y uint64) uint32 {
+	h32 := 10 + prime5
+
+	h32 += (uint32(x)<<16 | uint32(y>>48)) * prime3
+	h32 = rol17(h32) * prime4
+	h32 += uint32(y>>16) * prime3
+	h32 = rol17(h32) * prime4
+
+	h32 += uint32(y>>8) & 0xFF * prime5
+	h32 = rol11(h32) * prime1
+	h32 += uint32(y) & 0xFF * prime5
+	h32 = rol11(h32) * prime1
+
+	h32 ^= h32 >> 15
+	h32 *= prime2
+	h32 ^= h32 >> 13
+	h32 *= prime3
+	h32 ^= h32 >> 16
+
+	return h32
+}
+
 // Uint32Zero hashes x with seed 0.
 func Uint32Zero(x uint32) uint32 {
 	h := prime5 + 4 + x*prime3

+ 1 - 1
internal/xxh32/xxh32zero_test.go

@@ -35,7 +35,7 @@ var testdata = []test{
 func TestZeroBlockSize(t *testing.T) {
 	var xxh xxh32.XXHZero
 	if s := xxh.BlockSize(); s <= 0 {
-		t.Errorf("invalid BlockSize: %d", s)
+		t.Errorf("invalid BlockSizeIndex: %d", s)
 	}
 }
 

+ 27 - 93
lz4.go

@@ -1,35 +1,14 @@
-// Package lz4 implements reading and writing lz4 compressed data (a frame),
-// as specified in http://fastcompression.blogspot.fr/2013/04/lz4-streaming-format-final.html.
-//
-// Although the block level compression and decompression functions are exposed and are fully compatible
-// with the lz4 block format definition, they are low level and should not be used directly.
-// For a complete description of an lz4 compressed block, see:
-// http://fastcompression.blogspot.fr/2011/05/lz4-explained.html
-//
-// See https://github.com/Cyan4973/lz4 for the reference C implementation.
-//
 package lz4
 
-import "math/bits"
-
-import "sync"
-
 const (
-	// Extension is the LZ4 frame file name extension
-	Extension = ".lz4"
-	// Version is the LZ4 frame format version
-	Version = 1
-
 	frameMagic     uint32 = 0x184D2204
 	frameSkipMagic uint32 = 0x184D2A50
 
 	// The following constants are used to setup the compression algorithm.
-	minMatch            = 4  // the minimum size of the match sequence size (4 bytes)
-	winSizeLog          = 16 // LZ4 64Kb window size limit
-	winSize             = 1 << winSizeLog
-	winMask             = winSize - 1 // 64Kb window of previous data for dependent blocks
-	compressedBlockFlag = 1 << 31
-	compressedBlockMask = compressedBlockFlag - 1
+	minMatch   = 4           // the minimum size of the match sequence size (4 bytes)
+	winSizeLog = 16          // LZ4 64Kb window size limit
+	winSize    = 1 << winSizeLog
+	winMask    = winSize - 1 // 64Kb window of previous data for dependent blocks
 
 	// hashLog determines the size of the hash table used to quickly find a previous match position.
 	// Its value influences the compression speed and memory usage, the lower the faster,
@@ -41,73 +20,28 @@ const (
 	mfLimit = 10 + minMatch // The last match cannot start within the last 14 bytes.
 )
 
-// map the block max size id with its value in bytes: 64Kb, 256Kb, 1Mb and 4Mb.
-const (
-	blockSize64K = 1 << (16 + 2*iota)
-	blockSize256K
-	blockSize1M
-	blockSize4M
-)
-
-var (
-	// Keep a pool of buffers for each valid block sizes.
-	bsMapValue = [...]*sync.Pool{
-		newBufferPool(2 * blockSize64K),
-		newBufferPool(2 * blockSize256K),
-		newBufferPool(2 * blockSize1M),
-		newBufferPool(2 * blockSize4M),
-	}
-)
+type _error string
 
-// newBufferPool returns a pool for buffers of the given size.
-func newBufferPool(size int) *sync.Pool {
-	return &sync.Pool{
-		New: func() interface{} {
-			return make([]byte, size)
-		},
-	}
-}
+func (e _error) Error() string { return string(e) }
 
-// getBuffer returns a buffer to its pool.
-func getBuffer(size int) []byte {
-	idx := blockSizeValueToIndex(size) - 4
-	return bsMapValue[idx].Get().([]byte)
-}
-
-// putBuffer returns a buffer to its pool.
-func putBuffer(size int, buf []byte) {
-	if cap(buf) > 0 {
-		idx := blockSizeValueToIndex(size) - 4
-		bsMapValue[idx].Put(buf[:cap(buf)])
-	}
-}
-func blockSizeIndexToValue(i byte) int {
-	return 1 << (16 + 2*uint(i))
-}
-func isValidBlockSize(size int) bool {
-	const blockSizeMask = blockSize64K | blockSize256K | blockSize1M | blockSize4M
-
-	return size&blockSizeMask > 0 && bits.OnesCount(uint(size)) == 1
-}
-func blockSizeValueToIndex(size int) byte {
-	return 4 + byte(bits.TrailingZeros(uint(size)>>16)/2)
-}
-
-// Header describes the various flags that can be set on a Writer or obtained from a Reader.
-// The default values match those of the LZ4 frame format definition
-// (http://fastcompression.blogspot.com/2013/04/lz4-streaming-format-final.html).
-//
-// NB. in a Reader, in case of concatenated frames, the Header values may change between Read() calls.
-// It is the caller's responsibility to check them if necessary.
-type Header struct {
-	BlockChecksum    bool   // Compressed blocks checksum flag.
-	NoChecksum       bool   // Frame checksum flag.
-	BlockMaxSize     int    // Size of the uncompressed data block (one of [64KB, 256KB, 1MB, 4MB]). Default=4MB.
-	Size             uint64 // Frame total size. It is _not_ computed by the Writer.
-	CompressionLevel int    // Compression level (higher is better, use 0 for fastest compression).
-	done             bool   // Header processed flag (Read or Write and checked).
-}
-
-func (h *Header) Reset() {
-	h.done = false
-}
+const (
+	// ErrInvalidSourceShortBuffer is returned by UncompressBlock or CompressBLock when a compressed
+	// block is corrupted or the destination buffer is not large enough for the uncompressed data.
+	ErrInvalidSourceShortBuffer _error = "lz4: invalid source or destination buffer too short"
+	// ErrClosed is returned when calling Write/Read or Close on an already closed Writer/Reader.
+	ErrClosed _error = "lz4: closed Writer"
+	// ErrInvalid is returned when reading an invalid LZ4 archive.
+	ErrInvalid _error = "lz4: bad magic number"
+	// ErrBlockDependency is returned when attempting to decompress an archive created with block dependency.
+	ErrBlockDependency _error = "lz4: block dependency not supported"
+	// ErrUnsupportedSeek is returned when attempting to Seek any way but forward from the current position.
+	ErrUnsupportedSeek _error = "lz4: can only seek forward from io.SeekCurrent"
+	// ErrInternalUnhandledState is an internal error.
+	ErrInternalUnhandledState _error = "lz4: unhandled state"
+	// ErrInvalidHeaderChecksum
+	ErrInvalidHeaderChecksum _error = "lz4: invalid header checksum"
+	// ErrInvalidBlockChecksum
+	ErrInvalidBlockChecksum _error = "lz4: invalid block checksum"
+	// ErrInvalidFrameChecksum
+	ErrInvalidFrameChecksum _error = "lz4: invalid frame checksum"
+)

+ 0 - 29
lz4_go1.10.go

@@ -1,29 +0,0 @@
-//+build go1.10
-
-package lz4
-
-import (
-	"fmt"
-	"strings"
-)
-
-func (h Header) String() string {
-	var s strings.Builder
-
-	s.WriteString(fmt.Sprintf("%T{", h))
-	if h.BlockChecksum {
-		s.WriteString("BlockChecksum: true ")
-	}
-	if h.NoChecksum {
-		s.WriteString("NoChecksum: true ")
-	}
-	if bs := h.BlockMaxSize; bs != 0 && bs != 4<<20 {
-		s.WriteString(fmt.Sprintf("BlockMaxSize: %d ", bs))
-	}
-	if l := h.CompressionLevel; l != 0 {
-		s.WriteString(fmt.Sprintf("CompressionLevel: %d ", l))
-	}
-	s.WriteByte('}')
-
-	return s.String()
-}

+ 0 - 29
lz4_notgo1.10.go

@@ -1,29 +0,0 @@
-//+build !go1.10
-
-package lz4
-
-import (
-	"bytes"
-	"fmt"
-)
-
-func (h Header) String() string {
-	var s bytes.Buffer
-
-	s.WriteString(fmt.Sprintf("%T{", h))
-	if h.BlockChecksum {
-		s.WriteString("BlockChecksum: true ")
-	}
-	if h.NoChecksum {
-		s.WriteString("NoChecksum: true ")
-	}
-	if bs := h.BlockMaxSize; bs != 0 && bs != 4<<20 {
-		s.WriteString(fmt.Sprintf("BlockMaxSize: %d ", bs))
-	}
-	if l := h.CompressionLevel; l != 0 {
-		s.WriteString(fmt.Sprintf("CompressionLevel: %d ", l))
-	}
-	s.WriteByte('}')
-
-	return s.String()
-}

+ 165 - 0
options.go

@@ -0,0 +1,165 @@
+package lz4
+
+import (
+	"fmt"
+	"runtime"
+	"sync"
+)
+
+//go:generate go run golang.org/x/tools/cmd/stringer -type=BlockSize,CompressionLevel -output options_gen.go
+
+// Option defines the parameters to setup an LZ4 Writer or Reader.
+type Option func(*_Writer) error
+
+// Default options.
+var (
+	defaultBlockSizeOption = BlockSizeOption(Block4Mb)
+	defaultChecksumOption  = ChecksumOption(true)
+	defaultConcurrency     = ConcurrencyOption(1)
+)
+
+const (
+	Block64Kb BlockSize = 1 << (16 + iota*2)
+	Block256Kb
+	Block1Mb
+	Block4Mb
+)
+
+var (
+	blockPool64K  = sync.Pool{New: func() interface{} { return make([]byte, Block64Kb) }}
+	blockPool256K = sync.Pool{New: func() interface{} { return make([]byte, Block256Kb) }}
+	blockPool1M   = sync.Pool{New: func() interface{} { return make([]byte, Block1Mb) }}
+	blockPool4M   = sync.Pool{New: func() interface{} { return make([]byte, Block4Mb) }}
+)
+
+// BlockSizeIndex defines the size of the blocks to be compressed.
+type BlockSize uint32
+
+func (b BlockSize) isValid() bool {
+	return b.index() > 0
+}
+
+func (b BlockSize) index() BlockSizeIndex {
+	switch b {
+	case Block64Kb:
+		return 4
+	case Block256Kb:
+		return 5
+	case Block1Mb:
+		return 6
+	case Block4Mb:
+		return 7
+	}
+	return 0
+}
+
+type BlockSizeIndex uint8
+
+func (b BlockSizeIndex) get() []byte {
+	var buf interface{}
+	switch b {
+	case 4:
+		buf = blockPool64K.Get()
+	case 5:
+		buf = blockPool256K.Get()
+	case 6:
+		buf = blockPool1M.Get()
+	case 7:
+		buf = blockPool4M.Get()
+	}
+	return buf.([]byte)
+}
+
+func (b BlockSizeIndex) put(buf []byte) {
+	switch b {
+	case 4:
+		blockPool64K.Put(buf)
+	case 5:
+		blockPool256K.Put(buf)
+	case 6:
+		blockPool1M.Put(buf)
+	case 7:
+		blockPool4M.Put(buf)
+	}
+}
+
+// BlockSizeOption defines the maximum size of compressed blocks (default=Block4Mb).
+func BlockSizeOption(size BlockSize) Option {
+	return func(w *_Writer) error {
+		if !size.isValid() {
+			return fmt.Errorf("lz4: invalid block size %d", size)
+		}
+		w.frame.Descriptor.Flags.BlockSizeIndexSet(size.index())
+		return nil
+	}
+}
+
+// BlockChecksumOption enables or disables block checksum (default=false).
+func BlockChecksumOption(flag bool) Option {
+	return func(w *_Writer) error {
+		w.frame.Descriptor.Flags.BlockChecksumSet(flag)
+		return nil
+	}
+}
+
+// ChecksumOption enables/disables all blocks checksum (default=true).
+func ChecksumOption(flag bool) Option {
+	return func(w *_Writer) error {
+		w.frame.Descriptor.Flags.ContentChecksumSet(flag)
+		return nil
+	}
+}
+
+// SizeOption sets the size of the original uncompressed data (default=0).
+func SizeOption(size uint64) Option {
+	return func(w *_Writer) error {
+		w.frame.Descriptor.Flags.SizeSet(size > 0)
+		w.frame.Descriptor.ContentSize = size
+		return nil
+	}
+}
+
+// ConcurrencyOption sets the number of go routines used for compression.
+// If n<0, then the output of runtime.GOMAXPROCS(0) is used.
+func ConcurrencyOption(n int) Option {
+	return func(w *_Writer) error {
+		switch n {
+		case 0, 1:
+		default:
+			if n < 0 {
+				n = runtime.GOMAXPROCS(0)
+			}
+		}
+		w.num = n
+		return nil
+	}
+}
+
+// CompressionLevel defines the level of compression to use. The higher the better, but slower, compression.
+type CompressionLevel uint32
+
+const (
+	Fast   CompressionLevel = 0
+	Level1 CompressionLevel = 1 << (8 + iota)
+	Level2
+	Level3
+	Level4
+	Level5
+	Level6
+	Level7
+	Level8
+	Level9
+)
+
+// CompressionLevelOption defines the compression level (default=Fast).
+func CompressionLevelOption(level CompressionLevel) Option {
+	return func(w *_Writer) error {
+		switch level {
+		case Fast, Level1, Level2, Level3, Level4, Level5, Level6, Level7, Level8, Level9:
+		default:
+			return fmt.Errorf("lz4: invalid compression level %d", level)
+		}
+		w.level = level
+		return nil
+	}
+}

+ 92 - 0
options_gen.go

@@ -0,0 +1,92 @@
+// Code generated by "stringer -type=BlockSize,CompressionLevel -output options_gen.go"; DO NOT EDIT.
+
+package lz4
+
+import "strconv"
+
+func _() {
+	// An "invalid array index" compiler error signifies that the constant values have changed.
+	// Re-run the stringer command to generate them again.
+	var x [1]struct{}
+	_ = x[Block64Kb-65536]
+	_ = x[Block256Kb-262144]
+	_ = x[Block1Mb-1048576]
+	_ = x[Block4Mb-4194304]
+}
+
+const (
+	_BlockSize_name_0 = "Block64Kb"
+	_BlockSize_name_1 = "Block256Kb"
+	_BlockSize_name_2 = "Block1Mb"
+	_BlockSize_name_3 = "Block4Mb"
+)
+
+func (i BlockSize) String() string {
+	switch {
+	case i == 65536:
+		return _BlockSize_name_0
+	case i == 262144:
+		return _BlockSize_name_1
+	case i == 1048576:
+		return _BlockSize_name_2
+	case i == 4194304:
+		return _BlockSize_name_3
+	default:
+		return "BlockSize(" + strconv.FormatInt(int64(i), 10) + ")"
+	}
+}
+func _() {
+	// An "invalid array index" compiler error signifies that the constant values have changed.
+	// Re-run the stringer command to generate them again.
+	var x [1]struct{}
+	_ = x[Fast-0]
+	_ = x[Level1-512]
+	_ = x[Level2-1024]
+	_ = x[Level3-2048]
+	_ = x[Level4-4096]
+	_ = x[Level5-8192]
+	_ = x[Level6-16384]
+	_ = x[Level7-32768]
+	_ = x[Level8-65536]
+	_ = x[Level9-131072]
+}
+
+const (
+	_CompressionLevel_name_0 = "Fast"
+	_CompressionLevel_name_1 = "Level1"
+	_CompressionLevel_name_2 = "Level2"
+	_CompressionLevel_name_3 = "Level3"
+	_CompressionLevel_name_4 = "Level4"
+	_CompressionLevel_name_5 = "Level5"
+	_CompressionLevel_name_6 = "Level6"
+	_CompressionLevel_name_7 = "Level7"
+	_CompressionLevel_name_8 = "Level8"
+	_CompressionLevel_name_9 = "Level9"
+)
+
+func (i CompressionLevel) String() string {
+	switch {
+	case i == 0:
+		return _CompressionLevel_name_0
+	case i == 512:
+		return _CompressionLevel_name_1
+	case i == 1024:
+		return _CompressionLevel_name_2
+	case i == 2048:
+		return _CompressionLevel_name_3
+	case i == 4096:
+		return _CompressionLevel_name_4
+	case i == 8192:
+		return _CompressionLevel_name_5
+	case i == 16384:
+		return _CompressionLevel_name_6
+	case i == 32768:
+		return _CompressionLevel_name_7
+	case i == 65536:
+		return _CompressionLevel_name_8
+	case i == 131072:
+		return _CompressionLevel_name_9
+	default:
+		return "CompressionLevel(" + strconv.FormatInt(int64(i), 10) + ")"
+	}
+}

+ 96 - 304
reader.go

@@ -1,335 +1,127 @@
 package lz4
 
 import (
-	"encoding/binary"
-	"fmt"
 	"io"
-	"io/ioutil"
-
-	"github.com/pierrec/lz4/internal/xxh32"
 )
 
-// Reader implements the LZ4 frame decoder.
-// The Header is set after the first call to Read().
-// The Header may change between Read() calls in case of concatenated frames.
-type Reader struct {
-	Header
-	// Handler called when a block has been successfully read.
-	// It provides the number of bytes read.
-	OnBlockDone func(size int)
-
-	buf      [8]byte       // Scrap buffer.
-	pos      int64         // Current position in src.
-	src      io.Reader     // Source.
-	zdata    []byte        // Compressed data.
-	data     []byte        // Uncompressed data.
-	idx      int           // Index of unread bytes into data.
-	checksum xxh32.XXHZero // Frame hash.
-	skip     int64         // Bytes to skip before next read.
-	dpos     int64         // Position in dest
+var readerStates = []aState{
+	noState:     newState,
+	newState:    headerState,
+	headerState: readState,
+	readState:   closedState,
+	closedState: newState,
+	errorState:  newState,
 }
 
 // NewReader returns a new LZ4 frame decoder.
-// No access to the underlying io.Reader is performed.
-func NewReader(src io.Reader) *Reader {
-	r := &Reader{src: src}
-	return r
+func NewReader(r io.Reader) io.Reader {
+	zr := &_Reader{src: r}
+	zr.state.init(readerStates)
+	return zr
 }
 
-// readHeader checks the frame magic number and parses the frame descriptoz.
-// Skippable frames are supported even as a first frame although the LZ4
-// specifications recommends skippable frames not to be used as first frames.
-func (z *Reader) readHeader(first bool) error {
-	defer z.checksum.Reset()
-
-	buf := z.buf[:]
-	for {
-		magic, err := z.readUint32()
-		if err != nil {
-			z.pos += 4
-			if !first && err == io.ErrUnexpectedEOF {
-				return io.EOF
-			}
-			return err
-		}
-		if magic == frameMagic {
-			break
-		}
-		if magic>>8 != frameSkipMagic>>8 {
-			return ErrInvalid
-		}
-		skipSize, err := z.readUint32()
-		if err != nil {
-			return err
-		}
-		z.pos += 4
-		m, err := io.CopyN(ioutil.Discard, z.src, int64(skipSize))
-		if err != nil {
-			return err
-		}
-		z.pos += m
-	}
-
-	// Header.
-	if _, err := io.ReadFull(z.src, buf[:2]); err != nil {
-		return err
-	}
-	z.pos += 8
-
-	b := buf[0]
-	if v := b >> 6; v != Version {
-		return fmt.Errorf("lz4: invalid version: got %d; expected %d", v, Version)
-	}
-	if b>>5&1 == 0 {
-		return ErrBlockDependency
-	}
-	z.BlockChecksum = b>>4&1 > 0
-	frameSize := b>>3&1 > 0
-	z.NoChecksum = b>>2&1 == 0
-
-	bmsID := buf[1] >> 4 & 0x7
-	if bmsID < 4 || bmsID > 7 {
-		return fmt.Errorf("lz4: invalid block max size ID: %d", bmsID)
-	}
-	bSize := blockSizeIndexToValue(bmsID - 4)
-	z.BlockMaxSize = bSize
-
-	// Allocate the compressed/uncompressed buffers.
-	// The compressed buffer cannot exceed the uncompressed one.
-	if n := 2 * bSize; cap(z.zdata) < n {
-		z.zdata = make([]byte, n, n)
-	}
-	if debugFlag {
-		debug("header block max size id=%d size=%d", bmsID, bSize)
-	}
-	z.zdata = z.zdata[:bSize]
-	z.data = z.zdata[:cap(z.zdata)][bSize:]
-	z.idx = len(z.data)
-
-	_, _ = z.checksum.Write(buf[0:2])
+type _Reader struct {
+	state _State
+	buf   [9]byte   // frame descriptor needs at most 8+1=9 bytes
+	src   io.Reader // source reader
+	frame Frame     // frame being read
+	data  []byte    // pending data
+	idx   int       // size of pending data
+}
 
-	if frameSize {
-		buf := buf[:8]
-		if _, err := io.ReadFull(z.src, buf); err != nil {
-			return err
+// Size returns the size of the underlying uncompressed data, if set in the stream.
+func (r *_Reader) Size() int {
+	switch r.state.state {
+	case readState, closedState:
+		if r.frame.Descriptor.Flags.Size() {
+			return int(r.frame.Descriptor.ContentSize)
 		}
-		z.Size = binary.LittleEndian.Uint64(buf)
-		z.pos += 8
-		_, _ = z.checksum.Write(buf)
 	}
-
-	// Header checksum.
-	if _, err := io.ReadFull(z.src, buf[:1]); err != nil {
-		return err
-	}
-	z.pos++
-	if h := byte(z.checksum.Sum32() >> 8 & 0xFF); h != buf[0] {
-		return fmt.Errorf("lz4: invalid header checksum: got %x; expected %x", buf[0], h)
-	}
-
-	z.Header.done = true
-	if debugFlag {
-		debug("header read: %v", z.Header)
-	}
-
-	return nil
+	return 0
 }
 
-// Read decompresses data from the underlying source into the supplied buffer.
-//
-// Since there can be multiple streams concatenated, Header values may
-// change between calls to Read(). If that is the case, no data is actually read from
-// the underlying io.Reader, to allow for potential input buffer resizing.
-func (z *Reader) Read(buf []byte) (int, error) {
-	if debugFlag {
-		debug("Read buf len=%d", len(buf))
-	}
-	if !z.Header.done {
-		if err := z.readHeader(true); err != nil {
-			return 0, err
-		}
-		if debugFlag {
-			debug("header read OK compressed buffer %d / %d uncompressed buffer %d : %d index=%d",
-				len(z.zdata), cap(z.zdata), len(z.data), cap(z.data), z.idx)
+func (r *_Reader) Read(buf []byte) (n int, err error) {
+	defer r.state.check(&err)
+	switch r.state.state {
+	case closedState, errorState:
+		return 0, r.state.err
+	case newState:
+		// First initialization.
+		r.state.next(nil)
+		if err = r.frame.initR(r); r.state.next(err) {
+			return
 		}
+		r.state.next(nil)
+		r.data = r.frame.Descriptor.Flags.BlockSizeIndex().get()
+	default:
+		return 0, r.state.fail()
 	}
-
 	if len(buf) == 0 {
-		return 0, nil
+		return
 	}
 
-	if z.idx == len(z.data) {
-		// No data ready for reading, process the next block.
-		if debugFlag {
-			debug("reading block from writer")
-		}
-		// Reset uncompressed buffer
-		z.data = z.zdata[:cap(z.zdata)][len(z.zdata):]
-
-		// Block length: 0 = end of frame, highest bit set: uncompressed.
-		bLen, err := z.readUint32()
-		if err != nil {
-			return 0, err
-		}
-		z.pos += 4
-
-		if bLen == 0 {
-			// End of frame reached.
-			if !z.NoChecksum {
-				// Validate the frame checksum.
-				checksum, err := z.readUint32()
-				if err != nil {
-					return 0, err
-				}
-				if debugFlag {
-					debug("frame checksum got=%x / want=%x", z.checksum.Sum32(), checksum)
-				}
-				z.pos += 4
-				if h := z.checksum.Sum32(); checksum != h {
-					return 0, fmt.Errorf("lz4: invalid frame checksum: got %x; expected %x", h, checksum)
-				}
-			}
-
-			// Get ready for the next concatenated frame and keep the position.
-			pos := z.pos
-			z.Reset(z.src)
-			z.pos = pos
-
-			// Since multiple frames can be concatenated, check for more.
-			return 0, z.readHeader(false)
+	if r.idx > 0 {
+		// Some left over data, use it.
+		bn := copy(buf, r.data[r.idx:])
+		n += bn
+		r.idx += bn
+		if r.idx == len(r.data) {
+			// All data read, get ready for the next Read.
+			r.idx = 0
 		}
-
-		if debugFlag {
-			debug("raw block size %d", bLen)
+		return
+	}
+	// No uncompressed data yet.
+	var bn int
+	for len(buf) >= len(r.data) {
+		// Input buffer large enough and no pending data: uncompress directly into it.
+		switch bn, err = r.frame.Blocks.Block.uncompress(r, buf); err {
+		case nil:
+			n += bn
+			buf = buf[bn:]
+		case io.EOF:
+			goto close
+		default:
+			return
 		}
-		if bLen&compressedBlockFlag > 0 {
-			// Uncompressed block.
-			bLen &= compressedBlockMask
-			if debugFlag {
-				debug("uncompressed block size %d", bLen)
-			}
-			if int(bLen) > cap(z.data) {
-				return 0, fmt.Errorf("lz4: invalid block size: %d", bLen)
-			}
-			z.data = z.data[:bLen]
-			if _, err := io.ReadFull(z.src, z.data); err != nil {
-				return 0, err
-			}
-			z.pos += int64(bLen)
-			if z.OnBlockDone != nil {
-				z.OnBlockDone(int(bLen))
-			}
-
-			if z.BlockChecksum {
-				checksum, err := z.readUint32()
-				if err != nil {
-					return 0, err
-				}
-				z.pos += 4
-
-				if h := xxh32.ChecksumZero(z.data); h != checksum {
-					return 0, fmt.Errorf("lz4: invalid block checksum: got %x; expected %x", h, checksum)
-				}
-			}
-
-		} else {
-			// Compressed block.
-			if debugFlag {
-				debug("compressed block size %d", bLen)
-			}
-			if int(bLen) > cap(z.data) {
-				return 0, fmt.Errorf("lz4: invalid block size: %d", bLen)
-			}
-			zdata := z.zdata[:bLen]
-			if _, err := io.ReadFull(z.src, zdata); err != nil {
-				return 0, err
-			}
-			z.pos += int64(bLen)
-
-			if z.BlockChecksum {
-				checksum, err := z.readUint32()
-				if err != nil {
-					return 0, err
-				}
-				z.pos += 4
-
-				if h := xxh32.ChecksumZero(zdata); h != checksum {
-					return 0, fmt.Errorf("lz4: invalid block checksum: got %x; expected %x", h, checksum)
-				}
-			}
-
-			n, err := UncompressBlock(zdata, z.data)
-			if err != nil {
-				return 0, err
-			}
-			z.data = z.data[:n]
-			if z.OnBlockDone != nil {
-				z.OnBlockDone(n)
-			}
-		}
-
-		if !z.NoChecksum {
-			_, _ = z.checksum.Write(z.data)
-			if debugFlag {
-				debug("current frame checksum %x", z.checksum.Sum32())
-			}
-		}
-		z.idx = 0
-	}
-
-	if z.skip > int64(len(z.data[z.idx:])) {
-		z.skip -= int64(len(z.data[z.idx:]))
-		z.dpos += int64(len(z.data[z.idx:]))
-		z.idx = len(z.data)
-		return 0, nil
 	}
-
-	z.idx += int(z.skip)
-	z.dpos += z.skip
-	z.skip = 0
-
-	n := copy(buf, z.data[z.idx:])
-	z.idx += n
-	z.dpos += int64(n)
-	if debugFlag {
-		debug("copied %d bytes to input", n)
-	}
-
-	return n, nil
+	if n > 0 {
+		// Some data was read, done for now.
+		return
+	}
+	// Read the next block.
+	switch bn, err = r.frame.Blocks.Block.uncompress(r, r.data); err {
+	case nil:
+		n += bn
+	case io.EOF:
+		goto close
+	}
+	return
+close:
+	n += bn
+	err = r.frame.closeR(r)
+	r.frame.Descriptor.Flags.BlockSizeIndex().put(r.data)
+	r.reset(nil)
+	return
 }
 
-// Seek implements io.Seeker, but supports seeking forward from the current
-// position only. Any other seek will return an error. Allows skipping output
-// bytes which aren't needed, which in some scenarios is faster than reading
-// and discarding them.
-// Note this may cause future calls to Read() to read 0 bytes if all of the
-// data they would have returned is skipped.
-func (z *Reader) Seek(offset int64, whence int) (int64, error) {
-	if offset < 0 || whence != io.SeekCurrent {
-		return z.dpos + z.skip, ErrUnsupportedSeek
-	}
-	z.skip += offset
-	return z.dpos + z.skip, nil
+func (r *_Reader) reset(reader io.Reader) {
+	r.src = reader
+	r.data = nil
+	r.idx = 0
 }
 
-// Reset discards the Reader's state and makes it equivalent to the
-// result of its original state from NewReader, but reading from r instead.
-// This permits reusing a Reader rather than allocating a new one.
-func (z *Reader) Reset(r io.Reader) {
-	z.Header = Header{}
-	z.pos = 0
-	z.src = r
-	z.zdata = z.zdata[:0]
-	z.data = z.data[:0]
-	z.idx = 0
-	z.checksum.Reset()
+// Reset clears the state of the Reader r such that it is equivalent to its
+// initial state from NewReader, but instead writing to writer.
+// No access to reader is performed.
+//
+// w.Close must be called before Reset.
+func (r *_Reader) Reset(reader io.Reader) {
+	r.reset(reader)
+	r.state.state = noState
+	r.state.next(nil)
 }
 
-// readUint32 reads an uint32 into the supplied buffer.
-// The idea is to make use of the already allocated buffers avoiding additional allocations.
-func (z *Reader) readUint32() (uint32, error) {
-	buf := z.buf[:4]
-	_, err := io.ReadFull(z.src, buf)
-	x := binary.LittleEndian.Uint32(buf)
-	return x, err
+func (r *_Reader) Seek(offset int64, whence int) (int64, error) {
+	panic("TODO")
 }

+ 65 - 0
state.go

@@ -0,0 +1,65 @@
+package lz4
+
+import (
+	"errors"
+	"fmt"
+	"io"
+)
+
+//go:generate go run golang.org/x/tools/cmd/stringer -type=aState -output state_gen.go
+
+const (
+	noState     aState = iota // uninitialized reader
+	errorState                // unrecoverable error encountered
+	newState                  // instantiated object
+	headerState               // processing header
+	readState                 // reading data
+	writeState                // writing data
+	closedState               // all done
+)
+
+type (
+	aState uint8
+	_State struct {
+		states []aState
+		state  aState
+		err    error
+	}
+)
+
+func (s *_State) init(states []aState) *_State {
+	s.states = states
+	s.state = states[0]
+	return s
+}
+
+// next sets the state to the next one unless it is passed a non nil error.
+// It returns whether or not it is in error.
+func (s *_State) next(err error) bool {
+	if err != nil {
+		s.err = fmt.Errorf("%s: %w", s.state, err)
+		s.state = errorState
+		return true
+	}
+	s.state = s.states[s.state]
+	return false
+}
+
+// check sets s in error if not already in error and if the error is not nil or io.EOF,
+func (s *_State) check(errp *error) {
+	if s.state == errorState || errp == nil {
+		return
+	}
+	if err := *errp; err != nil {
+		s.err = fmt.Errorf("%s: %w", s.state, err)
+		if !errors.Is(err, io.EOF) {
+			s.state = errorState
+		}
+	}
+}
+
+func (s *_State) fail() error {
+	s.state = errorState
+	s.err = fmt.Errorf("%w: next state for %q", ErrInternalUnhandledState, s.state)
+	return s.err
+}

+ 29 - 0
state_gen.go

@@ -0,0 +1,29 @@
+// Code generated by "stringer -type=aState -output state_gen.go"; DO NOT EDIT.
+
+package lz4
+
+import "strconv"
+
+func _() {
+	// An "invalid array index" compiler error signifies that the constant values have changed.
+	// Re-run the stringer command to generate them again.
+	var x [1]struct{}
+	_ = x[noState-0]
+	_ = x[errorState-1]
+	_ = x[newState-2]
+	_ = x[headerState-3]
+	_ = x[readState-4]
+	_ = x[writeState-5]
+	_ = x[closedState-6]
+}
+
+const _aState_name = "noStateerrorStatenewStateheaderStatereadStatewriteStateclosedState"
+
+var _aState_index = [...]uint8{0, 7, 17, 25, 36, 45, 55, 66}
+
+func (i aState) String() string {
+	if i >= aState(len(_aState_index)-1) {
+		return "aState(" + strconv.FormatInt(int64(i), 10) + ")"
+	}
+	return _aState_name[_aState_index[i]:_aState_index[i+1]]
+}

+ 119 - 356
writer.go

@@ -1,408 +1,171 @@
 package lz4
 
-import (
-	"encoding/binary"
-	"fmt"
-	"github.com/pierrec/lz4/internal/xxh32"
-	"io"
-	"runtime"
-)
-
-// zResult contains the results of compressing a block.
-type zResult struct {
-	size     uint32 // Block header
-	data     []byte // Compressed data
-	checksum uint32 // Data checksum
-}
-
-// Writer implements the LZ4 frame encoder.
-type Writer struct {
-	Header
-	// Handler called when a block has been successfully written out.
-	// It provides the number of bytes written.
-	OnBlockDone func(size int)
-
-	buf       [19]byte      // magic number(4) + header(flags(2)+[Size(8)+DictID(4)]+checksum(1)) does not exceed 19 bytes
-	dst       io.Writer     // Destination.
-	checksum  xxh32.XXHZero // Frame checksum.
-	data      []byte        // Data to be compressed + buffer for compressed data.
-	idx       int           // Index into data.
-	hashtable [winSize]int  // Hash table used in CompressBlock().
-
-	// For concurrency.
-	c   chan chan zResult // Channel for block compression goroutines and writer goroutine.
-	err error             // Any error encountered while writing to the underlying destination.
+import "io"
+
+var writerStates = []aState{
+	noState:     newState,
+	newState:    headerState,
+	headerState: writeState,
+	writeState:  closedState,
+	closedState: newState,
+	errorState:  newState,
 }
 
 // NewWriter returns a new LZ4 frame encoder.
-// No access to the underlying io.Writer is performed.
-// The supplied Header is checked at the first Write.
-// It is ok to change it before the first Write but then not until a Reset() is performed.
-func NewWriter(dst io.Writer) *Writer {
-	z := new(Writer)
-	z.Reset(dst)
-	return z
-}
-
-// WithConcurrency sets the number of concurrent go routines used for compression.
-// A negative value sets the concurrency to GOMAXPROCS.
-func (z *Writer) WithConcurrency(n int) *Writer {
-	switch {
-	case n == 0 || n == 1:
-		z.c = nil
-		return z
-	case n < 0:
-		n = runtime.GOMAXPROCS(0)
-	}
-	z.c = make(chan chan zResult, n)
-	// Writer goroutine managing concurrent block compression goroutines.
-	go func() {
-		// Process next block compression item.
-		for c := range z.c {
-			// Read the next compressed block result.
-			// Waiting here ensures that the blocks are output in the order they were sent.
-			// The incoming channel is always closed as it indicates to the caller that
-			// the block has been processed.
-			res := <-c
-			n := len(res.data)
-			if n == 0 {
-				// Notify the block compression routine that we are done with its result.
-				// This is used when a sentinel block is sent to terminate the compression.
-				close(c)
-				return
-			}
-			// Write the block.
-			if err := z.writeUint32(res.size); err != nil && z.err == nil {
-				z.err = err
-			}
-			if _, err := z.dst.Write(res.data); err != nil && z.err == nil {
-				z.err = err
-			}
-			if z.BlockChecksum {
-				if err := z.writeUint32(res.checksum); err != nil && z.err == nil {
-					z.err = err
-				}
-			}
-			if isCompressed := res.size&compressedBlockFlag == 0; isCompressed {
-				// It is now safe to release the buffer as no longer in use by any goroutine.
-				putBuffer(cap(res.data), res.data)
-			}
-			if h := z.OnBlockDone; h != nil {
-				h(n)
-			}
-			close(c)
-		}
-	}()
-	return z
+func NewWriter(w io.Writer, options ...Option) (io.WriteCloser, error) {
+	zw := new(_Writer)
+	_ = defaultBlockSizeOption(zw)
+	_ = defaultChecksumOption(zw)
+	_ = defaultConcurrency(zw)
+	if err := zw.Reset(w, options...); err != nil {
+		return nil, err
+	}
+	return zw, nil
 }
 
-// newBuffers instantiates new buffers which size matches the one in Header.
-// The returned buffers are for decompression and compression respectively.
-func (z *Writer) newBuffers() {
-	bSize := z.Header.BlockMaxSize
-	buf := getBuffer(bSize)
-	z.data = buf[:bSize] // Uncompressed buffer is the first half.
+type _Writer struct {
+	state _State
+	buf   [11]byte         // frame descriptor needs at most 4+8+1=11 bytes
+	src   io.Writer        // destination writer
+	level CompressionLevel // how hard to try
+	num   int              // concurrency level
+	frame Frame            // frame being built
+	ht    []int            // hash table (set if no concurrency)
+	data  []byte           // pending data
+	idx   int              // size of pending data
 }
 
-// freeBuffers puts the writer's buffers back to the pool.
-func (z *Writer) freeBuffers() {
-	// Put the buffer back into the pool, if any.
-	putBuffer(z.Header.BlockMaxSize, z.data)
-	z.data = nil
+func (w *_Writer) isNotConcurrent() bool {
+	return w.num == 1
 }
 
-// writeHeader builds and writes the header (magic+header) to the underlying io.Writer.
-func (z *Writer) writeHeader() error {
-	// Default to 4Mb if BlockMaxSize is not set.
-	if z.Header.BlockMaxSize == 0 {
-		z.Header.BlockMaxSize = blockSize4M
-	}
-	// The only option that needs to be validated.
-	bSize := z.Header.BlockMaxSize
-	if !isValidBlockSize(z.Header.BlockMaxSize) {
-		return fmt.Errorf("lz4: invalid block max size: %d", bSize)
-	}
-	// Allocate the compressed/uncompressed buffers.
-	// The compressed buffer cannot exceed the uncompressed one.
-	z.newBuffers()
-	z.idx = 0
-
-	// Size is optional.
-	buf := z.buf[:]
-
-	// Set the fixed size data: magic number, block max size and flags.
-	binary.LittleEndian.PutUint32(buf[0:], frameMagic)
-	flg := byte(Version << 6)
-	flg |= 1 << 5 // No block dependency.
-	if z.Header.BlockChecksum {
-		flg |= 1 << 4
-	}
-	if z.Header.Size > 0 {
-		flg |= 1 << 3
-	}
-	if !z.Header.NoChecksum {
-		flg |= 1 << 2
-	}
-	buf[4] = flg
-	buf[5] = blockSizeValueToIndex(z.Header.BlockMaxSize) << 4
-
-	// Current buffer size: magic(4) + flags(1) + block max size (1).
-	n := 6
-	// Optional items.
-	if z.Header.Size > 0 {
-		binary.LittleEndian.PutUint64(buf[n:], z.Header.Size)
-		n += 8
-	}
-
-	// The header checksum includes the flags, block max size and optional Size.
-	buf[n] = byte(xxh32.ChecksumZero(buf[4:n]) >> 8 & 0xFF)
-	z.checksum.Reset()
-
-	// Header ready, write it out.
-	if _, err := z.dst.Write(buf[0 : n+1]); err != nil {
-		return err
-	}
-	z.Header.done = true
-	if debugFlag {
-		debug("wrote header %v", z.Header)
-	}
-
-	return nil
-}
-
-// Write compresses data from the supplied buffer into the underlying io.Writer.
-// Write does not return until the data has been written.
-func (z *Writer) Write(buf []byte) (int, error) {
-	if !z.Header.done {
-		if err := z.writeHeader(); err != nil {
-			return 0, err
+func (w *_Writer) Write(buf []byte) (n int, err error) {
+	defer w.state.check(&err)
+	switch w.state.state {
+	case closedState, errorState:
+		return 0, w.state.err
+	case newState:
+		w.state.next(nil)
+		if err = w.frame.Descriptor.write(w); w.state.next(err) {
+			return
 		}
-	}
-	if debugFlag {
-		debug("input buffer len=%d index=%d", len(buf), z.idx)
+	default:
+		return 0, w.state.fail()
 	}
 
-	zn := len(z.data)
-	var n int
+	zn := len(w.data)
 	for len(buf) > 0 {
-		if z.idx == 0 && len(buf) >= zn {
+		if w.idx == 0 && len(buf) >= zn {
 			// Avoid a copy as there is enough data for a block.
-			if err := z.compressBlock(buf[:zn]); err != nil {
-				return n, err
+			if err = w.write(); err != nil {
+				return
 			}
 			n += zn
 			buf = buf[zn:]
 			continue
 		}
 		// Accumulate the data to be compressed.
-		m := copy(z.data[z.idx:], buf)
+		m := copy(w.data[w.idx:], buf)
 		n += m
-		z.idx += m
+		w.idx += m
 		buf = buf[m:]
-		if debugFlag {
-			debug("%d bytes copied to buf, current index %d", n, z.idx)
-		}
 
-		if z.idx < len(z.data) {
+		if w.idx < len(w.data) {
 			// Buffer not filled.
-			if debugFlag {
-				debug("need more data for compression")
-			}
-			return n, nil
+			return
 		}
 
 		// Buffer full.
-		if err := z.compressBlock(z.data); err != nil {
-			return n, err
+		if err = w.write(); err != nil {
+			return
 		}
-		z.idx = 0
+		w.idx = 0
 	}
-
-	return n, nil
+	return
 }
 
-// compressBlock compresses a block.
-func (z *Writer) compressBlock(data []byte) error {
-	if !z.NoChecksum {
-		_, _ = z.checksum.Write(data)
-	}
-
-	if z.c != nil {
-		c := make(chan zResult)
-		z.c <- c // Send now to guarantee order
-		go writerCompressBlock(c, z.Header, data)
-		return nil
-	}
-
-	zdata := z.data[z.Header.BlockMaxSize:cap(z.data)]
-	// The compressed block size cannot exceed the input's.
-	var zn int
-
-	if level := z.Header.CompressionLevel; level != 0 {
-		zn, _ = CompressBlockHC(data, zdata, level)
-	} else {
-		zn, _ = CompressBlock(data, zdata, z.hashtable[:])
-	}
-
-	var bLen uint32
-	if debugFlag {
-		debug("block compression %d => %d", len(data), zn)
-	}
-	if zn > 0 && zn < len(data) {
-		// Compressible and compressed size smaller than uncompressed: ok!
-		bLen = uint32(zn)
-		zdata = zdata[:zn]
-	} else {
-		// Uncompressed block.
-		bLen = uint32(len(data)) | compressedBlockFlag
-		zdata = data
-	}
-	if debugFlag {
-		debug("block compression to be written len=%d data len=%d", bLen, len(zdata))
-	}
-
-	// Write the block.
-	if err := z.writeUint32(bLen); err != nil {
-		return err
-	}
-	written, err := z.dst.Write(zdata)
-	if err != nil {
-		return err
-	}
-	if h := z.OnBlockDone; h != nil {
-		h(written)
-	}
-
-	if !z.BlockChecksum {
-		if debugFlag {
-			debug("current frame checksum %x", z.checksum.Sum32())
+func (w *_Writer) write() error {
+	if w.isNotConcurrent() {
+		return w.frame.Blocks.Block.compress(w, w.data, w.ht).write(w)
+	}
+	size := w.frame.Descriptor.Flags.BlockSizeIndex()
+	c := make(chan *FrameDataBlock)
+	w.frame.Blocks.Blocks <- c
+	go func(c chan *FrameDataBlock, data []byte, size BlockSizeIndex) {
+		b := newFrameDataBlock(size)
+		zdata := b.Data
+		c <- b.compress(w, data, nil)
+		// Wait for the compressed or uncompressed data to no longer be in use
+		// and free the allocated buffers
+		if !b.Size.compressed() {
+			zdata, data = data, zdata
 		}
-		return nil
-	}
-	checksum := xxh32.ChecksumZero(zdata)
-	if debugFlag {
-		debug("block checksum %x", checksum)
-		defer func() { debug("current frame checksum %x", z.checksum.Sum32()) }()
-	}
-	return z.writeUint32(checksum)
-}
+		size.put(data)
+		<-c
+		size.put(zdata)
+	}(c, w.data, size)
 
-// Flush flushes any pending compressed data to the underlying writer.
-// Flush does not return until the data has been written.
-// If the underlying writer returns an error, Flush returns that error.
-func (z *Writer) Flush() error {
-	if debugFlag {
-		debug("flush with index %d", z.idx)
-	}
-	if z.idx == 0 {
-		return nil
+	if w.idx > 0 {
+		// Not closed.
+		w.data = size.get()
 	}
+	w.idx = 0
 
-	data := z.data[:z.idx]
-	z.idx = 0
-	if z.c == nil {
-		return z.compressBlock(data)
-	}
-	if !z.NoChecksum {
-		_, _ = z.checksum.Write(data)
-	}
-	c := make(chan zResult)
-	z.c <- c
-	writerCompressBlock(c, z.Header, data)
 	return nil
 }
 
-func (z *Writer) close() error {
-	if z.c == nil {
+// Close closes the Writer, flushing any unwritten data to the underlying io.Writer,
+// but does not close the underlying io.Writer.
+func (w *_Writer) Close() error {
+	switch w.state.state {
+	case writeState:
+	case errorState:
+		return w.state.err
+	default:
 		return nil
 	}
-	// Send a sentinel block (no data to compress) to terminate the writer main goroutine.
-	c := make(chan zResult)
-	z.c <- c
-	c <- zResult{}
-	// Wait for the main goroutine to complete.
-	<-c
-	// At this point the main goroutine has shut down or is about to return.
-	z.c = nil
-	return z.err
-}
-
-// Close closes the Writer, flushing any unwritten data to the underlying io.Writer, but does not close the underlying io.Writer.
-func (z *Writer) Close() error {
-	if !z.Header.done {
-		if err := z.writeHeader(); err != nil {
+	var err error
+	defer func() { w.state.next(err) }()
+	if idx := w.idx; idx > 0 {
+		// Flush pending data.
+		w.data = w.data[:idx]
+		w.idx = 0
+		if err = w.write(); err != nil {
 			return err
 		}
+		w.data = nil
 	}
-	if err := z.Flush(); err != nil {
-		return err
-	}
-	if err := z.close(); err != nil {
-		return err
-	}
-	z.freeBuffers()
-
-	if debugFlag {
-		debug("writing last empty block")
-	}
-	if err := z.writeUint32(0); err != nil {
-		return err
+	if w.isNotConcurrent() {
+		htPool.Put(w.ht)
+		size := w.frame.Descriptor.Flags.BlockSizeIndex()
+		size.put(w.data)
 	}
-	if z.NoChecksum {
-		return nil
-	}
-	checksum := z.checksum.Sum32()
-	if debugFlag {
-		debug("stream checksum %x", checksum)
-	}
-	return z.writeUint32(checksum)
-}
-
-// Reset clears the state of the Writer z such that it is equivalent to its
-// initial state from NewWriter, but instead writing to w.
-// No access to the underlying io.Writer is performed.
-func (z *Writer) Reset(w io.Writer) {
-	n := cap(z.c)
-	_ = z.close()
-	z.freeBuffers()
-	z.Header.Reset()
-	z.dst = w
-	z.checksum.Reset()
-	z.idx = 0
-	z.err = nil
-	z.WithConcurrency(n)
+	return w.frame.closeW(w)
 }
 
-// writeUint32 writes a uint32 to the underlying writer.
-func (z *Writer) writeUint32(x uint32) error {
-	buf := z.buf[:4]
-	binary.LittleEndian.PutUint32(buf, x)
-	_, err := z.dst.Write(buf)
-	return err
-}
-
-// writerCompressBlock compresses data into a pooled buffer and writes its result
-// out to the input channel.
-func writerCompressBlock(c chan zResult, header Header, data []byte) {
-	zdata := getBuffer(header.BlockMaxSize)
-	// The compressed block size cannot exceed the input's.
-	var zn int
-	if level := header.CompressionLevel; level != 0 {
-		zn, _ = CompressBlockHC(data, zdata, level)
-	} else {
-		var hashTable [winSize]int
-		zn, _ = CompressBlock(data, zdata, hashTable[:])
+// Reset clears the state of the Writer w such that it is equivalent to its
+// initial state from NewWriter, but instead writing to writer.
+// Reset keeps the previous options unless overwritten by the supplied ones.
+// No access to writer is performed.
+//
+// w.Close must be called before Reset.
+func (w *_Writer) Reset(writer io.Writer, options ...Option) (err error) {
+	for _, o := range options {
+		if err = o(w); err != nil {
+			break
+		}
 	}
-	var res zResult
-	if zn > 0 && zn < len(data) {
-		res.size = uint32(zn)
-		res.data = zdata[:zn]
-	} else {
-		res.size = uint32(len(data)) | compressedBlockFlag
-		res.data = data
+	w.state.state = noState
+	if w.state.next(err) {
+		return
 	}
-	if header.BlockChecksum {
-		res.checksum = xxh32.ChecksumZero(res.data)
+	w.src = writer
+	w.frame.initW(w)
+	size := w.frame.Descriptor.Flags.BlockSizeIndex()
+	w.data = size.get()
+	w.idx = 0
+	if w.isNotConcurrent() {
+		w.ht = htPool.Get().([]int)
 	}
-	c <- res
+	return nil
 }

+ 12 - 14
writer_test.go

@@ -8,8 +8,6 @@ import (
 	"os"
 	"reflect"
 	"testing"
-
-	"github.com/pierrec/lz4"
 )
 
 func TestWriter(t *testing.T) {
@@ -26,7 +24,7 @@ func TestWriter(t *testing.T) {
 
 	for _, fname := range goldenFiles {
 		for _, size := range []int{0, 4} {
-			for _, header := range []lz4.Header{
+			for _, header := range []Header{
 				{}, // Default header.
 				{BlockChecksum: true},
 				{NoChecksum: true},
@@ -48,7 +46,7 @@ func TestWriter(t *testing.T) {
 
 					// Compress.
 					var zout bytes.Buffer
-					zw := lz4.NewWriter(&zout)
+					zw := NewWriter(&zout)
 					zw.Header = header
 					zw.WithConcurrency(size)
 					_, err = io.Copy(zw, r)
@@ -62,7 +60,7 @@ func TestWriter(t *testing.T) {
 
 					// Uncompress.
 					var out bytes.Buffer
-					zr := lz4.NewReader(&zout)
+					zr := NewReader(&zout)
 					n, err := io.Copy(&out, zr)
 					if err != nil {
 						t.Fatal(err)
@@ -84,8 +82,8 @@ func TestWriter(t *testing.T) {
 
 func TestIssue41(t *testing.T) {
 	r, w := io.Pipe()
-	zw := lz4.NewWriter(w)
-	zr := lz4.NewReader(r)
+	zw := NewWriter(w)
+	zr := NewReader(r)
 
 	data := "x"
 	go func() {
@@ -112,7 +110,7 @@ func TestIssue43(t *testing.T) {
 		}
 		defer f.Close()
 
-		zw := lz4.NewWriter(w)
+		zw := NewWriter(w)
 		defer zw.Close()
 
 		_, err = io.Copy(zw, f)
@@ -120,7 +118,7 @@ func TestIssue43(t *testing.T) {
 			t.Fatal(err)
 		}
 	}()
-	_, err := io.Copy(ioutil.Discard, lz4.NewReader(r))
+	_, err := io.Copy(ioutil.Discard, NewReader(r))
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -135,14 +133,14 @@ func TestIssue51(t *testing.T) {
 	zbuf := make([]byte, 8192)
 	ht := make([]int, htSize)
 
-	n, err := lz4.CompressBlock(data, zbuf, ht)
+	n, err := CompressBlock(data, zbuf, ht)
 	if err != nil {
 		t.Fatal(err)
 	}
 	zbuf = zbuf[:n]
 
 	buf := make([]byte, 8192)
-	n, err = lz4.UncompressBlock(zbuf, buf)
+	n, err = UncompressBlock(zbuf, buf)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -159,11 +157,11 @@ func TestIssue71(t *testing.T) {
 	} {
 		t.Run(tc, func(t *testing.T) {
 			src := []byte(tc)
-			bound := lz4.CompressBlockBound(len(tc))
+			bound := CompressBlockBound(len(tc))
 
 			// Small buffer.
 			zSmall := make([]byte, bound-1)
-			n, err := lz4.CompressBlock(src, zSmall, nil)
+			n, err := CompressBlock(src, zSmall, nil)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -173,7 +171,7 @@ func TestIssue71(t *testing.T) {
 
 			// Large enough buffer.
 			zLarge := make([]byte, bound)
-			n, err = lz4.CompressBlock(src, zLarge, nil)
+			n, err = CompressBlock(src, zLarge, nil)
 			if err != nil {
 				t.Fatal(err)
 			}