Browse Source

fixed concurrent Writer

Pierre.Curto 5 years ago
parent
commit
98c91cba6b
3 changed files with 39 additions and 34 deletions
  1. 18 12
      internal/lz4stream/frame.go
  2. 20 21
      writer.go
  3. 1 1
      writer_test.go

+ 18 - 12
internal/lz4stream/frame.go

@@ -177,6 +177,7 @@ func (b *Blocks) initW(f *Frame, dst io.Writer, num int) {
 		b.Block = NewFrameDataBlock(size)
 		return
 	}
+	b.Block = nil
 	if cap(b.Blocks) != num {
 		b.Blocks = make(chan chan *FrameDataBlock, num)
 	}
@@ -211,8 +212,7 @@ func (b *Blocks) initW(f *Frame, dst io.Writer, num int) {
 
 func (b *Blocks) closeW(f *Frame, num int) error {
 	if num == 1 {
-		b.Block.closeW(f)
-		b.Block = nil
+		b.Block.CloseW(f)
 		return nil
 	}
 	c := make(chan *FrameDataBlock)
@@ -235,18 +235,24 @@ func NewFrameDataBlock(size lz4block.BlockSizeIndex) *FrameDataBlock {
 
 type FrameDataBlock struct {
 	Size     DataBlockSize
-	Data     []byte
+	Data     []byte // (un)compressed data
 	Checksum uint32
+	ref      []byte
+	src      []byte // uncompressed data
 }
 
-func (b *FrameDataBlock) closeW(f *Frame) {
+func (b *FrameDataBlock) CloseW(f *Frame) {
 	size := f.Descriptor.Flags.BlockSizeIndex()
-	size.Put(b.Data)
+	size.Put(b.ref)
+	b.Data = nil
+	b.ref = nil
+	b.src = nil
 }
 
 // Block compression errors are ignored since the buffer is sized appropriately.
 func (b *FrameDataBlock) Compress(f *Frame, src []byte, ht []int, level lz4block.CompressionLevel) *FrameDataBlock {
 	data := b.Data[:len(src)] // trigger the incompressible flag in CompressBlock
+	b.ref = data              // keep track of the allocated buffer so that it can be freed
 	var n int
 	switch level {
 	case lz4block.Fast:
@@ -256,24 +262,24 @@ func (b *FrameDataBlock) Compress(f *Frame, src []byte, ht []int, level lz4block
 	}
 	if n == 0 {
 		b.Size.UncompressedSet(true)
-		data = src
+		b.Data = src
 	} else {
 		b.Size.UncompressedSet(false)
-		data = data[:n]
+		b.Data = data[:n]
 	}
-	b.Data = data
-	b.Size.sizeSet(len(data))
+	b.Size.sizeSet(len(b.Data))
+	b.src = src // keep track of the source for content checksum
 
 	if f.Descriptor.Flags.BlockChecksum() {
 		b.Checksum = xxh32.ChecksumZero(src)
 	}
-	if f.Descriptor.Flags.ContentChecksum() {
-		_, _ = f.checksum.Write(src)
-	}
 	return b
 }
 
 func (b *FrameDataBlock) Write(f *Frame, dst io.Writer) error {
+	if f.Descriptor.Flags.ContentChecksum() {
+		_, _ = f.checksum.Write(b.src)
+	}
 	buf := f.buf[:]
 	binary.LittleEndian.PutUint32(buf, uint32(b.Size))
 	if _, err := dst.Write(buf[:4]); err != nil {

+ 20 - 21
writer.go

@@ -100,12 +100,16 @@ func (w *Writer) Write(buf []byte) (n int, err error) {
 		if err = w.write(w.data, true); err != nil {
 			return
 		}
+		if !w.isNotConcurrent() {
+			size := w.frame.Descriptor.Flags.BlockSizeIndex()
+			w.data = size.Get()
+		}
 		w.idx = 0
 	}
 	return
 }
 
-func (w *Writer) write(data []byte, direct bool) error {
+func (w *Writer) write(data []byte, safe bool) error {
 	if w.isNotConcurrent() {
 		defer w.handler(len(data))
 		block := w.frame.Blocks.Block
@@ -114,24 +118,17 @@ func (w *Writer) write(data []byte, direct bool) error {
 	size := w.frame.Descriptor.Flags.BlockSizeIndex()
 	c := make(chan *lz4stream.FrameDataBlock)
 	w.frame.Blocks.Blocks <- c
-	go func(c chan *lz4stream.FrameDataBlock, data []byte, size lz4block.BlockSizeIndex) {
-		defer w.handler(len(data))
+	go func(c chan *lz4stream.FrameDataBlock, data []byte, size lz4block.BlockSizeIndex, safe bool) {
 		b := lz4stream.NewFrameDataBlock(size)
-		zdata := b.Data
 		c <- b.Compress(w.frame, data, nil, w.level)
-		// Wait for the compressed or uncompressed data to no longer be in use
-		// and free the allocated buffers
-		if b.Size.Uncompressed() {
-			zdata, data = data, zdata
-		}
-		size.Put(data)
 		<-c
-		size.Put(zdata)
-	}(c, data, size)
-
-	if direct {
-		w.data = size.Get()
-	}
+		w.handler(len(data))
+		b.CloseW(w.frame)
+		if safe {
+			// safe to put it back as the last usage of it was FrameDataBlock.Write() called before c is closed
+			size.Put(data)
+		}
+	}(c, data, size, safe)
 
 	return nil
 }
@@ -148,19 +145,21 @@ func (w *Writer) Close() (err error) {
 	}
 	defer func() { w.state.next(err) }()
 	if w.idx > 0 {
-		// Flush pending data.
+		// Flush pending data, disable w.data freeing as it is done later on.
 		if err = w.write(w.data[:w.idx], false); err != nil {
 			return err
 		}
 		w.idx = 0
 	}
+	err = w.frame.CloseW(w.src, w.num)
 	if w.isNotConcurrent() {
 		lz4block.HashTablePool.Put(w.ht)
-		size := w.frame.Descriptor.Flags.BlockSizeIndex()
-		size.Put(w.data)
-		w.data = nil
 	}
-	return w.frame.CloseW(w.src, w.num)
+	// It is now safe to free the buffer.
+	size := w.frame.Descriptor.Flags.BlockSizeIndex()
+	size.Put(w.data)
+	w.data = nil
+	return
 }
 
 // Reset clears the state of the Writer w such that it is equivalent to its

+ 1 - 1
writer_test.go

@@ -31,7 +31,7 @@ func TestWriter(t *testing.T) {
 			lz4.ConcurrencyOption(1),
 			lz4.BlockChecksumOption(true),
 			lz4.SizeOption(123),
-			//lz4.ConcurrencyOption(2),
+			lz4.ConcurrencyOption(4),
 		} {
 			label := fmt.Sprintf("%s/%s", fname, option)
 			t.Run(label, func(t *testing.T) {