Browse Source

Reuse lz4 and gzip readers

Use sync.Pool to reuse lz4 and gzip reader objects across
decompressions. lz4 in particular makes a large allocation per-reader,
so you spend all your time in GC if make a new reader per-message.

Benchmarking reading 500 messages/s with 3 consumers and 32
partitions, lz4 consumer CPU fell from ~120% to ~5%. gzip went from
~20% to ~5%.
Muir Manders 7 years ago
parent
commit
afe6b1d906
3 changed files with 69 additions and 65 deletions
  1. 63 0
      decompress.go
  2. 3 39
      message.go
  3. 3 26
      record_batch.go

+ 63 - 0
decompress.go

@@ -0,0 +1,63 @@
+package sarama
+
+import (
+	"bytes"
+	"compress/gzip"
+	"fmt"
+	"io/ioutil"
+	"sync"
+
+	"github.com/eapache/go-xerial-snappy"
+	"github.com/pierrec/lz4"
+)
+
+var (
+	lz4Pool = sync.Pool{
+		New: func() interface{} {
+			return lz4.NewReader(nil)
+		},
+	}
+
+	gzipPool sync.Pool
+)
+
+func decompress(cc CompressionCodec, data []byte) ([]byte, error) {
+	switch cc {
+	case CompressionNone:
+		return data, nil
+	case CompressionGZIP:
+		var (
+			err        error
+			reader     *gzip.Reader
+			readerIntf = gzipPool.Get()
+		)
+		if readerIntf != nil {
+			reader = readerIntf.(*gzip.Reader)
+		} else {
+			reader, err = gzip.NewReader(bytes.NewReader(data))
+			if err != nil {
+				return nil, err
+			}
+		}
+
+		defer gzipPool.Put(reader)
+
+		if err := reader.Reset(bytes.NewReader(data)); err != nil {
+			return nil, err
+		}
+
+		return ioutil.ReadAll(reader)
+	case CompressionSnappy:
+		return snappy.Decode(data)
+	case CompressionLZ4:
+		reader := lz4Pool.Get().(*lz4.Reader)
+		defer lz4Pool.Put(reader)
+
+		reader.Reset(bytes.NewReader(data))
+		return ioutil.ReadAll(reader)
+	case CompressionZSTD:
+		return zstdDecompress(nil, data)
+	default:
+		return nil, PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", cc)}
+	}
+}

+ 3 - 39
message.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"compress/gzip"
 	"fmt"
-	"io/ioutil"
 	"time"
 
 	"github.com/eapache/go-xerial-snappy"
@@ -179,53 +178,18 @@ func (m *Message) decode(pd packetDecoder) (err error) {
 	switch m.Codec {
 	case CompressionNone:
 		// nothing to do
-	case CompressionGZIP:
+	default:
 		if m.Value == nil {
 			break
 		}
-		reader, err := gzip.NewReader(bytes.NewReader(m.Value))
+
+		m.Value, err = decompress(m.Codec, m.Value)
 		if err != nil {
 			return err
 		}
-		if m.Value, err = ioutil.ReadAll(reader); err != nil {
-			return err
-		}
 		if err := m.decodeSet(); err != nil {
 			return err
 		}
-	case CompressionSnappy:
-		if m.Value == nil {
-			break
-		}
-		if m.Value, err = snappy.Decode(m.Value); err != nil {
-			return err
-		}
-		if err := m.decodeSet(); err != nil {
-			return err
-		}
-	case CompressionLZ4:
-		if m.Value == nil {
-			break
-		}
-		reader := lz4.NewReader(bytes.NewReader(m.Value))
-		if m.Value, err = ioutil.ReadAll(reader); err != nil {
-			return err
-		}
-		if err := m.decodeSet(); err != nil {
-			return err
-		}
-	case CompressionZSTD:
-		if m.Value == nil {
-			break
-		}
-		if m.Value, err = zstdDecompress(nil, m.Value); err != nil {
-			return err
-		}
-		if err := m.decodeSet(); err != nil {
-			return err
-		}
-	default:
-		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", m.Codec)}
 	}
 
 	return pd.pop()

+ 3 - 26
record_batch.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"compress/gzip"
 	"fmt"
-	"io/ioutil"
 	"time"
 
 	"github.com/eapache/go-xerial-snappy"
@@ -174,31 +173,9 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) {
 		return err
 	}
 
-	switch b.Codec {
-	case CompressionNone:
-	case CompressionGZIP:
-		reader, err := gzip.NewReader(bytes.NewReader(recBuffer))
-		if err != nil {
-			return err
-		}
-		if recBuffer, err = ioutil.ReadAll(reader); err != nil {
-			return err
-		}
-	case CompressionSnappy:
-		if recBuffer, err = snappy.Decode(recBuffer); err != nil {
-			return err
-		}
-	case CompressionLZ4:
-		reader := lz4.NewReader(bytes.NewReader(recBuffer))
-		if recBuffer, err = ioutil.ReadAll(reader); err != nil {
-			return err
-		}
-	case CompressionZSTD:
-		if recBuffer, err = zstdDecompress(nil, recBuffer); err != nil {
-			return err
-		}
-	default:
-		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)}
+	recBuffer, err = decompress(b.Codec, recBuffer)
+	if err != nil {
+		return err
 	}
 
 	b.recordsLen = len(recBuffer)