Browse Source

Merge pull request #1170 from bobrik/zstd

Zstd support for compression
Vlad Gorodetsky 6 years ago
parent
commit
6bfdd61781
3 changed files with 81 additions and 4 deletions
  1. 26 4
      message.go
  2. 44 0
      message_test.go
  3. 11 0
      record_batch.go

+ 26 - 4
message.go

@@ -7,6 +7,7 @@ import (
 	"io/ioutil"
 	"time"
 
+	"github.com/DataDog/zstd"
 	"github.com/eapache/go-xerial-snappy"
 	"github.com/pierrec/lz4"
 )
@@ -14,14 +15,15 @@ import (
 // CompressionCodec represents the various compression codecs recognized by Kafka in messages.
 type CompressionCodec int8
 
-// only the last two bits are really used
-const compressionCodecMask int8 = 0x03
+// The lowest 3 bits contain the compression codec used for the message
+const compressionCodecMask int8 = 0x07
 
 const (
 	CompressionNone   CompressionCodec = 0
 	CompressionGZIP   CompressionCodec = 1
 	CompressionSnappy CompressionCodec = 2
 	CompressionLZ4    CompressionCodec = 3
+	CompressionZSTD   CompressionCodec = 4
 )
 
 func (cc CompressionCodec) String() string {
@@ -113,7 +115,18 @@ func (m *Message) encode(pe packetEncoder) error {
 			}
 			m.compressedCache = buf.Bytes()
 			payload = m.compressedCache
-
+		case CompressionZSTD:
+			if len(m.Value) == 0 {
+				// Hardcoded empty ZSTD frame, see: https://github.com/DataDog/zstd/issues/41
+				m.compressedCache = []byte{0x28, 0xb5, 0x2f, 0xfd, 0x24, 0x00, 0x01, 0x00, 0x00, 0x99, 0xe9, 0xd8, 0x51}
+			} else {
+				c, err := zstd.CompressLevel(nil, m.Value, m.CompressionLevel)
+				if err != nil {
+					return err
+				}
+				m.compressedCache = c
+			}
+			payload = m.compressedCache
 		default:
 			return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", m.Codec)}
 		}
@@ -207,7 +220,16 @@ func (m *Message) decode(pd packetDecoder) (err error) {
 		if err := m.decodeSet(); err != nil {
 			return err
 		}
-
+	case CompressionZSTD:
+		if m.Value == nil {
+			break
+		}
+		if m.Value, err = zstd.Decompress(nil, m.Value); err != nil {
+			return err
+		}
+		if err := m.decodeSet(); err != nil {
+			return err
+		}
 	default:
 		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", m.Codec)}
 	}

+ 44 - 0
message_test.go

@@ -52,6 +52,17 @@ var (
 		5, 93, 204, 2, // LZ4 checksum
 	}
 
+	emptyZSTDMessage = []byte{
+		252, 62, 137, 23, // CRC
+		0x01,                          // version byte
+		0x04,                          // attribute flags: lz4
+		0, 0, 1, 88, 141, 205, 89, 56, // timestamp
+		0xFF, 0xFF, 0xFF, 0xFF, // key
+		0x00, 0x00, 0x00, 0x0d, // len
+		// ZSTD data
+		0x28, 0xb5, 0x2f, 0xfd, 0x24, 0x00, 0x01, 0x00, 0x00, 0x99, 0xe9, 0xd8, 0x51,
+	}
+
 	emptyBulkSnappyMessage = []byte{
 		180, 47, 53, 209, //CRC
 		0x00,                   // magic version byte
@@ -86,6 +97,17 @@ var (
 		112, 185, 52, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 121, 87, 72, 224, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 14, 121, 87, 72, 224, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0,
 		71, 129, 23, 111, // LZ4 checksum
 	}
+
+	emptyBulkZSTDMessage = []byte{
+		203, 151, 133, 28, // CRC
+		0x01,                                  // Version
+		0x04,                                  // attribute flags (ZSTD)
+		255, 255, 249, 209, 212, 181, 73, 201, // timestamp
+		0xFF, 0xFF, 0xFF, 0xFF, // key
+		0x00, 0x00, 0x00, 0x26, // len
+		// ZSTD data
+		0x28, 0xb5, 0x2f, 0xfd, 0x24, 0x34, 0xcd, 0x0, 0x0, 0x78, 0x0, 0x0, 0xe, 0x79, 0x57, 0x48, 0xe0, 0x0, 0x0, 0xff, 0xff, 0xff, 0xff, 0x0, 0x1, 0x3, 0x0, 0x3d, 0xbd, 0x0, 0x3b, 0x15, 0x0, 0xb, 0xd2, 0x34, 0xc1, 0x78,
+	}
 )
 
 func TestMessageEncoding(t *testing.T) {
@@ -101,6 +123,12 @@ func TestMessageEncoding(t *testing.T) {
 	message.Timestamp = time.Unix(1479847795, 0)
 	message.Version = 1
 	testEncodable(t, "empty lz4", &message, emptyLZ4Message)
+
+	message.Value = []byte{}
+	message.Codec = CompressionZSTD
+	message.Timestamp = time.Unix(1479847795, 0)
+	message.Version = 1
+	testEncodable(t, "empty zstd", &message, emptyZSTDMessage)
 }
 
 func TestMessageDecoding(t *testing.T) {
@@ -179,6 +207,22 @@ func TestMessageDecodingBulkLZ4(t *testing.T) {
 	}
 }
 
+func TestMessageDecodingBulkZSTD(t *testing.T) {
+	message := Message{}
+	testDecodable(t, "bulk zstd", &message, emptyBulkZSTDMessage)
+	if message.Codec != CompressionZSTD {
+		t.Errorf("Decoding produced codec %d, but expected %d.", message.Codec, CompressionZSTD)
+	}
+	if message.Key != nil {
+		t.Errorf("Decoding produced key %+v, but none was expected.", message.Key)
+	}
+	if message.Set == nil {
+		t.Error("Decoding produced no set, but one was expected.")
+	} else if len(message.Set.Messages) != 2 {
+		t.Errorf("Decoding produced a set with %d messages, but 2 were expected.", len(message.Set.Messages))
+	}
+}
+
 func TestMessageDecodingVersion1(t *testing.T) {
 	message := Message{Version: 1}
 	testDecodable(t, "decoding empty v1 message", &message, emptyV1Message)

+ 11 - 0
record_batch.go

@@ -7,6 +7,7 @@ import (
 	"io/ioutil"
 	"time"
 
+	"github.com/DataDog/zstd"
 	"github.com/eapache/go-xerial-snappy"
 	"github.com/pierrec/lz4"
 )
@@ -193,6 +194,10 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) {
 		if recBuffer, err = ioutil.ReadAll(reader); err != nil {
 			return err
 		}
+	case CompressionZSTD:
+		if recBuffer, err = zstd.Decompress(nil, recBuffer); err != nil {
+			return err
+		}
 	default:
 		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)}
 	}
@@ -248,6 +253,12 @@ func (b *RecordBatch) encodeRecords(pe packetEncoder) error {
 			return err
 		}
 		b.compressedRecords = buf.Bytes()
+	case CompressionZSTD:
+		c, err := zstd.CompressLevel(nil, raw, b.CompressionLevel)
+		if err != nil {
+			return err
+		}
+		b.compressedRecords = c
 	default:
 		return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
 	}