Browse Source

Merge pull request #1185 from retailnext/reuse-compression-objects

Reuse compression objects
Vlad Gorodetsky 7 years ago
parent
commit
94536b3e82
4 changed files with 151 additions and 165 deletions
  1. 75 0
      compress.go
  2. 63 0
      decompress.go
  3. 8 90
      message.go
  4. 5 75
      record_batch.go

+ 75 - 0
compress.go

@@ -0,0 +1,75 @@
+package sarama
+
+import (
+	"bytes"
+	"compress/gzip"
+	"fmt"
+	"sync"
+
+	"github.com/eapache/go-xerial-snappy"
+	"github.com/pierrec/lz4"
+)
+
+var (
+	lz4WriterPool = sync.Pool{
+		New: func() interface{} {
+			return lz4.NewWriter(nil)
+		},
+	}
+
+	gzipWriterPool = sync.Pool{
+		New: func() interface{} {
+			return gzip.NewWriter(nil)
+		},
+	}
+)
+
+func compress(cc CompressionCodec, level int, data []byte) ([]byte, error) {
+	switch cc {
+	case CompressionNone:
+		return data, nil
+	case CompressionGZIP:
+		var (
+			err    error
+			buf    bytes.Buffer
+			writer *gzip.Writer
+		)
+		if level != CompressionLevelDefault {
+			writer, err = gzip.NewWriterLevel(&buf, level)
+			if err != nil {
+				return nil, err
+			}
+		} else {
+			writer = gzipWriterPool.Get().(*gzip.Writer)
+			defer gzipWriterPool.Put(writer)
+			writer.Reset(&buf)
+		}
+		if _, err := writer.Write(data); err != nil {
+			return nil, err
+		}
+		if err := writer.Close(); err != nil {
+			return nil, err
+		}
+		return buf.Bytes(), nil
+	case CompressionSnappy:
+		return snappy.Encode(data), nil
+	case CompressionLZ4:
+		writer := lz4WriterPool.Get().(*lz4.Writer)
+		defer lz4WriterPool.Put(writer)
+
+		var buf bytes.Buffer
+		writer.Reset(&buf)
+
+		if _, err := writer.Write(data); err != nil {
+			return nil, err
+		}
+		if err := writer.Close(); err != nil {
+			return nil, err
+		}
+		return buf.Bytes(), nil
+	case CompressionZSTD:
+		return zstdCompressLevel(nil, data, level)
+	default:
+		return nil, PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", cc)}
+	}
+}

+ 63 - 0
decompress.go

@@ -0,0 +1,63 @@
+package sarama
+
+import (
+	"bytes"
+	"compress/gzip"
+	"fmt"
+	"io/ioutil"
+	"sync"
+
+	"github.com/eapache/go-xerial-snappy"
+	"github.com/pierrec/lz4"
+)
+
+var (
+	lz4ReaderPool = sync.Pool{
+		New: func() interface{} {
+			return lz4.NewReader(nil)
+		},
+	}
+
+	gzipReaderPool sync.Pool
+)
+
+func decompress(cc CompressionCodec, data []byte) ([]byte, error) {
+	switch cc {
+	case CompressionNone:
+		return data, nil
+	case CompressionGZIP:
+		var (
+			err        error
+			reader     *gzip.Reader
+			readerIntf = gzipReaderPool.Get()
+		)
+		if readerIntf != nil {
+			reader = readerIntf.(*gzip.Reader)
+		} else {
+			reader, err = gzip.NewReader(bytes.NewReader(data))
+			if err != nil {
+				return nil, err
+			}
+		}
+
+		defer gzipReaderPool.Put(reader)
+
+		if err := reader.Reset(bytes.NewReader(data)); err != nil {
+			return nil, err
+		}
+
+		return ioutil.ReadAll(reader)
+	case CompressionSnappy:
+		return snappy.Decode(data)
+	case CompressionLZ4:
+		reader := lz4ReaderPool.Get().(*lz4.Reader)
+		defer lz4ReaderPool.Put(reader)
+
+		reader.Reset(bytes.NewReader(data))
+		return ioutil.ReadAll(reader)
+	case CompressionZSTD:
+		return zstdDecompress(nil, data)
+	default:
+		return nil, PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", cc)}
+	}
+}

+ 8 - 90
message.go

@@ -1,14 +1,8 @@
 package sarama
 package sarama
 
 
 import (
 import (
-	"bytes"
-	"compress/gzip"
 	"fmt"
 	"fmt"
-	"io/ioutil"
 	"time"
 	"time"
-
-	"github.com/eapache/go-xerial-snappy"
-	"github.com/pierrec/lz4"
 )
 )
 
 
 // CompressionCodec represents the various compression codecs recognized by Kafka in messages.
 // CompressionCodec represents the various compression codecs recognized by Kafka in messages.
@@ -77,53 +71,12 @@ func (m *Message) encode(pe packetEncoder) error {
 		payload = m.compressedCache
 		payload = m.compressedCache
 		m.compressedCache = nil
 		m.compressedCache = nil
 	} else if m.Value != nil {
 	} else if m.Value != nil {
-		switch m.Codec {
-		case CompressionNone:
-			payload = m.Value
-		case CompressionGZIP:
-			var buf bytes.Buffer
-			var writer *gzip.Writer
-			if m.CompressionLevel != CompressionLevelDefault {
-				writer, err = gzip.NewWriterLevel(&buf, m.CompressionLevel)
-				if err != nil {
-					return err
-				}
-			} else {
-				writer = gzip.NewWriter(&buf)
-			}
-			if _, err = writer.Write(m.Value); err != nil {
-				return err
-			}
-			if err = writer.Close(); err != nil {
-				return err
-			}
-			m.compressedCache = buf.Bytes()
-			payload = m.compressedCache
-		case CompressionSnappy:
-			tmp := snappy.Encode(m.Value)
-			m.compressedCache = tmp
-			payload = m.compressedCache
-		case CompressionLZ4:
-			var buf bytes.Buffer
-			writer := lz4.NewWriter(&buf)
-			if _, err = writer.Write(m.Value); err != nil {
-				return err
-			}
-			if err = writer.Close(); err != nil {
-				return err
-			}
-			m.compressedCache = buf.Bytes()
-			payload = m.compressedCache
-		case CompressionZSTD:
-			c, err := zstdCompressLevel(nil, m.Value, m.CompressionLevel)
-			if err != nil {
-				return err
-			}
-			m.compressedCache = c
-			payload = m.compressedCache
-		default:
-			return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", m.Codec)}
+
+		payload, err = compress(m.Codec, m.CompressionLevel, m.Value)
+		if err != nil {
+			return err
 		}
 		}
+		m.compressedCache = payload
 		// Keep in mind the compressed payload size for metric gathering
 		// Keep in mind the compressed payload size for metric gathering
 		m.compressedSize = len(payload)
 		m.compressedSize = len(payload)
 	}
 	}
@@ -179,53 +132,18 @@ func (m *Message) decode(pd packetDecoder) (err error) {
 	switch m.Codec {
 	switch m.Codec {
 	case CompressionNone:
 	case CompressionNone:
 		// nothing to do
 		// nothing to do
-	case CompressionGZIP:
+	default:
 		if m.Value == nil {
 		if m.Value == nil {
 			break
 			break
 		}
 		}
-		reader, err := gzip.NewReader(bytes.NewReader(m.Value))
+
+		m.Value, err = decompress(m.Codec, m.Value)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		if m.Value, err = ioutil.ReadAll(reader); err != nil {
-			return err
-		}
 		if err := m.decodeSet(); err != nil {
 		if err := m.decodeSet(); err != nil {
 			return err
 			return err
 		}
 		}
-	case CompressionSnappy:
-		if m.Value == nil {
-			break
-		}
-		if m.Value, err = snappy.Decode(m.Value); err != nil {
-			return err
-		}
-		if err := m.decodeSet(); err != nil {
-			return err
-		}
-	case CompressionLZ4:
-		if m.Value == nil {
-			break
-		}
-		reader := lz4.NewReader(bytes.NewReader(m.Value))
-		if m.Value, err = ioutil.ReadAll(reader); err != nil {
-			return err
-		}
-		if err := m.decodeSet(); err != nil {
-			return err
-		}
-	case CompressionZSTD:
-		if m.Value == nil {
-			break
-		}
-		if m.Value, err = zstdDecompress(nil, m.Value); err != nil {
-			return err
-		}
-		if err := m.decodeSet(); err != nil {
-			return err
-		}
-	default:
-		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", m.Codec)}
 	}
 	}
 
 
 	return pd.pop()
 	return pd.pop()

+ 5 - 75
record_batch.go

@@ -1,14 +1,8 @@
 package sarama
 package sarama
 
 
 import (
 import (
-	"bytes"
-	"compress/gzip"
 	"fmt"
 	"fmt"
-	"io/ioutil"
 	"time"
 	"time"
-
-	"github.com/eapache/go-xerial-snappy"
-	"github.com/pierrec/lz4"
 )
 )
 
 
 const recordBatchOverhead = 49
 const recordBatchOverhead = 49
@@ -174,31 +168,9 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) {
 		return err
 		return err
 	}
 	}
 
 
-	switch b.Codec {
-	case CompressionNone:
-	case CompressionGZIP:
-		reader, err := gzip.NewReader(bytes.NewReader(recBuffer))
-		if err != nil {
-			return err
-		}
-		if recBuffer, err = ioutil.ReadAll(reader); err != nil {
-			return err
-		}
-	case CompressionSnappy:
-		if recBuffer, err = snappy.Decode(recBuffer); err != nil {
-			return err
-		}
-	case CompressionLZ4:
-		reader := lz4.NewReader(bytes.NewReader(recBuffer))
-		if recBuffer, err = ioutil.ReadAll(reader); err != nil {
-			return err
-		}
-	case CompressionZSTD:
-		if recBuffer, err = zstdDecompress(nil, recBuffer); err != nil {
-			return err
-		}
-	default:
-		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)}
+	recBuffer, err = decompress(b.Codec, recBuffer)
+	if err != nil {
+		return err
 	}
 	}
 
 
 	b.recordsLen = len(recBuffer)
 	b.recordsLen = len(recBuffer)
@@ -219,50 +191,8 @@ func (b *RecordBatch) encodeRecords(pe packetEncoder) error {
 	}
 	}
 	b.recordsLen = len(raw)
 	b.recordsLen = len(raw)
 
 
-	switch b.Codec {
-	case CompressionNone:
-		b.compressedRecords = raw
-	case CompressionGZIP:
-		var buf bytes.Buffer
-		var writer *gzip.Writer
-		if b.CompressionLevel != CompressionLevelDefault {
-			writer, err = gzip.NewWriterLevel(&buf, b.CompressionLevel)
-			if err != nil {
-				return err
-			}
-		} else {
-			writer = gzip.NewWriter(&buf)
-		}
-		if _, err := writer.Write(raw); err != nil {
-			return err
-		}
-		if err := writer.Close(); err != nil {
-			return err
-		}
-		b.compressedRecords = buf.Bytes()
-	case CompressionSnappy:
-		b.compressedRecords = snappy.Encode(raw)
-	case CompressionLZ4:
-		var buf bytes.Buffer
-		writer := lz4.NewWriter(&buf)
-		if _, err := writer.Write(raw); err != nil {
-			return err
-		}
-		if err := writer.Close(); err != nil {
-			return err
-		}
-		b.compressedRecords = buf.Bytes()
-	case CompressionZSTD:
-		c, err := zstdCompressLevel(nil, raw, b.CompressionLevel)
-		if err != nil {
-			return err
-		}
-		b.compressedRecords = c
-	default:
-		return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
-	}
-
-	return nil
+	b.compressedRecords, err = compress(b.Codec, b.CompressionLevel, raw)
+	return err
 }
 }
 
 
 func (b *RecordBatch) computeAttributes() int16 {
 func (b *RecordBatch) computeAttributes() int16 {