Explorar el Código

Support multiple mixed Record types

Ivan Babrou hace 8 años
padre
commit
d7d2bb7b2b
Se han modificado 10 ficheros con 149 adiciones y 257 borrados
  1. 20 21
      consumer.go
  2. 56 16
      fetch_response.go
  3. 9 27
      fetch_response_test.go
  4. 9 0
      message_set.go
  5. 4 4
      produce_request.go
  6. 6 6
      produce_set.go
  7. 1 1
      produce_set_test.go
  8. 0 41
      record_batch.go
  9. 35 120
      records.go
  10. 9 21
      records_test.go

+ 20 - 21
consumer.go

@@ -570,18 +570,10 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
 		return nil, block.Err
 	}
 
-	nRecs, err := block.Records.numRecords()
-	if err != nil {
-		return nil, err
-	}
-	if nRecs == 0 {
-		partialTrailingMessage, err := block.Records.isPartial()
-		if err != nil {
-			return nil, err
-		}
+	if block.numRecords() == 0 {
 		// We got no messages. If we got a trailing one then we need to ask for more data.
 		// Otherwise we just poll again and wait for one to be produced...
-		if partialTrailingMessage {
+		if block.isPartial() {
 			if child.conf.Consumer.Fetch.Max > 0 && child.fetchSize == child.conf.Consumer.Fetch.Max {
 				// we can't ask for more data, we've hit the configured limit
 				child.sendError(ErrMessageTooLarge)
@@ -601,21 +593,28 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
 	child.fetchSize = child.conf.Consumer.Fetch.Default
 	atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)
 
-	if block.Records.recordsType == legacyRecords {
-		return child.parseMessages(block.Records.msgSet)
-	}
-
 	messages := []*ConsumerMessage{}
-	for _, recordBatch := range block.Records.recordBatchSet.batches {
-		if recordBatch.Control {
-			continue
+	for _, chunk := range block.RecordsSet {
+		if chunk.msgSet != nil {
+			messageSetMessages, err := child.parseMessages(chunk.msgSet)
+			if err != nil {
+				return nil, err
+			}
+
+			messages = append(messages, messageSetMessages...)
 		}
 
-		recordBatchMessages, err := child.parseRecords(recordBatch)
-		messages = append(messages, recordBatchMessages...)
+		if chunk.recordBatch != nil {
+			if chunk.recordBatch.Control {
+				continue
+			}
 
-		if err != nil {
-			return messages, err
+			recordBatchMessages, err := child.parseRecords(chunk.recordBatch)
+			if err != nil {
+				return nil, err
+			}
+
+			messages = append(messages, recordBatchMessages...)
 		}
 	}
 

+ 56 - 16
fetch_response.go

@@ -33,7 +33,8 @@ type FetchResponseBlock struct {
 	HighWaterMarkOffset int64
 	LastStableOffset    int64
 	AbortedTransactions []*AbortedTransaction
-	Records             Records
+	RecordsSet          []*Records
+	Partial             bool
 }
 
 func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error) {
@@ -81,15 +82,51 @@ func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error)
 	if err != nil {
 		return err
 	}
-	if recordsSize > 0 {
-		if err = b.Records.decode(recordsDecoder); err != nil {
+
+	b.RecordsSet = []*Records{}
+
+	for {
+		if recordsDecoder.remaining() == 0 {
+			break
+		}
+
+		chunk := &Records{}
+		if err := chunk.decode(recordsDecoder); err != nil {
+			// If we have at least one decoded record chunk, this is not an error
+			if err == ErrInsufficientData {
+				if len(b.RecordsSet) == 0 {
+					b.Partial = true
+				}
+				break
+			}
 			return err
 		}
+
+		// If we have at least one full record chunk, we skip incomplete ones
+		if chunk.isPartial() && len(b.RecordsSet) > 0 {
+			break
+		}
+
+		b.RecordsSet = append(b.RecordsSet, chunk)
 	}
 
 	return nil
 }
 
+func (b *FetchResponseBlock) numRecords() int {
+	s := 0
+
+	for _, chunk := range b.RecordsSet {
+		s += chunk.numRecords()
+	}
+
+	return s
+}
+
+func (b *FetchResponseBlock) isPartial() bool {
+	return b.Partial || len(b.RecordsSet) == 1 && b.RecordsSet[0].isPartial()
+}
+
 func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error) {
 	pe.putInt16(int16(b.Err))
 
@@ -109,9 +146,11 @@ func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error)
 	}
 
 	pe.push(&lengthField{})
-	err = b.Records.encode(pe)
-	if err != nil {
-		return err
+	for _, chunk := range b.RecordsSet {
+		err = chunk.encode(pe)
+		if err != nil {
+			return err
+		}
 	}
 	return pe.pop()
 }
@@ -291,11 +330,10 @@ func (r *FetchResponse) AddMessage(topic string, partition int32, key, value Enc
 	kb, vb := encodeKV(key, value)
 	msg := &Message{Key: kb, Value: vb}
 	msgBlock := &MessageBlock{Msg: msg, Offset: offset}
-	set := frb.Records.msgSet
-	if set == nil {
-		set = &MessageSet{}
-		frb.Records = newLegacyRecords(set)
+	if len(frb.RecordsSet) == 0 {
+		frb.RecordsSet = []*Records{&Records{msgSet: &MessageSet{}}}
 	}
+	set := frb.RecordsSet[0].msgSet
 	set.Messages = append(set.Messages, msgBlock)
 }
 
@@ -303,18 +341,20 @@ func (r *FetchResponse) AddRecord(topic string, partition int32, key, value Enco
 	frb := r.getOrCreateBlock(topic, partition)
 	kb, vb := encodeKV(key, value)
 	rec := &Record{Key: kb, Value: vb, OffsetDelta: offset}
-	if frb.Records.recordBatchSet == nil {
-		frb.Records = newDefaultRecords([]*RecordBatch{&RecordBatch{Version: 2}})
+	if len(frb.RecordsSet) == 0 {
+		frb.RecordsSet = []*Records{&Records{recordBatch: &RecordBatch{Version: 2}}}
 	}
-	frb.Records.recordBatchSet.batches[0].addRecord(rec)
+	batch := frb.RecordsSet[0].recordBatch
+	batch.addRecord(rec)
 }
 
 func (r *FetchResponse) SetLastOffsetDelta(topic string, partition int32, offset int32) {
 	frb := r.getOrCreateBlock(topic, partition)
-	if frb.Records.recordBatchSet == nil {
-		frb.Records = newDefaultRecords([]*RecordBatch{&RecordBatch{Version: 2}})
+	if len(frb.RecordsSet) == 0 {
+		frb.RecordsSet = []*Records{&Records{recordBatch: &RecordBatch{Version: 2}}}
 	}
-	frb.Records.recordBatchSet.batches[0].LastOffsetDelta = offset
+	batch := frb.RecordsSet[0].recordBatch
+	batch.LastOffsetDelta = offset
 }
 
 func (r *FetchResponse) SetLastStableOffset(topic string, partition int32, offset int64) {

+ 9 - 27
fetch_response_test.go

@@ -117,22 +117,16 @@ func TestOneMessageFetchResponse(t *testing.T) {
 	if block.HighWaterMarkOffset != 0x10101010 {
 		t.Error("Decoding didn't produce correct high water mark offset.")
 	}
-	partial, err := block.Records.isPartial()
-	if err != nil {
-		t.Fatalf("Unexpected error: %v", err)
-	}
+	partial := block.RecordsSet[0].isPartial()
 	if partial {
 		t.Error("Decoding detected a partial trailing message where there wasn't one.")
 	}
 
-	n, err := block.Records.numRecords()
-	if err != nil {
-		t.Fatalf("Unexpected error: %v", err)
-	}
+	n := block.RecordsSet[0].numRecords()
 	if n != 1 {
 		t.Fatal("Decoding produced incorrect number of messages.")
 	}
-	msgBlock := block.Records.msgSet.Messages[0]
+	msgBlock := block.RecordsSet[0].msgSet.Messages[0]
 	if msgBlock.Offset != 0x550000 {
 		t.Error("Decoding produced incorrect message offset.")
 	}
@@ -170,22 +164,16 @@ func TestOneRecordFetchResponse(t *testing.T) {
 	if block.HighWaterMarkOffset != 0x10101010 {
 		t.Error("Decoding didn't produce correct high water mark offset.")
 	}
-	partial, err := block.Records.isPartial()
-	if err != nil {
-		t.Fatalf("Unexpected error: %v", err)
-	}
+	partial := block.RecordsSet[0].isPartial()
 	if partial {
 		t.Error("Decoding detected a partial trailing record where there wasn't one.")
 	}
 
-	n, err := block.Records.numRecords()
-	if err != nil {
-		t.Fatalf("Unexpected error: %v", err)
-	}
+	n := block.RecordsSet[0].numRecords()
 	if n != 1 {
 		t.Fatal("Decoding produced incorrect number of records.")
 	}
-	rec := block.Records.recordBatchSet.batches[0].Records[0]
+	rec := block.RecordsSet[0].recordBatch.Records[0]
 	if !bytes.Equal(rec.Key, []byte{0x01, 0x02, 0x03, 0x04}) {
 		t.Error("Decoding produced incorrect record key.")
 	}
@@ -216,22 +204,16 @@ func TestOneMessageFetchResponseV4(t *testing.T) {
 	if block.HighWaterMarkOffset != 0x10101010 {
 		t.Error("Decoding didn't produce correct high water mark offset.")
 	}
-	partial, err := block.Records.isPartial()
-	if err != nil {
-		t.Fatalf("Unexpected error: %v", err)
-	}
+	partial := block.RecordsSet[0].isPartial()
 	if partial {
 		t.Error("Decoding detected a partial trailing record where there wasn't one.")
 	}
 
-	n, err := block.Records.numRecords()
-	if err != nil {
-		t.Fatalf("Unexpected error: %v", err)
-	}
+	n := block.RecordsSet[0].numRecords()
 	if n != 1 {
 		t.Fatal("Decoding produced incorrect number of records.")
 	}
-	msgBlock := block.Records.msgSet.Messages[0]
+	msgBlock := block.RecordsSet[0].msgSet.Messages[0]
 	if msgBlock.Offset != 0x550000 {
 		t.Error("Decoding produced incorrect message offset.")
 	}

+ 9 - 0
message_set.go

@@ -64,6 +64,15 @@ func (ms *MessageSet) decode(pd packetDecoder) (err error) {
 	ms.Messages = nil
 
 	for pd.remaining() > 0 {
+		magic, err := magicValue(pd)
+		if err != nil {
+			return err
+		}
+
+		if magic > 1 {
+			return nil
+		}
+
 		msb := new(MessageBlock)
 		err = msb.decode(pd)
 		switch err {

+ 4 - 4
produce_request.go

@@ -113,7 +113,7 @@ func (r *ProduceRequest) encode(pe packetEncoder) error {
 			}
 			if metricRegistry != nil {
 				if r.Version >= 3 {
-					topicRecordCount += updateBatchMetrics(records.recordBatchSet.batches[0], compressionRatioMetric, topicCompressionRatioMetric)
+					topicRecordCount += updateBatchMetrics(records.recordBatch, compressionRatioMetric, topicCompressionRatioMetric)
 				} else {
 					topicRecordCount += updateMsgSetMetrics(records.msgSet, compressionRatioMetric, topicCompressionRatioMetric)
 				}
@@ -235,7 +235,7 @@ func (r *ProduceRequest) AddMessage(topic string, partition int32, msg *Message)
 
 	if set == nil {
 		set = new(MessageSet)
-		r.records[topic][partition] = newLegacyRecords(set)
+		r.records[topic][partition] = Records{msgSet: set}
 	}
 
 	set.addMessage(msg)
@@ -243,10 +243,10 @@ func (r *ProduceRequest) AddMessage(topic string, partition int32, msg *Message)
 
 func (r *ProduceRequest) AddSet(topic string, partition int32, set *MessageSet) {
 	r.ensureRecords(topic, partition)
-	r.records[topic][partition] = newLegacyRecords(set)
+	r.records[topic][partition] = Records{msgSet: set}
 }
 
 func (r *ProduceRequest) AddBatch(topic string, partition int32, batch *RecordBatch) {
 	r.ensureRecords(topic, partition)
-	r.records[topic][partition] = newDefaultRecords([]*RecordBatch{batch})
+	r.records[topic][partition] = Records{recordBatch: batch}
 }

+ 6 - 6
produce_set.go

@@ -64,10 +64,10 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
 				ProducerID:     -1, /* No producer id */
 				Codec:          ps.parent.conf.Producer.Compression,
 			}
-			set = &partitionSet{recordsToSend: newDefaultRecords([]*RecordBatch{batch})}
+			set = &partitionSet{recordsToSend: Records{recordBatch: batch}}
 			size = recordBatchOverhead
 		} else {
-			set = &partitionSet{recordsToSend: newLegacyRecords(new(MessageSet))}
+			set = &partitionSet{recordsToSend: Records{msgSet: &MessageSet{}}}
 		}
 		partitions[msg.Partition] = set
 	}
@@ -79,7 +79,7 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
 		rec := &Record{
 			Key:            key,
 			Value:          val,
-			TimestampDelta: timestamp.Sub(set.recordsToSend.recordBatchSet.batches[0].FirstTimestamp),
+			TimestampDelta: timestamp.Sub(set.recordsToSend.recordBatch.FirstTimestamp),
 		}
 		size += len(key) + len(val)
 		if len(msg.Headers) > 0 {
@@ -89,7 +89,7 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
 				size += len(rec.Headers[i].Key) + len(rec.Headers[i].Value) + 2*binary.MaxVarintLen32
 			}
 		}
-		set.recordsToSend.recordBatchSet.batches[0].addRecord(rec)
+		set.recordsToSend.recordBatch.addRecord(rec)
 	} else {
 		msgToSend := &Message{Codec: CompressionNone, Key: key, Value: val}
 		if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) {
@@ -122,11 +122,11 @@ func (ps *produceSet) buildRequest() *ProduceRequest {
 	for topic, partitionSet := range ps.msgs {
 		for partition, set := range partitionSet {
 			if req.Version >= 3 {
-				for i, record := range set.recordsToSend.recordBatchSet.batches[0].Records {
+				for i, record := range set.recordsToSend.recordBatch.Records {
 					record.OffsetDelta = int64(i)
 				}
 
-				req.AddBatch(topic, partition, set.recordsToSend.recordBatchSet.batches[0])
+				req.AddBatch(topic, partition, set.recordsToSend.recordBatch)
 				continue
 			}
 			if ps.parent.conf.Producer.Compression == CompressionNone {

+ 1 - 1
produce_set_test.go

@@ -227,7 +227,7 @@ func TestProduceSetV3RequestBuilding(t *testing.T) {
 		t.Error("Wrong request version")
 	}
 
-	batch := req.records["t1"][0].recordBatchSet.batches[0]
+	batch := req.records["t1"][0].recordBatch
 	if batch.FirstTimestamp != now {
 		t.Errorf("Wrong first timestamp: %v", batch.FirstTimestamp)
 	}

+ 0 - 41
record_batch.go

@@ -35,47 +35,6 @@ func (e recordsArray) decode(pd packetDecoder) error {
 	return nil
 }
 
-type RecordBatchSet struct {
-	batches []*RecordBatch
-}
-
-func (rbs *RecordBatchSet) encode(pe packetEncoder) error {
-	for _, rb := range rbs.batches {
-		if err := rb.encode(pe); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-func (rbs *RecordBatchSet) decode(pd packetDecoder) error {
-	rbs.batches = []*RecordBatch{}
-
-	for {
-		if pd.remaining() == 0 {
-			break
-		}
-
-		rb := &RecordBatch{}
-		if err := rb.decode(pd); err != nil {
-			// If we have at least one decoded record batch, this is not an error
-			if err == ErrInsufficientData && len(rbs.batches) > 0 {
-				return nil
-			}
-			return err
-		}
-
-		// If we have at least one full record batch, we skip incomplete ones
-		if rb.PartialTrailingRecord && len(rbs.batches) > 0 {
-			return nil
-		}
-
-		rbs.batches = append(rbs.batches, rb)
-	}
-
-	return nil
-}
-
 type RecordBatch struct {
 	FirstOffset           int64
 	PartitionLeaderEpoch  int32

+ 35 - 120
records.go

@@ -1,157 +1,72 @@
 package sarama
 
-import (
-	"fmt"
-)
-
 const (
-	unknownRecords = iota
-	legacyRecords
-	defaultRecords
-
 	magicOffset = 16
 	magicLength = 1
 )
 
 // Records implements a union type containing either a RecordBatch or a legacy MessageSet.
 type Records struct {
-	recordsType    int
-	msgSet         *MessageSet
-	recordBatchSet *RecordBatchSet
-}
-
-func newLegacyRecords(msgSet *MessageSet) Records {
-	return Records{recordsType: legacyRecords, msgSet: msgSet}
-}
-
-func newDefaultRecords(batches []*RecordBatch) Records {
-	return Records{recordsType: defaultRecords, recordBatchSet: &RecordBatchSet{batches}}
+	msgSet      *MessageSet
+	recordBatch *RecordBatch
 }
 
-// setTypeFromFields sets type of Records depending on which of msgSet or recordBatch is not nil.
-// The first return value indicates whether both fields are nil (and the type is not set).
-// If both fields are not nil, it returns an error.
-func (r *Records) setTypeFromFields() (bool, error) {
-	if r.msgSet == nil && r.recordBatchSet == nil {
-		return true, nil
+func (c *Records) numRecords() int {
+	if c.msgSet != nil {
+		return len(c.msgSet.Messages)
 	}
-	if r.msgSet != nil && r.recordBatchSet != nil {
-		return false, fmt.Errorf("both msgSet and recordBatchSet are set, but record type is unknown")
-	}
-	r.recordsType = defaultRecords
-	if r.msgSet != nil {
-		r.recordsType = legacyRecords
+
+	if c.recordBatch != nil {
+		return len(c.recordBatch.Records)
 	}
-	return false, nil
+
+	return 0
 }
 
-func (r *Records) encode(pe packetEncoder) error {
-	if r.recordsType == unknownRecords {
-		if empty, err := r.setTypeFromFields(); err != nil || empty {
-			return err
-		}
+func (c *Records) isPartial() bool {
+	if c.msgSet != nil {
+		return c.msgSet.PartialTrailingMessage
 	}
 
-	switch r.recordsType {
-	case legacyRecords:
-		if r.msgSet == nil {
-			return nil
-		}
-		return r.msgSet.encode(pe)
-	case defaultRecords:
-		if r.recordBatchSet == nil {
-			return nil
-		}
-		return r.recordBatchSet.encode(pe)
+	if c.recordBatch != nil {
+		return c.recordBatch.PartialTrailingRecord
 	}
-	return fmt.Errorf("unknown records type: %v", r.recordsType)
-}
 
-func (r *Records) setTypeFromMagic(pd packetDecoder) error {
-	dec, err := pd.peek(magicOffset, magicLength)
-	if err != nil {
-		return err
-	}
+	return false
+}
 
-	magic, err := dec.getInt8()
+func (c *Records) decode(pd packetDecoder) (err error) {
+	magic, err := magicValue(pd)
 	if err != nil {
 		return err
 	}
 
-	r.recordsType = defaultRecords
 	if magic < 2 {
-		r.recordsType = legacyRecords
-	}
-	return nil
-}
-
-func (r *Records) decode(pd packetDecoder) error {
-	if r.recordsType == unknownRecords {
-		if err := r.setTypeFromMagic(pd); err != nil {
-			return nil
-		}
+		c.msgSet = &MessageSet{}
+		return c.msgSet.decode(pd)
 	}
 
-	switch r.recordsType {
-	case legacyRecords:
-		r.msgSet = &MessageSet{}
-		return r.msgSet.decode(pd)
-	case defaultRecords:
-		r.recordBatchSet = &RecordBatchSet{batches: []*RecordBatch{}}
-		return r.recordBatchSet.decode(pd)
-	}
-	return fmt.Errorf("unknown records type: %v", r.recordsType)
+	c.recordBatch = &RecordBatch{}
+	return c.recordBatch.decode(pd)
 }
 
-func (r *Records) numRecords() (int, error) {
-	if r.recordsType == unknownRecords {
-		if empty, err := r.setTypeFromFields(); err != nil || empty {
-			return 0, err
-		}
+func (c *Records) encode(pe packetEncoder) (err error) {
+	if c.msgSet != nil {
+		return c.msgSet.encode(pe)
 	}
 
-	switch r.recordsType {
-	case legacyRecords:
-		if r.msgSet == nil {
-			return 0, nil
-		}
-		return len(r.msgSet.Messages), nil
-	case defaultRecords:
-		if r.recordBatchSet == nil {
-			return 0, nil
-		}
-		s := 0
-		for i := range r.recordBatchSet.batches {
-			s += len(r.recordBatchSet.batches[i].Records)
-		}
-		return s, nil
+	if c.recordBatch != nil {
+		return c.recordBatch.encode(pe)
 	}
-	return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
+
+	return nil
 }
 
-func (r *Records) isPartial() (bool, error) {
-	if r.recordsType == unknownRecords {
-		if empty, err := r.setTypeFromFields(); err != nil || empty {
-			return false, err
-		}
+func magicValue(pd packetDecoder) (int8, error) {
+	dec, err := pd.peek(magicOffset, magicLength)
+	if err != nil {
+		return 0, err
 	}
 
-	switch r.recordsType {
-	case unknownRecords:
-		return false, nil
-	case legacyRecords:
-		if r.msgSet == nil {
-			return false, nil
-		}
-		return r.msgSet.PartialTrailingMessage, nil
-	case defaultRecords:
-		if r.recordBatchSet == nil {
-			return false, nil
-		}
-		if len(r.recordBatchSet.batches) == 1 {
-			return r.recordBatchSet.batches[0].PartialTrailingRecord, nil
-		}
-		return false, nil
-	}
-	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
+	return dec.getInt8()
 }

+ 9 - 21
records_test.go

@@ -16,7 +16,7 @@ func TestLegacyRecords(t *testing.T) {
 			},
 		},
 	}
-	r := newLegacyRecords(set)
+	r := Records{msgSet: set}
 
 	exp, err := encode(set, nil)
 	if err != nil {
@@ -42,25 +42,16 @@ func TestLegacyRecords(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	if r.recordsType != legacyRecords {
-		t.Fatalf("Wrong records type %v, expected %v", r.recordsType, legacyRecords)
-	}
 	if !reflect.DeepEqual(set, r.msgSet) {
 		t.Errorf("Wrong decoding for legacy records, wanted %#+v, got %#+v", set, r.msgSet)
 	}
 
-	n, err := r.numRecords()
-	if err != nil {
-		t.Fatal(err)
-	}
+	n := r.numRecords()
 	if n != 1 {
 		t.Errorf("Wrong number of records, wanted 1, got %d", n)
 	}
 
-	p, err := r.isPartial()
-	if err != nil {
-		t.Fatal(err)
-	}
+	p := r.isPartial()
 	if p {
 		t.Errorf("MessageSet shouldn't have a partial trailing message")
 	}
@@ -76,7 +67,7 @@ func TestDefaultRecords(t *testing.T) {
 		},
 	}
 
-	r := newDefaultRecords([]*RecordBatch{batch})
+	r := Records{recordBatch: batch}
 
 	exp, err := encode(batch, nil)
 	if err != nil {
@@ -102,14 +93,11 @@ func TestDefaultRecords(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	if r.recordsType != defaultRecords {
-		t.Fatalf("Wrong records type %v, expected %v", r.recordsType, defaultRecords)
-	}
-	if !reflect.DeepEqual(batch, r.recordBatchSet.batches[0]) {
-		t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatchSet.batches[0])
+	if !reflect.DeepEqual(batch, r.recordBatch) {
+		t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
 	}
 
-	n, err := r.numRecords()
+	n := r.numRecords()
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -117,7 +105,7 @@ func TestDefaultRecords(t *testing.T) {
 		t.Errorf("Wrong number of records, wanted 1, got %d", n)
 	}
 
-	p, err := r.isPartial()
+	p := r.isPartial()
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -125,7 +113,7 @@ func TestDefaultRecords(t *testing.T) {
 		t.Errorf("RecordBatch shouldn't have a partial trailing record")
 	}
 
-	if r.recordBatchSet.batches[0].Control {
+	if r.recordBatch.Control {
 		t.Errorf("RecordBatch shouldn't be a control batch")
 	}
 }