Ivan Babrou пре 8 година
родитељ
комит
c8284bbf5c
7 измењених фајлова са 247 додато и 74 уклоњено
  1. 22 13
      consumer.go
  2. 39 18
      fetch_response.go
  3. 24 6
      fetch_response_test.go
  4. 3 3
      produce_request.go
  5. 2 2
      produce_set.go
  6. 126 25
      records.go
  7. 31 7
      records_test.go

+ 22 - 13
consumer.go

@@ -570,10 +570,18 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
 		return nil, block.Err
 		return nil, block.Err
 	}
 	}
 
 
-	if block.numRecords() == 0 {
+	nRecs, err := block.numRecords()
+	if err != nil {
+		return nil, err
+	}
+	if nRecs == 0 {
+		partialTrailingMessage, err := block.isPartial()
+		if err != nil {
+			return nil, err
+		}
 		// We got no messages. If we got a trailing one then we need to ask for more data.
 		// We got no messages. If we got a trailing one then we need to ask for more data.
 		// Otherwise we just poll again and wait for one to be produced...
 		// Otherwise we just poll again and wait for one to be produced...
-		if block.isPartial() {
+		if partialTrailingMessage {
 			if child.conf.Consumer.Fetch.Max > 0 && child.fetchSize == child.conf.Consumer.Fetch.Max {
 			if child.conf.Consumer.Fetch.Max > 0 && child.fetchSize == child.conf.Consumer.Fetch.Max {
 				// we can't ask for more data, we've hit the configured limit
 				// we can't ask for more data, we've hit the configured limit
 				child.sendError(ErrMessageTooLarge)
 				child.sendError(ErrMessageTooLarge)
@@ -594,27 +602,28 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
 	atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)
 	atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)
 
 
 	messages := []*ConsumerMessage{}
 	messages := []*ConsumerMessage{}
-	for _, chunk := range block.RecordsSet {
-		if chunk.msgSet != nil {
-			messageSetMessages, err := child.parseMessages(chunk.msgSet)
+	for _, records := range block.RecordsSet {
+		if control, err := records.isControl(); err != nil || control {
+			continue
+		}
+
+		switch records.recordsType {
+		case legacyRecords:
+			messageSetMessages, err := child.parseMessages(records.msgSet)
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
 
 
 			messages = append(messages, messageSetMessages...)
 			messages = append(messages, messageSetMessages...)
-		}
-
-		if chunk.recordBatch != nil {
-			if chunk.recordBatch.Control {
-				continue
-			}
-
-			recordBatchMessages, err := child.parseRecords(chunk.recordBatch)
+		case defaultRecords:
+			recordBatchMessages, err := child.parseRecords(records.recordBatch)
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
 
 
 			messages = append(messages, recordBatchMessages...)
 			messages = append(messages, recordBatchMessages...)
+		default:
+			return nil, fmt.Errorf("unknown records type: %v", records.recordsType)
 		}
 		}
 	}
 	}
 
 

+ 39 - 18
fetch_response.go

@@ -90,9 +90,9 @@ func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error)
 			break
 			break
 		}
 		}
 
 
-		chunk := &Records{}
-		if err := chunk.decode(recordsDecoder); err != nil {
-			// If we have at least one decoded record chunk, this is not an error
+		records := &Records{}
+		if err := records.decode(recordsDecoder); err != nil {
+			// If we have at least one decoded records, this is not an error
 			if err == ErrInsufficientData {
 			if err == ErrInsufficientData {
 				if len(b.RecordsSet) == 0 {
 				if len(b.RecordsSet) == 0 {
 					b.Partial = true
 					b.Partial = true
@@ -102,29 +102,47 @@ func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error)
 			return err
 			return err
 		}
 		}
 
 
-		// If we have at least one full record chunk, we skip incomplete ones
-		if chunk.isPartial() && len(b.RecordsSet) > 0 {
+		partial, err := records.isPartial()
+		if err != nil {
+			return err
+		}
+
+		// If we have at least one full records, we skip incomplete ones
+		if partial && len(b.RecordsSet) > 0 {
 			break
 			break
 		}
 		}
 
 
-		b.RecordsSet = append(b.RecordsSet, chunk)
+		b.RecordsSet = append(b.RecordsSet, records)
 	}
 	}
 
 
 	return nil
 	return nil
 }
 }
 
 
-func (b *FetchResponseBlock) numRecords() int {
-	s := 0
+func (b *FetchResponseBlock) numRecords() (int, error) {
+	sum := 0
 
 
-	for _, chunk := range b.RecordsSet {
-		s += chunk.numRecords()
+	for _, records := range b.RecordsSet {
+		count, err := records.numRecords()
+		if err != nil {
+			return 0, err
+		}
+
+		sum += count
 	}
 	}
 
 
-	return s
+	return sum, nil
 }
 }
 
 
-func (b *FetchResponseBlock) isPartial() bool {
-	return b.Partial || len(b.RecordsSet) == 1 && b.RecordsSet[0].isPartial()
+func (b *FetchResponseBlock) isPartial() (bool, error) {
+	if b.Partial {
+		return true, nil
+	}
+
+	if len(b.RecordsSet) == 1 {
+		return b.RecordsSet[0].isPartial()
+	}
+
+	return false, nil
 }
 }
 
 
 func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error) {
 func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error) {
@@ -146,8 +164,8 @@ func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error)
 	}
 	}
 
 
 	pe.push(&lengthField{})
 	pe.push(&lengthField{})
-	for _, chunk := range b.RecordsSet {
-		err = chunk.encode(pe)
+	for _, records := range b.RecordsSet {
+		err = records.encode(pe)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -331,7 +349,8 @@ func (r *FetchResponse) AddMessage(topic string, partition int32, key, value Enc
 	msg := &Message{Key: kb, Value: vb}
 	msg := &Message{Key: kb, Value: vb}
 	msgBlock := &MessageBlock{Msg: msg, Offset: offset}
 	msgBlock := &MessageBlock{Msg: msg, Offset: offset}
 	if len(frb.RecordsSet) == 0 {
 	if len(frb.RecordsSet) == 0 {
-		frb.RecordsSet = []*Records{&Records{msgSet: &MessageSet{}}}
+		records := newLegacyRecords(&MessageSet{})
+		frb.RecordsSet = []*Records{&records}
 	}
 	}
 	set := frb.RecordsSet[0].msgSet
 	set := frb.RecordsSet[0].msgSet
 	set.Messages = append(set.Messages, msgBlock)
 	set.Messages = append(set.Messages, msgBlock)
@@ -342,7 +361,8 @@ func (r *FetchResponse) AddRecord(topic string, partition int32, key, value Enco
 	kb, vb := encodeKV(key, value)
 	kb, vb := encodeKV(key, value)
 	rec := &Record{Key: kb, Value: vb, OffsetDelta: offset}
 	rec := &Record{Key: kb, Value: vb, OffsetDelta: offset}
 	if len(frb.RecordsSet) == 0 {
 	if len(frb.RecordsSet) == 0 {
-		frb.RecordsSet = []*Records{&Records{recordBatch: &RecordBatch{Version: 2}}}
+		records := newDefaultRecords(&RecordBatch{Version: 2})
+		frb.RecordsSet = []*Records{&records}
 	}
 	}
 	batch := frb.RecordsSet[0].recordBatch
 	batch := frb.RecordsSet[0].recordBatch
 	batch.addRecord(rec)
 	batch.addRecord(rec)
@@ -351,7 +371,8 @@ func (r *FetchResponse) AddRecord(topic string, partition int32, key, value Enco
 func (r *FetchResponse) SetLastOffsetDelta(topic string, partition int32, offset int32) {
 func (r *FetchResponse) SetLastOffsetDelta(topic string, partition int32, offset int32) {
 	frb := r.getOrCreateBlock(topic, partition)
 	frb := r.getOrCreateBlock(topic, partition)
 	if len(frb.RecordsSet) == 0 {
 	if len(frb.RecordsSet) == 0 {
-		frb.RecordsSet = []*Records{&Records{recordBatch: &RecordBatch{Version: 2}}}
+		records := newDefaultRecords(&RecordBatch{Version: 2})
+		frb.RecordsSet = []*Records{&records}
 	}
 	}
 	batch := frb.RecordsSet[0].recordBatch
 	batch := frb.RecordsSet[0].recordBatch
 	batch.LastOffsetDelta = offset
 	batch.LastOffsetDelta = offset

+ 24 - 6
fetch_response_test.go

@@ -117,12 +117,18 @@ func TestOneMessageFetchResponse(t *testing.T) {
 	if block.HighWaterMarkOffset != 0x10101010 {
 	if block.HighWaterMarkOffset != 0x10101010 {
 		t.Error("Decoding didn't produce correct high water mark offset.")
 		t.Error("Decoding didn't produce correct high water mark offset.")
 	}
 	}
-	partial := block.RecordsSet[0].isPartial()
+	partial, err := block.isPartial()
+	if err != nil {
+		t.Fatalf("Unexpected error: %v", err)
+	}
 	if partial {
 	if partial {
 		t.Error("Decoding detected a partial trailing message where there wasn't one.")
 		t.Error("Decoding detected a partial trailing message where there wasn't one.")
 	}
 	}
 
 
-	n := block.RecordsSet[0].numRecords()
+	n, err := block.numRecords()
+	if err != nil {
+		t.Fatalf("Unexpected error: %v", err)
+	}
 	if n != 1 {
 	if n != 1 {
 		t.Fatal("Decoding produced incorrect number of messages.")
 		t.Fatal("Decoding produced incorrect number of messages.")
 	}
 	}
@@ -164,12 +170,18 @@ func TestOneRecordFetchResponse(t *testing.T) {
 	if block.HighWaterMarkOffset != 0x10101010 {
 	if block.HighWaterMarkOffset != 0x10101010 {
 		t.Error("Decoding didn't produce correct high water mark offset.")
 		t.Error("Decoding didn't produce correct high water mark offset.")
 	}
 	}
-	partial := block.RecordsSet[0].isPartial()
+	partial, err := block.isPartial()
+	if err != nil {
+		t.Fatalf("Unexpected error: %v", err)
+	}
 	if partial {
 	if partial {
 		t.Error("Decoding detected a partial trailing record where there wasn't one.")
 		t.Error("Decoding detected a partial trailing record where there wasn't one.")
 	}
 	}
 
 
-	n := block.RecordsSet[0].numRecords()
+	n, err := block.numRecords()
+	if err != nil {
+		t.Fatalf("Unexpected error: %v", err)
+	}
 	if n != 1 {
 	if n != 1 {
 		t.Fatal("Decoding produced incorrect number of records.")
 		t.Fatal("Decoding produced incorrect number of records.")
 	}
 	}
@@ -204,12 +216,18 @@ func TestOneMessageFetchResponseV4(t *testing.T) {
 	if block.HighWaterMarkOffset != 0x10101010 {
 	if block.HighWaterMarkOffset != 0x10101010 {
 		t.Error("Decoding didn't produce correct high water mark offset.")
 		t.Error("Decoding didn't produce correct high water mark offset.")
 	}
 	}
-	partial := block.RecordsSet[0].isPartial()
+	partial, err := block.isPartial()
+	if err != nil {
+		t.Fatalf("Unexpected error: %v", err)
+	}
 	if partial {
 	if partial {
 		t.Error("Decoding detected a partial trailing record where there wasn't one.")
 		t.Error("Decoding detected a partial trailing record where there wasn't one.")
 	}
 	}
 
 
-	n := block.RecordsSet[0].numRecords()
+	n, err := block.numRecords()
+	if err != nil {
+		t.Fatalf("Unexpected error: %v", err)
+	}
 	if n != 1 {
 	if n != 1 {
 		t.Fatal("Decoding produced incorrect number of records.")
 		t.Fatal("Decoding produced incorrect number of records.")
 	}
 	}

+ 3 - 3
produce_request.go

@@ -235,7 +235,7 @@ func (r *ProduceRequest) AddMessage(topic string, partition int32, msg *Message)
 
 
 	if set == nil {
 	if set == nil {
 		set = new(MessageSet)
 		set = new(MessageSet)
-		r.records[topic][partition] = Records{msgSet: set}
+		r.records[topic][partition] = newLegacyRecords(set)
 	}
 	}
 
 
 	set.addMessage(msg)
 	set.addMessage(msg)
@@ -243,10 +243,10 @@ func (r *ProduceRequest) AddMessage(topic string, partition int32, msg *Message)
 
 
 func (r *ProduceRequest) AddSet(topic string, partition int32, set *MessageSet) {
 func (r *ProduceRequest) AddSet(topic string, partition int32, set *MessageSet) {
 	r.ensureRecords(topic, partition)
 	r.ensureRecords(topic, partition)
-	r.records[topic][partition] = Records{msgSet: set}
+	r.records[topic][partition] = newLegacyRecords(set)
 }
 }
 
 
 func (r *ProduceRequest) AddBatch(topic string, partition int32, batch *RecordBatch) {
 func (r *ProduceRequest) AddBatch(topic string, partition int32, batch *RecordBatch) {
 	r.ensureRecords(topic, partition)
 	r.ensureRecords(topic, partition)
-	r.records[topic][partition] = Records{recordBatch: batch}
+	r.records[topic][partition] = newDefaultRecords(batch)
 }
 }

+ 2 - 2
produce_set.go

@@ -64,10 +64,10 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
 				ProducerID:     -1, /* No producer id */
 				ProducerID:     -1, /* No producer id */
 				Codec:          ps.parent.conf.Producer.Compression,
 				Codec:          ps.parent.conf.Producer.Compression,
 			}
 			}
-			set = &partitionSet{recordsToSend: Records{recordBatch: batch}}
+			set = &partitionSet{recordsToSend: newDefaultRecords(batch)}
 			size = recordBatchOverhead
 			size = recordBatchOverhead
 		} else {
 		} else {
-			set = &partitionSet{recordsToSend: Records{msgSet: &MessageSet{}}}
+			set = &partitionSet{recordsToSend: newLegacyRecords(new(MessageSet))}
 		}
 		}
 		partitions[msg.Partition] = set
 		partitions[msg.Partition] = set
 	}
 	}

+ 126 - 25
records.go

@@ -1,65 +1,166 @@
 package sarama
 package sarama
 
 
+import "fmt"
+
 const (
 const (
+	unknownRecords = iota
+	legacyRecords
+	defaultRecords
+
 	magicOffset = 16
 	magicOffset = 16
 	magicLength = 1
 	magicLength = 1
 )
 )
 
 
 // Records implements a union type containing either a RecordBatch or a legacy MessageSet.
 // Records implements a union type containing either a RecordBatch or a legacy MessageSet.
 type Records struct {
 type Records struct {
+	recordsType int
 	msgSet      *MessageSet
 	msgSet      *MessageSet
 	recordBatch *RecordBatch
 	recordBatch *RecordBatch
 }
 }
 
 
-func (c *Records) numRecords() int {
-	if c.msgSet != nil {
-		return len(c.msgSet.Messages)
-	}
+func newLegacyRecords(msgSet *MessageSet) Records {
+	return Records{recordsType: legacyRecords, msgSet: msgSet}
+}
 
 
-	if c.recordBatch != nil {
-		return len(c.recordBatch.Records)
-	}
+func newDefaultRecords(batch *RecordBatch) Records {
+	return Records{recordsType: defaultRecords, recordBatch: batch}
+}
 
 
-	return 0
+// setTypeFromFields sets type of Records depending on which of msgSet or recordBatch is not nil.
+// The first return value indicates whether both fields are nil (and the type is not set).
+// If both fields are not nil, it returns an error.
+func (r *Records) setTypeFromFields() (bool, error) {
+	if r.msgSet == nil && r.recordBatch == nil {
+		return true, nil
+	}
+	if r.msgSet != nil && r.recordBatch != nil {
+		return false, fmt.Errorf("both msgSet and recordBatch are set, but record type is unknown")
+	}
+	r.recordsType = defaultRecords
+	if r.msgSet != nil {
+		r.recordsType = legacyRecords
+	}
+	return false, nil
 }
 }
 
 
-func (c *Records) isPartial() bool {
-	if c.msgSet != nil {
-		return c.msgSet.PartialTrailingMessage
+func (r *Records) encode(pe packetEncoder) error {
+	if r.recordsType == unknownRecords {
+		if empty, err := r.setTypeFromFields(); err != nil || empty {
+			return err
+		}
 	}
 	}
 
 
-	if c.recordBatch != nil {
-		return c.recordBatch.PartialTrailingRecord
+	switch r.recordsType {
+	case legacyRecords:
+		if r.msgSet == nil {
+			return nil
+		}
+		return r.msgSet.encode(pe)
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return nil
+		}
+		return r.recordBatch.encode(pe)
 	}
 	}
 
 
-	return false
+	return fmt.Errorf("unknown records type: %v", r.recordsType)
 }
 }
 
 
-func (c *Records) decode(pd packetDecoder) (err error) {
+func (r *Records) setTypeFromMagic(pd packetDecoder) error {
 	magic, err := magicValue(pd)
 	magic, err := magicValue(pd)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
+	r.recordsType = defaultRecords
 	if magic < 2 {
 	if magic < 2 {
-		c.msgSet = &MessageSet{}
-		return c.msgSet.decode(pd)
+		r.recordsType = legacyRecords
 	}
 	}
 
 
-	c.recordBatch = &RecordBatch{}
-	return c.recordBatch.decode(pd)
+	return nil
 }
 }
 
 
-func (c *Records) encode(pe packetEncoder) (err error) {
-	if c.msgSet != nil {
-		return c.msgSet.encode(pe)
+func (r *Records) decode(pd packetDecoder) error {
+	if r.recordsType == unknownRecords {
+		if err := r.setTypeFromMagic(pd); err != nil {
+			return err
+		}
 	}
 	}
 
 
-	if c.recordBatch != nil {
-		return c.recordBatch.encode(pe)
+	switch r.recordsType {
+	case legacyRecords:
+		r.msgSet = &MessageSet{}
+		return r.msgSet.decode(pd)
+	case defaultRecords:
+		r.recordBatch = &RecordBatch{}
+		return r.recordBatch.decode(pd)
 	}
 	}
+	return fmt.Errorf("unknown records type: %v", r.recordsType)
+}
 
 
-	return nil
+func (r *Records) numRecords() (int, error) {
+	if r.recordsType == unknownRecords {
+		if empty, err := r.setTypeFromFields(); err != nil || empty {
+			return 0, err
+		}
+	}
+
+	switch r.recordsType {
+	case legacyRecords:
+		if r.msgSet == nil {
+			return 0, nil
+		}
+		return len(r.msgSet.Messages), nil
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return 0, nil
+		}
+		return len(r.recordBatch.Records), nil
+	}
+	return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
+}
+
+func (r *Records) isPartial() (bool, error) {
+	if r.recordsType == unknownRecords {
+		if empty, err := r.setTypeFromFields(); err != nil || empty {
+			return false, err
+		}
+	}
+
+	switch r.recordsType {
+	case unknownRecords:
+		return false, nil
+	case legacyRecords:
+		if r.msgSet == nil {
+			return false, nil
+		}
+		return r.msgSet.PartialTrailingMessage, nil
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return false, nil
+		}
+		return r.recordBatch.PartialTrailingRecord, nil
+	}
+	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
+}
+
+func (r *Records) isControl() (bool, error) {
+	if r.recordsType == unknownRecords {
+		if empty, err := r.setTypeFromFields(); err != nil || empty {
+			return false, err
+		}
+	}
+
+	switch r.recordsType {
+	case legacyRecords:
+		return false, nil
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return false, nil
+		}
+		return r.recordBatch.Control, nil
+	}
+	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
 }
 }
 
 
 func magicValue(pd packetDecoder) (int8, error) {
 func magicValue(pd packetDecoder) (int8, error) {

+ 31 - 7
records_test.go

@@ -16,7 +16,7 @@ func TestLegacyRecords(t *testing.T) {
 			},
 			},
 		},
 		},
 	}
 	}
-	r := Records{msgSet: set}
+	r := newLegacyRecords(set)
 
 
 	exp, err := encode(set, nil)
 	exp, err := encode(set, nil)
 	if err != nil {
 	if err != nil {
@@ -42,19 +42,36 @@ func TestLegacyRecords(t *testing.T) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
+	if r.recordsType != legacyRecords {
+		t.Fatalf("Wrong records type %v, expected %v", r.recordsType, legacyRecords)
+	}
 	if !reflect.DeepEqual(set, r.msgSet) {
 	if !reflect.DeepEqual(set, r.msgSet) {
 		t.Errorf("Wrong decoding for legacy records, wanted %#+v, got %#+v", set, r.msgSet)
 		t.Errorf("Wrong decoding for legacy records, wanted %#+v, got %#+v", set, r.msgSet)
 	}
 	}
 
 
-	n := r.numRecords()
+	n, err := r.numRecords()
+	if err != nil {
+		t.Fatal(err)
+	}
 	if n != 1 {
 	if n != 1 {
 		t.Errorf("Wrong number of records, wanted 1, got %d", n)
 		t.Errorf("Wrong number of records, wanted 1, got %d", n)
 	}
 	}
 
 
-	p := r.isPartial()
+	p, err := r.isPartial()
+	if err != nil {
+		t.Fatal(err)
+	}
 	if p {
 	if p {
 		t.Errorf("MessageSet shouldn't have a partial trailing message")
 		t.Errorf("MessageSet shouldn't have a partial trailing message")
 	}
 	}
+
+	c, err := r.isControl()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if c {
+		t.Errorf("MessageSet can't be a control batch")
+	}
 }
 }
 
 
 func TestDefaultRecords(t *testing.T) {
 func TestDefaultRecords(t *testing.T) {
@@ -67,7 +84,7 @@ func TestDefaultRecords(t *testing.T) {
 		},
 		},
 	}
 	}
 
 
-	r := Records{recordBatch: batch}
+	r := newDefaultRecords(batch)
 
 
 	exp, err := encode(batch, nil)
 	exp, err := encode(batch, nil)
 	if err != nil {
 	if err != nil {
@@ -93,11 +110,14 @@ func TestDefaultRecords(t *testing.T) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
+	if r.recordsType != defaultRecords {
+		t.Fatalf("Wrong records type %v, expected %v", r.recordsType, defaultRecords)
+	}
 	if !reflect.DeepEqual(batch, r.recordBatch) {
 	if !reflect.DeepEqual(batch, r.recordBatch) {
 		t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
 		t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
 	}
 	}
 
 
-	n := r.numRecords()
+	n, err := r.numRecords()
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -105,7 +125,7 @@ func TestDefaultRecords(t *testing.T) {
 		t.Errorf("Wrong number of records, wanted 1, got %d", n)
 		t.Errorf("Wrong number of records, wanted 1, got %d", n)
 	}
 	}
 
 
-	p := r.isPartial()
+	p, err := r.isPartial()
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -113,7 +133,11 @@ func TestDefaultRecords(t *testing.T) {
 		t.Errorf("RecordBatch shouldn't have a partial trailing record")
 		t.Errorf("RecordBatch shouldn't have a partial trailing record")
 	}
 	}
 
 
-	if r.recordBatch.Control {
+	c, err := r.isControl()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if c {
 		t.Errorf("RecordBatch shouldn't be a control batch")
 		t.Errorf("RecordBatch shouldn't be a control batch")
 	}
 	}
 }
 }