Browse Source

Merge pull request #973 from wladh/records

Add implementation of Kafka 0.11 Records
Evan Huus 8 years ago
parent
commit
eca6c1cfdb
10 changed files with 913 additions and 13 deletions
  1. 19 8
      length_field.go
  2. 9 0
      packet_decoder.go
  3. 11 0
      packet_encoder.go
  4. 9 0
      prep_encoder.go
  5. 12 5
      real_decoder.go
  6. 105 0
      record.go
  7. 251 0
      record_batch.go
  8. 264 0
      record_test.go
  9. 96 0
      records.go
  10. 137 0
      records_test.go

+ 19 - 8
length_field.go

@@ -33,24 +33,35 @@ type varintLengthField struct {
 	length      int64
 }
 
-func newVarintLengthField(pd packetDecoder) (*varintLengthField, error) {
-	n, err := pd.getVarint()
-	if err != nil {
-		return nil, err
-	}
-	return &varintLengthField{length: n}, nil
+func (l *varintLengthField) decode(pd packetDecoder) error {
+	var err error
+	l.length, err = pd.getVarint()
+	return err
 }
 
 func (l *varintLengthField) saveOffset(in int) {
 	l.startOffset = in
 }
 
+func (l *varintLengthField) adjustLength(currOffset int) int {
+	oldFieldSize := l.reserveLength()
+	l.length = int64(currOffset - l.startOffset - oldFieldSize)
+
+	return l.reserveLength() - oldFieldSize
+}
+
 func (l *varintLengthField) reserveLength() int {
-	return 0
+	var tmp [binary.MaxVarintLen64]byte
+	return binary.PutVarint(tmp[:], l.length)
+}
+
+func (l *varintLengthField) run(curOffset int, buf []byte) error {
+	binary.PutVarint(buf[l.startOffset:], l.length)
+	return nil
 }
 
 func (l *varintLengthField) check(curOffset int, buf []byte) error {
-	if int64(curOffset-l.startOffset) != l.length {
+	if int64(curOffset-l.startOffset-l.reserveLength()) != l.length {
 		return PacketDecodingError{"length field invalid"}
 	}
 

+ 9 - 0
packet_decoder.go

@@ -46,3 +46,12 @@ type pushDecoder interface {
 	// of data from the saved offset, and verify it based on the data between the saved offset and curOffset.
 	check(curOffset int, buf []byte) error
 }
+
+// dynamicPushDecoder extends the interface of pushDecoder for uses cases where the length of the
+// fields itself is unknown until its value was decoded (for instance varint encoded length
+// fields).
+// During push, dynamicPushDecoder.decode() method will be called instead of reserveLength()
+type dynamicPushDecoder interface {
+	pushDecoder
+	decoder
+}

+ 11 - 0
packet_encoder.go

@@ -50,3 +50,14 @@ type pushEncoder interface {
 	// of data to the saved offset, based on the data between the saved offset and curOffset.
 	run(curOffset int, buf []byte) error
 }
+
+// dynamicPushEncoder extends the interface of pushEncoder for uses cases where the length of the
+// fields itself is unknown until its value was computed (for instance varint encoded length
+// fields).
+type dynamicPushEncoder interface {
+	pushEncoder
+
+	// Called during pop() to adjust the length of the field.
+	// It should return the difference in bytes between the last computed length and current length.
+	adjustLength(currOffset int) int
+}

+ 9 - 0
prep_encoder.go

@@ -9,6 +9,7 @@ import (
 )
 
 type prepEncoder struct {
+	stack  []pushEncoder
 	length int
 }
 
@@ -119,10 +120,18 @@ func (pe *prepEncoder) offset() int {
 // stackable
 
 func (pe *prepEncoder) push(in pushEncoder) {
+	in.saveOffset(pe.length)
 	pe.length += in.reserveLength()
+	pe.stack = append(pe.stack, in)
 }
 
 func (pe *prepEncoder) pop() error {
+	in := pe.stack[len(pe.stack)-1]
+	pe.stack = pe.stack[:len(pe.stack)-1]
+	if dpe, ok := in.(dynamicPushEncoder); ok {
+		pe.length += dpe.adjustLength(pe.length)
+	}
+
 	return nil
 }
 

+ 12 - 5
real_decoder.go

@@ -79,7 +79,7 @@ func (rd *realDecoder) getArrayLength() (int, error) {
 		rd.off = len(rd.raw)
 		return -1, ErrInsufficientData
 	}
-	tmp := int(binary.BigEndian.Uint32(rd.raw[rd.off:]))
+	tmp := int(int32(binary.BigEndian.Uint32(rd.raw[rd.off:])))
 	rd.off += 4
 	if tmp > rd.remaining() {
 		rd.off = len(rd.raw)
@@ -260,10 +260,17 @@ func (rd *realDecoder) getRawBytes(length int) ([]byte, error) {
 func (rd *realDecoder) push(in pushDecoder) error {
 	in.saveOffset(rd.off)
 
-	reserve := in.reserveLength()
-	if rd.remaining() < reserve {
-		rd.off = len(rd.raw)
-		return ErrInsufficientData
+	var reserve int
+	if dpd, ok := in.(dynamicPushDecoder); ok {
+		if err := dpd.decode(rd); err != nil {
+			return err
+		}
+	} else {
+		reserve = in.reserveLength()
+		if rd.remaining() < reserve {
+			rd.off = len(rd.raw)
+			return ErrInsufficientData
+		}
 	}
 
 	rd.stack = append(rd.stack, in)

+ 105 - 0
record.go

@@ -0,0 +1,105 @@
+package sarama
+
+const (
+	controlMask = 0x20
+)
+
+type RecordHeader struct {
+	Key   []byte
+	Value []byte
+}
+
+func (h *RecordHeader) encode(pe packetEncoder) error {
+	if err := pe.putVarintBytes(h.Key); err != nil {
+		return err
+	}
+	return pe.putVarintBytes(h.Value)
+}
+
+func (h *RecordHeader) decode(pd packetDecoder) (err error) {
+	if h.Key, err = pd.getVarintBytes(); err != nil {
+		return err
+	}
+
+	if h.Value, err = pd.getVarintBytes(); err != nil {
+		return err
+	}
+	return nil
+}
+
+type Record struct {
+	Attributes     int8
+	TimestampDelta int64
+	OffsetDelta    int64
+	Key            []byte
+	Value          []byte
+	Headers        []*RecordHeader
+
+	length varintLengthField
+}
+
+func (r *Record) encode(pe packetEncoder) error {
+	pe.push(&r.length)
+	pe.putInt8(r.Attributes)
+	pe.putVarint(r.TimestampDelta)
+	pe.putVarint(r.OffsetDelta)
+	if err := pe.putVarintBytes(r.Key); err != nil {
+		return err
+	}
+	if err := pe.putVarintBytes(r.Value); err != nil {
+		return err
+	}
+	pe.putVarint(int64(len(r.Headers)))
+
+	for _, h := range r.Headers {
+		if err := h.encode(pe); err != nil {
+			return err
+		}
+	}
+
+	return pe.pop()
+}
+
+func (r *Record) decode(pd packetDecoder) (err error) {
+	if err = pd.push(&r.length); err != nil {
+		return err
+	}
+
+	if r.Attributes, err = pd.getInt8(); err != nil {
+		return err
+	}
+
+	if r.TimestampDelta, err = pd.getVarint(); err != nil {
+		return err
+	}
+
+	if r.OffsetDelta, err = pd.getVarint(); err != nil {
+		return err
+	}
+
+	if r.Key, err = pd.getVarintBytes(); err != nil {
+		return err
+	}
+
+	if r.Value, err = pd.getVarintBytes(); err != nil {
+		return err
+	}
+
+	numHeaders, err := pd.getVarint()
+	if err != nil {
+		return err
+	}
+
+	if numHeaders >= 0 {
+		r.Headers = make([]*RecordHeader, numHeaders)
+	}
+	for i := int64(0); i < numHeaders; i++ {
+		hdr := new(RecordHeader)
+		if err := hdr.decode(pd); err != nil {
+			return err
+		}
+		r.Headers[i] = hdr
+	}
+
+	return pd.pop()
+}

+ 251 - 0
record_batch.go

@@ -0,0 +1,251 @@
+package sarama
+
+import (
+	"bytes"
+	"compress/gzip"
+	"fmt"
+	"io/ioutil"
+
+	"github.com/eapache/go-xerial-snappy"
+	"github.com/pierrec/lz4"
+)
+
+const recordBatchOverhead = 49
+
+type recordsArray []*Record
+
+func (e recordsArray) encode(pe packetEncoder) error {
+	for _, r := range e {
+		if err := r.encode(pe); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (e recordsArray) decode(pd packetDecoder) error {
+	for i := range e {
+		rec := &Record{}
+		if err := rec.decode(pd); err != nil {
+			return err
+		}
+		e[i] = rec
+	}
+	return nil
+}
+
+type RecordBatch struct {
+	FirstOffset           int64
+	PartitionLeaderEpoch  int32
+	Version               int8
+	Codec                 CompressionCodec
+	Control               bool
+	LastOffsetDelta       int32
+	FirstTimestamp        int64
+	MaxTimestamp          int64
+	ProducerID            int64
+	ProducerEpoch         int16
+	FirstSequence         int32
+	Records               []*Record
+	PartialTrailingRecord bool
+
+	compressedRecords []byte
+	recordsLen        int // uncompressed records size
+}
+
+func (b *RecordBatch) encode(pe packetEncoder) error {
+	if b.Version != 2 {
+		return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
+	}
+	pe.putInt64(b.FirstOffset)
+	pe.push(&lengthField{})
+	pe.putInt32(b.PartitionLeaderEpoch)
+	pe.putInt8(b.Version)
+	pe.push(newCRC32Field(crcCastagnoli))
+	pe.putInt16(b.computeAttributes())
+	pe.putInt32(b.LastOffsetDelta)
+	pe.putInt64(b.FirstTimestamp)
+	pe.putInt64(b.MaxTimestamp)
+	pe.putInt64(b.ProducerID)
+	pe.putInt16(b.ProducerEpoch)
+	pe.putInt32(b.FirstSequence)
+
+	if err := pe.putArrayLength(len(b.Records)); err != nil {
+		return err
+	}
+
+	if b.compressedRecords == nil {
+		if err := b.encodeRecords(pe); err != nil {
+			return err
+		}
+	}
+	if err := pe.putRawBytes(b.compressedRecords); err != nil {
+		return err
+	}
+
+	if err := pe.pop(); err != nil {
+		return err
+	}
+	return pe.pop()
+}
+
+func (b *RecordBatch) decode(pd packetDecoder) (err error) {
+	if b.FirstOffset, err = pd.getInt64(); err != nil {
+		return err
+	}
+
+	batchLen, err := pd.getInt32()
+	if err != nil {
+		return err
+	}
+
+	if b.PartitionLeaderEpoch, err = pd.getInt32(); err != nil {
+		return err
+	}
+
+	if b.Version, err = pd.getInt8(); err != nil {
+		return err
+	}
+
+	if err = pd.push(&crc32Field{polynomial: crcCastagnoli}); err != nil {
+		return err
+	}
+
+	attributes, err := pd.getInt16()
+	if err != nil {
+		return err
+	}
+	b.Codec = CompressionCodec(int8(attributes) & compressionCodecMask)
+	b.Control = attributes&controlMask == controlMask
+
+	if b.LastOffsetDelta, err = pd.getInt32(); err != nil {
+		return err
+	}
+
+	if b.FirstTimestamp, err = pd.getInt64(); err != nil {
+		return err
+	}
+
+	if b.MaxTimestamp, err = pd.getInt64(); err != nil {
+		return err
+	}
+
+	if b.ProducerID, err = pd.getInt64(); err != nil {
+		return err
+	}
+
+	if b.ProducerEpoch, err = pd.getInt16(); err != nil {
+		return err
+	}
+
+	if b.FirstSequence, err = pd.getInt32(); err != nil {
+		return err
+	}
+
+	numRecs, err := pd.getArrayLength()
+	if err != nil {
+		return err
+	}
+	if numRecs >= 0 {
+		b.Records = make([]*Record, numRecs)
+	}
+
+	bufSize := int(batchLen) - recordBatchOverhead
+	recBuffer, err := pd.getRawBytes(bufSize)
+	if err != nil {
+		return err
+	}
+
+	if err = pd.pop(); err != nil {
+		return err
+	}
+
+	switch b.Codec {
+	case CompressionNone:
+	case CompressionGZIP:
+		reader, err := gzip.NewReader(bytes.NewReader(recBuffer))
+		if err != nil {
+			return err
+		}
+		if recBuffer, err = ioutil.ReadAll(reader); err != nil {
+			return err
+		}
+	case CompressionSnappy:
+		if recBuffer, err = snappy.Decode(recBuffer); err != nil {
+			return err
+		}
+	case CompressionLZ4:
+		reader := lz4.NewReader(bytes.NewReader(recBuffer))
+		if recBuffer, err = ioutil.ReadAll(reader); err != nil {
+			return err
+		}
+	default:
+		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)}
+	}
+
+	err = decode(recBuffer, recordsArray(b.Records))
+	if err == ErrInsufficientData {
+		b.PartialTrailingRecord = true
+		b.Records = nil
+		return nil
+	}
+	return err
+}
+
+func (b *RecordBatch) encodeRecords(pe packetEncoder) error {
+	var raw []byte
+	if b.Codec != CompressionNone {
+		var err error
+		if raw, err = encode(recordsArray(b.Records), nil); err != nil {
+			return err
+		}
+		b.recordsLen = len(raw)
+	}
+
+	switch b.Codec {
+	case CompressionNone:
+		offset := pe.offset()
+		if err := recordsArray(b.Records).encode(pe); err != nil {
+			return err
+		}
+		b.recordsLen = pe.offset() - offset
+	case CompressionGZIP:
+		var buf bytes.Buffer
+		writer := gzip.NewWriter(&buf)
+		if _, err := writer.Write(raw); err != nil {
+			return err
+		}
+		if err := writer.Close(); err != nil {
+			return err
+		}
+		b.compressedRecords = buf.Bytes()
+	case CompressionSnappy:
+		b.compressedRecords = snappy.Encode(raw)
+	case CompressionLZ4:
+		var buf bytes.Buffer
+		writer := lz4.NewWriter(&buf)
+		if _, err := writer.Write(raw); err != nil {
+			return err
+		}
+		if err := writer.Close(); err != nil {
+			return err
+		}
+		b.compressedRecords = buf.Bytes()
+	default:
+		return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
+	}
+
+	return nil
+}
+
+func (b *RecordBatch) computeAttributes() int16 {
+	attr := int16(b.Codec) & int16(compressionCodecMask)
+	if b.Control {
+		attr |= controlMask
+	}
+	return attr
+}
+
+func (b *RecordBatch) addRecord(r *Record) {
+	b.Records = append(b.Records, r)
+}

+ 264 - 0
record_test.go

@@ -0,0 +1,264 @@
+package sarama
+
+import (
+	"reflect"
+	"runtime"
+	"strconv"
+	"strings"
+	"testing"
+
+	"github.com/davecgh/go-spew/spew"
+)
+
+var recordBatchTestCases = []struct {
+	name         string
+	batch        RecordBatch
+	encoded      []byte
+	oldGoEncoded []byte // used in case of gzipped content for go versions prior to 1.8
+}{
+	{
+		name:  "empty record",
+		batch: RecordBatch{Version: 2, Records: []*Record{}},
+		encoded: []byte{
+			0, 0, 0, 0, 0, 0, 0, 0, // First Offset
+			0, 0, 0, 49, // Length
+			0, 0, 0, 0, // Partition Leader Epoch
+			2,                // Version
+			89, 95, 183, 221, // CRC
+			0, 0, // Attributes
+			0, 0, 0, 0, // Last Offset Delta
+			0, 0, 0, 0, 0, 0, 0, 0, // First Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
+			0, 0, // Producer Epoch
+			0, 0, 0, 0, // First Sequence
+			0, 0, 0, 0, // Number of Records
+		},
+	},
+	{
+		name:  "control batch",
+		batch: RecordBatch{Version: 2, Control: true, Records: []*Record{}},
+		encoded: []byte{
+			0, 0, 0, 0, 0, 0, 0, 0, // First Offset
+			0, 0, 0, 49, // Length
+			0, 0, 0, 0, // Partition Leader Epoch
+			2,               // Version
+			81, 46, 67, 217, // CRC
+			0, 32, // Attributes
+			0, 0, 0, 0, // Last Offset Delta
+			0, 0, 0, 0, 0, 0, 0, 0, // First Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
+			0, 0, // Producer Epoch
+			0, 0, 0, 0, // First Sequence
+			0, 0, 0, 0, // Number of Records
+		},
+	},
+	{
+		name: "uncompressed record",
+		batch: RecordBatch{
+			Version:        2,
+			FirstTimestamp: 10,
+			Records: []*Record{{
+				TimestampDelta: 5,
+				Key:            []byte{1, 2, 3, 4},
+				Value:          []byte{5, 6, 7},
+				Headers: []*RecordHeader{{
+					Key:   []byte{8, 9, 10},
+					Value: []byte{11, 12},
+				}},
+			}},
+		},
+		encoded: []byte{
+			0, 0, 0, 0, 0, 0, 0, 0, // First Offset
+			0, 0, 0, 70, // Length
+			0, 0, 0, 0, // Partition Leader Epoch
+			2,                // Version
+			219, 71, 20, 201, // CRC
+			0, 0, // Attributes
+			0, 0, 0, 0, // Last Offset Delta
+			0, 0, 0, 0, 0, 0, 0, 10, // First Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
+			0, 0, // Producer Epoch
+			0, 0, 0, 0, // First Sequence
+			0, 0, 0, 1, // Number of Records
+			40, // Record Length
+			0,  // Attributes
+			10, // Timestamp Delta
+			0,  // Offset Delta
+			8,  // Key Length
+			1, 2, 3, 4,
+			6, // Value Length
+			5, 6, 7,
+			2,        // Number of Headers
+			6,        // Header Key Length
+			8, 9, 10, // Header Key
+			4,      // Header Value Length
+			11, 12, // Header Value
+		},
+	},
+	{
+		name: "gzipped record",
+		batch: RecordBatch{
+			Version:        2,
+			Codec:          CompressionGZIP,
+			FirstTimestamp: 10,
+			Records: []*Record{{
+				TimestampDelta: 5,
+				Key:            []byte{1, 2, 3, 4},
+				Value:          []byte{5, 6, 7},
+				Headers: []*RecordHeader{{
+					Key:   []byte{8, 9, 10},
+					Value: []byte{11, 12},
+				}},
+			}},
+		},
+		encoded: []byte{
+			0, 0, 0, 0, 0, 0, 0, 0, // First Offset
+			0, 0, 0, 94, // Length
+			0, 0, 0, 0, // Partition Leader Epoch
+			2,                // Version
+			15, 156, 184, 78, // CRC
+			0, 1, // Attributes
+			0, 0, 0, 0, // Last Offset Delta
+			0, 0, 0, 0, 0, 0, 0, 10, // First Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
+			0, 0, // Producer Epoch
+			0, 0, 0, 0, // First Sequence
+			0, 0, 0, 1, // Number of Records
+			31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 210, 96, 224, 98, 224, 96, 100, 98, 102, 97, 99, 101,
+			99, 103, 98, 227, 224, 228, 98, 225, 230, 1, 4, 0, 0, 255, 255, 173, 201, 88, 103, 21, 0, 0, 0,
+		},
+		oldGoEncoded: []byte{
+			0, 0, 0, 0, 0, 0, 0, 0, // First Offset
+			0, 0, 0, 94, // Length
+			0, 0, 0, 0, // Partition Leader Epoch
+			2,               // Version
+			144, 168, 0, 33, // CRC
+			0, 1, // Attributes
+			0, 0, 0, 0, // Last Offset Delta
+			0, 0, 0, 0, 0, 0, 0, 10, // First Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
+			0, 0, // Producer Epoch
+			0, 0, 0, 0, // First Sequence
+			0, 0, 0, 1, // Number of Records
+			31, 139, 8, 0, 0, 9, 110, 136, 0, 255, 210, 96, 224, 98, 224, 96, 100, 98, 102, 97, 99, 101,
+			99, 103, 98, 227, 224, 228, 98, 225, 230, 1, 4, 0, 0, 255, 255, 173, 201, 88, 103, 21, 0, 0, 0,
+		},
+	},
+	{
+		name: "snappy compressed record",
+		batch: RecordBatch{
+			Version:        2,
+			Codec:          CompressionSnappy,
+			FirstTimestamp: 10,
+			Records: []*Record{{
+				TimestampDelta: 5,
+				Key:            []byte{1, 2, 3, 4},
+				Value:          []byte{5, 6, 7},
+				Headers: []*RecordHeader{{
+					Key:   []byte{8, 9, 10},
+					Value: []byte{11, 12},
+				}},
+			}},
+		},
+		encoded: []byte{
+			0, 0, 0, 0, 0, 0, 0, 0, // First Offset
+			0, 0, 0, 72, // Length
+			0, 0, 0, 0, // Partition Leader Epoch
+			2,               // Version
+			95, 173, 35, 17, // CRC
+			0, 2, // Attributes
+			0, 0, 0, 0, // Last Offset Delta
+			0, 0, 0, 0, 0, 0, 0, 10, // First Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
+			0, 0, // Producer Epoch
+			0, 0, 0, 0, // First Sequence
+			0, 0, 0, 1, // Number of Records
+			21, 80, 40, 0, 10, 0, 8, 1, 2, 3, 4, 6, 5, 6, 7, 2, 6, 8, 9, 10, 4, 11, 12,
+		},
+	},
+	{
+		name: "lz4 compressed record",
+		batch: RecordBatch{
+			Version:        2,
+			Codec:          CompressionLZ4,
+			FirstTimestamp: 10,
+			Records: []*Record{{
+				TimestampDelta: 5,
+				Key:            []byte{1, 2, 3, 4},
+				Value:          []byte{5, 6, 7},
+				Headers: []*RecordHeader{{
+					Key:   []byte{8, 9, 10},
+					Value: []byte{11, 12},
+				}},
+			}},
+		},
+		encoded: []byte{
+			0, 0, 0, 0, 0, 0, 0, 0, // First Offset
+			0, 0, 0, 89, // Length
+			0, 0, 0, 0, // Partition Leader Epoch
+			2,                // Version
+			129, 238, 43, 82, // CRC
+			0, 3, // Attributes
+			0, 0, 0, 0, // Last Offset Delta
+			0, 0, 0, 0, 0, 0, 0, 10, // First Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
+			0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
+			0, 0, // Producer Epoch
+			0, 0, 0, 0, // First Sequence
+			0, 0, 0, 1, // Number of Records
+			4, 34, 77, 24, 100, 112, 185, 21, 0, 0, 128, 40, 0, 10, 0, 8, 1, 2, 3, 4, 6, 5, 6, 7, 2,
+			6, 8, 9, 10, 4, 11, 12, 0, 0, 0, 0, 12, 59, 239, 146,
+		},
+	},
+}
+
+func isOldGo(t *testing.T) bool {
+	v := strings.Split(runtime.Version()[2:], ".")
+	if len(v) < 2 {
+		t.Logf("Can't parse version: %s", runtime.Version())
+		return false
+	}
+	maj, err := strconv.Atoi(v[0])
+	if err != nil {
+		t.Logf("Can't parse version: %s", runtime.Version())
+		return false
+	}
+	min, err := strconv.Atoi(v[1])
+	if err != nil {
+		t.Logf("Can't parse version: %s", runtime.Version())
+		return false
+	}
+	return maj < 1 || (maj == 1 && min < 8)
+}
+
+func TestRecordBatchEncoding(t *testing.T) {
+	for _, tc := range recordBatchTestCases {
+		if tc.oldGoEncoded != nil && isOldGo(t) {
+			testEncodable(t, tc.name, &tc.batch, tc.oldGoEncoded)
+		} else {
+			testEncodable(t, tc.name, &tc.batch, tc.encoded)
+		}
+	}
+}
+
+func TestRecordBatchDecoding(t *testing.T) {
+	for _, tc := range recordBatchTestCases {
+		batch := RecordBatch{}
+		testDecodable(t, tc.name, &batch, tc.encoded)
+		for _, r := range batch.Records {
+			r.length = varintLengthField{}
+		}
+		for _, r := range tc.batch.Records {
+			r.length = varintLengthField{}
+		}
+		if !reflect.DeepEqual(batch, tc.batch) {
+			t.Errorf(spew.Sprintf("invalid decode of %s\ngot %+v\nwanted %+v", tc.name, batch, tc.batch))
+		}
+	}
+}

+ 96 - 0
records.go

@@ -0,0 +1,96 @@
+package sarama
+
+import "fmt"
+
+const (
+	legacyRecords = iota
+	defaultRecords
+)
+
+// Records implements a union type containing either a RecordBatch or a legacy MessageSet.
+type Records struct {
+	recordsType int
+	msgSet      *MessageSet
+	recordBatch *RecordBatch
+}
+
+func newLegacyRecords(msgSet *MessageSet) Records {
+	return Records{recordsType: legacyRecords, msgSet: msgSet}
+}
+
+func newDefaultRecords(batch *RecordBatch) Records {
+	return Records{recordsType: defaultRecords, recordBatch: batch}
+}
+
+func (r *Records) encode(pe packetEncoder) error {
+	switch r.recordsType {
+	case legacyRecords:
+		if r.msgSet == nil {
+			return nil
+		}
+		return r.msgSet.encode(pe)
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return nil
+		}
+		return r.recordBatch.encode(pe)
+	}
+	return fmt.Errorf("unknown records type: %v", r.recordsType)
+}
+
+func (r *Records) decode(pd packetDecoder) error {
+	switch r.recordsType {
+	case legacyRecords:
+		r.msgSet = &MessageSet{}
+		return r.msgSet.decode(pd)
+	case defaultRecords:
+		r.recordBatch = &RecordBatch{}
+		return r.recordBatch.decode(pd)
+	}
+	return fmt.Errorf("unknown records type: %v", r.recordsType)
+}
+
+func (r *Records) numRecords() (int, error) {
+	switch r.recordsType {
+	case legacyRecords:
+		if r.msgSet == nil {
+			return 0, nil
+		}
+		return len(r.msgSet.Messages), nil
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return 0, nil
+		}
+		return len(r.recordBatch.Records), nil
+	}
+	return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
+}
+
+func (r *Records) isPartial() (bool, error) {
+	switch r.recordsType {
+	case legacyRecords:
+		if r.msgSet == nil {
+			return false, nil
+		}
+		return r.msgSet.PartialTrailingMessage, nil
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return false, nil
+		}
+		return r.recordBatch.PartialTrailingRecord, nil
+	}
+	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
+}
+
+func (r *Records) isControl() (bool, error) {
+	switch r.recordsType {
+	case legacyRecords:
+		return false, nil
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return false, nil
+		}
+		return r.recordBatch.Control, nil
+	}
+	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
+}

+ 137 - 0
records_test.go

@@ -0,0 +1,137 @@
+package sarama
+
+import (
+	"bytes"
+	"reflect"
+	"testing"
+)
+
+func TestLegacyRecords(t *testing.T) {
+	set := &MessageSet{
+		Messages: []*MessageBlock{
+			{
+				Msg: &Message{
+					Version: 1,
+				},
+			},
+		},
+	}
+	r := newLegacyRecords(set)
+
+	exp, err := encode(set, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	buf, err := encode(&r, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !bytes.Equal(buf, exp) {
+		t.Errorf("Wrong encoding for legacy records, wanted %v, got %v", exp, buf)
+	}
+
+	set = &MessageSet{}
+	r = newLegacyRecords(nil)
+
+	err = decode(exp, set)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = decode(buf, &r)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if !reflect.DeepEqual(set, r.msgSet) {
+		t.Errorf("Wrong decoding for legacy records, wanted %#+v, got %#+v", set, r.msgSet)
+	}
+
+	n, err := r.numRecords()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if n != 1 {
+		t.Errorf("Wrong number of records, wanted 1, got %d", n)
+	}
+
+	p, err := r.isPartial()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if p {
+		t.Errorf("MessageSet shouldn't have a partial trailing message")
+	}
+
+	c, err := r.isControl()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if c {
+		t.Errorf("MessageSet can't be a control batch")
+	}
+}
+
+func TestDefaultRecords(t *testing.T) {
+	batch := &RecordBatch{
+		Version: 2,
+		Records: []*Record{
+			{
+				Value: []byte{1},
+			},
+		},
+	}
+
+	r := newDefaultRecords(batch)
+
+	exp, err := encode(batch, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	buf, err := encode(&r, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !bytes.Equal(buf, exp) {
+		t.Errorf("Wrong encoding for default records, wanted %v, got %v", exp, buf)
+	}
+
+	batch = &RecordBatch{}
+	r = newDefaultRecords(nil)
+
+	err = decode(exp, batch)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = decode(buf, &r)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if !reflect.DeepEqual(batch, r.recordBatch) {
+		t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
+	}
+
+	n, err := r.numRecords()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if n != 1 {
+		t.Errorf("Wrong number of records, wanted 1, got %d", n)
+	}
+
+	p, err := r.isPartial()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if p {
+		t.Errorf("RecordBatch shouldn't have a partial trailing record")
+	}
+
+	c, err := r.isControl()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if c {
+		t.Errorf("RecordBatch shouldn't be a control batch")
+	}
+}