Browse Source

Writer: added concurrency support. Fixes #55

Pierre Curto 6 years ago
parent
commit
9085dacd1e
4 changed files with 157 additions and 140 deletions
  1. 3 2
      cmd/lz4c/compress.go
  2. 6 0
      lz4.go
  3. 95 88
      writer.go
  4. 53 50
      writer_test.go

+ 3 - 2
cmd/lz4c/compress.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"io"
 	"os"
+	"sync/atomic"
 
 	"code.cloudfoundry.org/bytefmt"
 	"github.com/schollz/progressbar"
@@ -65,7 +66,7 @@ func Compress(fs *flag.FlagSet) cmdflag.Handler {
 
 			// Accumulate compressed bytes num.
 			var (
-				zsize int
+				zsize int64
 				size  = finfo.Size()
 			)
 			if size > 0 {
@@ -80,7 +81,7 @@ func Compress(fs *flag.FlagSet) cmdflag.Handler {
 				)
 				zw.OnBlockDone = func(n int) {
 					_ = bar.Add(1)
-					zsize += n
+					atomic.AddInt64(&zsize, int64(n))
 				}
 			}
 

+ 6 - 0
lz4.go

@@ -68,6 +68,12 @@ func newBufferPool(size int) *sync.Pool {
 	}
 }
 
+// getBuffer returns a buffer to its pool.
+func getBuffer(size int) []byte {
+	idx := blockSizeValueToIndex(size) - 4
+	return bsMapValue[idx].Get().([]byte)
+}
+
 // putBuffer returns a buffer to its pool.
 func putBuffer(size int, buf []byte) {
 	if cap(buf) > 0 {

+ 95 - 88
writer.go

@@ -40,7 +40,6 @@ type Writer struct {
 // It is ok to change it before the first Write but then not until a Reset() is performed.
 func NewWriter(dst io.Writer) *Writer {
 	z := new(Writer)
-	//z.WithConcurrency(4)
 	z.Reset(dst)
 	return z
 }
@@ -62,6 +61,8 @@ func (z *Writer) WithConcurrency(n int) *Writer {
 		for c := range z.c {
 			// Read the next compressed block result.
 			// Waiting here ensures that the blocks are output in the order they were sent.
+			// The incoming channel is always closed as it indicates to the caller that
+			// the block has been processed.
 			res := <-c
 			n := len(res.data)
 			if n == 0 {
@@ -82,9 +83,14 @@ func (z *Writer) WithConcurrency(n int) *Writer {
 					z.err = err
 				}
 			}
+			if isCompressed := res.size&compressedBlockFlag == 0; isCompressed {
+				// It is now safe to release the buffer as no longer in use by any goroutine.
+				putBuffer(cap(res.data), res.data)
+			}
 			if h := z.OnBlockDone; h != nil {
 				h(n)
 			}
+			close(c)
 		}
 	}()
 	return z
@@ -94,8 +100,7 @@ func (z *Writer) WithConcurrency(n int) *Writer {
 // The returned buffers are for decompression and compression respectively.
 func (z *Writer) newBuffers() {
 	bSize := z.Header.BlockMaxSize
-	idx := blockSizeValueToIndex(bSize) - 4
-	buf := bsMapValue[idx].Get().([]byte)
+	buf := getBuffer(bSize)
 	z.data = buf[:bSize] // Uncompressed buffer is the first half.
 }
 
@@ -219,93 +224,67 @@ func (z *Writer) Write(buf []byte) (int, error) {
 // compressBlock compresses a block.
 func (z *Writer) compressBlock(data []byte) error {
 	if !z.NoChecksum {
-		z.checksum.Write(data)
+		_, _ = z.checksum.Write(data)
+	}
+
+	if z.c != nil {
+		c := make(chan zResult)
+		z.c <- c // Send now to guarantee order
+		go writerCompressBlock(c, z.Header, data)
+		return nil
 	}
 
 	zdata := z.data[z.Header.BlockMaxSize:cap(z.data)]
-	if z.c == nil {
-		// The compressed block size cannot exceed the input's.
-		var zn int
+	// The compressed block size cannot exceed the input's.
+	var zn int
 
-		if level := z.Header.CompressionLevel; level != 0 {
-			zn, _ = CompressBlockHC(data, zdata, level)
-		} else {
-			zn, _ = CompressBlock(data, zdata, z.hashtable[:])
-		}
+	if level := z.Header.CompressionLevel; level != 0 {
+		zn, _ = CompressBlockHC(data, zdata, level)
+	} else {
+		zn, _ = CompressBlock(data, zdata, z.hashtable[:])
+	}
 
-		var bLen uint32
-		if debugFlag {
-			debug("block compression %d => %d", len(data), zn)
-		}
-		if zn > 0 && zn < len(data) {
-			// Compressible and compressed size smaller than uncompressed: ok!
-			bLen = uint32(zn)
-			zdata = zdata[:zn]
-		} else {
-			// Uncompressed block.
-			bLen = uint32(len(data)) | compressedBlockFlag
-			zdata = data
-		}
-		if debugFlag {
-			debug("block compression to be written len=%d data len=%d", bLen, len(zdata))
-		}
+	var bLen uint32
+	if debugFlag {
+		debug("block compression %d => %d", len(data), zn)
+	}
+	if zn > 0 && zn < len(data) {
+		// Compressible and compressed size smaller than uncompressed: ok!
+		bLen = uint32(zn)
+		zdata = zdata[:zn]
+	} else {
+		// Uncompressed block.
+		bLen = uint32(len(data)) | compressedBlockFlag
+		zdata = data
+	}
+	if debugFlag {
+		debug("block compression to be written len=%d data len=%d", bLen, len(zdata))
+	}
 
-		// Write the block.
-		if err := z.writeUint32(bLen); err != nil {
-			return err
-		}
-		written, err := z.dst.Write(zdata)
-		if err != nil {
-			return err
-		}
-		if h := z.OnBlockDone; h != nil {
-			h(written)
-		}
+	// Write the block.
+	if err := z.writeUint32(bLen); err != nil {
+		return err
+	}
+	written, err := z.dst.Write(zdata)
+	if err != nil {
+		return err
+	}
+	if h := z.OnBlockDone; h != nil {
+		h(written)
+	}
 
-		if !z.BlockChecksum {
-			if debugFlag {
-				debug("current frame checksum %x", z.checksum.Sum32())
-			}
-			return nil
-		}
-		checksum := xxh32.ChecksumZero(zdata)
+	if !z.BlockChecksum {
 		if debugFlag {
-			debug("block checksum %x", checksum)
-			defer func() { debug("current frame checksum %x", z.checksum.Sum32()) }()
+			debug("current frame checksum %x", z.checksum.Sum32())
 		}
-		return z.writeUint32(checksum)
+		return nil
 	}
-
-	odata := z.data
-	z.newBuffers()
-	c := make(chan zResult)
-	z.c <- c // Send now to guarantee order
-	go func(header Header) {
-		// The compressed block size cannot exceed the input's.
-		var zn int
-		if level := header.CompressionLevel; level != 0 {
-			zn, _ = CompressBlockHC(data, zdata, level)
-		} else {
-			var hashTable [winSize]int
-			zn, _ = CompressBlock(data, zdata, hashTable[:])
-		}
-		var res zResult
-		if zn > 0 && zn < len(data) {
-			// Compressible and compressed size smaller than uncompressed: ok!
-			res.size = uint32(zn)
-			res.data = zdata[:zn]
-		} else {
-			// Uncompressed block.
-			res.size = uint32(len(data)) | compressedBlockFlag
-			res.data = data
-		}
-		if header.BlockChecksum {
-			res.checksum = xxh32.ChecksumZero(res.data)
-		}
-		c <- res
-		putBuffer(header.BlockMaxSize, odata)
-	}(z.Header)
-	return nil
+	checksum := xxh32.ChecksumZero(zdata)
+	if debugFlag {
+		debug("block checksum %x", checksum)
+		defer func() { debug("current frame checksum %x", z.checksum.Sum32()) }()
+	}
+	return z.writeUint32(checksum)
 }
 
 // Flush flushes any pending compressed data to the underlying writer.
@@ -319,15 +298,17 @@ func (z *Writer) Flush() error {
 		return nil
 	}
 
-	// Disable concurrency for now.
-	c := z.c
-	z.c = nil
-	if err := z.compressBlock(z.data[:z.idx]); err != nil {
-		return err
-	}
-	z.c = c // Restore concurrency.
-
+	data := z.data[:z.idx]
 	z.idx = 0
+	if z.c == nil {
+		return z.compressBlock(data)
+	}
+	if !z.NoChecksum {
+		_, _ = z.checksum.Write(data)
+	}
+	c := make(chan zResult)
+	z.c <- c
+	writerCompressBlock(c, z.Header, data)
 	return nil
 }
 
@@ -399,3 +380,29 @@ func (z *Writer) writeUint32(x uint32) error {
 	_, err := z.dst.Write(buf)
 	return err
 }
+
+// writerCompressBlock compresses data into a pooled buffer and writes its result
+// out to the input channel.
+func writerCompressBlock(c chan zResult, header Header, data []byte) {
+	zdata := getBuffer(header.BlockMaxSize)
+	// The compressed block size cannot exceed the input's.
+	var zn int
+	if level := header.CompressionLevel; level != 0 {
+		zn, _ = CompressBlockHC(data, zdata, level)
+	} else {
+		var hashTable [winSize]int
+		zn, _ = CompressBlock(data, zdata, hashTable[:])
+	}
+	var res zResult
+	if zn > 0 && zn < len(data) {
+		res.size = uint32(zn)
+		res.data = zdata[:zn]
+	} else {
+		res.size = uint32(len(data)) | compressedBlockFlag
+		res.data = data
+	}
+	if header.BlockChecksum {
+		res.checksum = xxh32.ChecksumZero(res.data)
+	}
+	c <- res
+}

+ 53 - 50
writer_test.go

@@ -25,56 +25,59 @@ func TestWriter(t *testing.T) {
 	}
 
 	for _, fname := range goldenFiles {
-		for _, header := range []lz4.Header{
-			{}, // Default header.
-			{BlockChecksum: true},
-			{NoChecksum: true},
-			{BlockMaxSize: 64 << 10}, // 64Kb
-			{CompressionLevel: 10},
-			{Size: 123},
-		} {
-			label := fmt.Sprintf("%s/%s", fname, header)
-			t.Run(label, func(t *testing.T) {
-				fname := fname
-				header := header
-				t.Parallel()
-
-				raw, err := ioutil.ReadFile(fname)
-				if err != nil {
-					t.Fatal(err)
-				}
-				r := bytes.NewReader(raw)
-
-				// Compress.
-				var zout bytes.Buffer
-				zw := lz4.NewWriter(&zout)
-				zw.Header = header
-				_, err = io.Copy(zw, r)
-				if err != nil {
-					t.Fatal(err)
-				}
-				err = zw.Close()
-				if err != nil {
-					t.Fatal(err)
-				}
-
-				// Uncompress.
-				var out bytes.Buffer
-				zr := lz4.NewReader(&zout)
-				n, err := io.Copy(&out, zr)
-				if err != nil {
-					t.Fatal(err)
-				}
-
-				// The uncompressed data must be the same as the initial input.
-				if got, want := int(n), len(raw); got != want {
-					t.Errorf("invalid sizes: got %d; want %d", got, want)
-				}
-
-				if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) {
-					t.Fatal("uncompressed data does not match original")
-				}
-			})
+		for _, size := range []int{0, 4} {
+			for _, header := range []lz4.Header{
+				{}, // Default header.
+				{BlockChecksum: true},
+				{NoChecksum: true},
+				{BlockMaxSize: 64 << 10}, // 64Kb
+				{CompressionLevel: 10},
+				{Size: 123},
+			} {
+				label := fmt.Sprintf("%s/%s", fname, header)
+				t.Run(label, func(t *testing.T) {
+					fname := fname
+					header := header
+					t.Parallel()
+
+					raw, err := ioutil.ReadFile(fname)
+					if err != nil {
+						t.Fatal(err)
+					}
+					r := bytes.NewReader(raw)
+
+					// Compress.
+					var zout bytes.Buffer
+					zw := lz4.NewWriter(&zout)
+					zw.Header = header
+					zw.WithConcurrency(size)
+					_, err = io.Copy(zw, r)
+					if err != nil {
+						t.Fatal(err)
+					}
+					err = zw.Close()
+					if err != nil {
+						t.Fatal(err)
+					}
+
+					// Uncompress.
+					var out bytes.Buffer
+					zr := lz4.NewReader(&zout)
+					n, err := io.Copy(&out, zr)
+					if err != nil {
+						t.Fatal(err)
+					}
+
+					// The uncompressed data must be the same as the initial input.
+					if got, want := int(n), len(raw); got != want {
+						t.Errorf("invalid sizes: got %d; want %d", got, want)
+					}
+
+					if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) {
+						t.Fatal("uncompressed data does not match original")
+					}
+				})
+			}
 		}
 	}
 }