Browse Source

Handle invalid buffers being put back into the Block pool

Pierre.Curto 5 years ago
parent
commit
4b88d85f64
4 changed files with 26 additions and 10 deletions
  1. 14 5
      internal/lz4block/blocks.go
  2. 8 5
      internal/lz4stream/frame.go
  3. 2 0
      reader_test.go
  4. 2 0
      writer_test.go

+ 14 - 5
internal/lz4block/blocks.go

@@ -61,15 +61,24 @@ func (b BlockSizeIndex) Get() []byte {
 }
 }
 
 
 func (b BlockSizeIndex) Put(buf []byte) {
 func (b BlockSizeIndex) Put(buf []byte) {
-	switch b {
+	// Safeguard: do not allow invalid buffers.
+	switch c := uint32(cap(buf)); b {
 	case 4:
 	case 4:
-		BlockPool64K.Put(buf)
+		if c == Block64Kb {
+			BlockPool64K.Put(buf[:c])
+		}
 	case 5:
 	case 5:
-		BlockPool256K.Put(buf)
+		if c == Block256Kb {
+			BlockPool256K.Put(buf[:c])
+		}
 	case 6:
 	case 6:
-		BlockPool1M.Put(buf)
+		if c == Block1Mb {
+			BlockPool1M.Put(buf[:c])
+		}
 	case 7:
 	case 7:
-		BlockPool4M.Put(buf)
+		if c == Block4Mb {
+			BlockPool4M.Put(buf[:c])
+		}
 	}
 	}
 }
 }
 
 

+ 8 - 5
internal/lz4stream/frame.go

@@ -260,11 +260,14 @@ type FrameDataBlock struct {
 }
 }
 
 
 func (b *FrameDataBlock) CloseW(f *Frame) {
 func (b *FrameDataBlock) CloseW(f *Frame) {
-	size := f.Descriptor.Flags.BlockSizeIndex()
-	size.Put(b.data)
-	b.Data = nil
-	b.data = nil
-	b.src = nil
+	if b.data != nil {
+		// Block was not already closed.
+		size := f.Descriptor.Flags.BlockSizeIndex()
+		size.Put(b.data)
+		b.Data = nil
+		b.data = nil
+		b.src = nil
+	}
 }
 }
 
 
 // Block compression errors are ignored since the buffer is sized appropriately.
 // Block compression errors are ignored since the buffer is sized appropriately.

+ 2 - 0
reader_test.go

@@ -91,6 +91,8 @@ func TestReader_Reset(t *testing.T) {
 
 
 	buf.Reset()
 	buf.Reset()
 	src.Reset(data)
 	src.Reset(data)
+	// Another time to maybe trigger some edge case.
+	src.Reset(data)
 	zr.Reset(src)
 	zr.Reset(src)
 	if _, err := io.Copy(buf, zr); err != nil {
 	if _, err := io.Copy(buf, zr); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)

+ 2 - 0
writer_test.go

@@ -98,6 +98,8 @@ func TestWriter_Reset(t *testing.T) {
 	buf.Reset()
 	buf.Reset()
 	src.Reset(data)
 	src.Reset(data)
 	zw.Reset(buf)
 	zw.Reset(buf)
+	zw.Reset(buf)
+	// Another time to maybe trigger some edge case.
 	if _, err := io.Copy(zw, src); err != nil {
 	if _, err := io.Copy(zw, src); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}