Browse Source

Merge pull request #1023 from bobrik/multiple-record-batches

Support multiple record batches, closes #1022
Evan Huus 8 years ago
parent
commit
0f4f8caef9
5 changed files with 137 additions and 43 deletions
  1. 27 8
      consumer.go
  2. 79 19
      fetch_response.go
  3. 9 9
      fetch_response_test.go
  4. 9 0
      message_set.go
  5. 13 7
      records.go

+ 27 - 8
consumer.go

@@ -570,12 +570,12 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
 		return nil, block.Err
 	}
 
-	nRecs, err := block.Records.numRecords()
+	nRecs, err := block.numRecords()
 	if err != nil {
 		return nil, err
 	}
 	if nRecs == 0 {
-		partialTrailingMessage, err := block.Records.isPartial()
+		partialTrailingMessage, err := block.isPartial()
 		if err != nil {
 			return nil, err
 		}
@@ -601,14 +601,33 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
 	child.fetchSize = child.conf.Consumer.Fetch.Default
 	atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)
 
-	if control, err := block.Records.isControl(); err != nil || control {
-		return nil, err
-	}
+	messages := []*ConsumerMessage{}
+	for _, records := range block.RecordsSet {
+		if control, err := records.isControl(); err != nil || control {
+			continue
+		}
+
+		switch records.recordsType {
+		case legacyRecords:
+			messageSetMessages, err := child.parseMessages(records.msgSet)
+			if err != nil {
+				return nil, err
+			}
 
-	if block.Records.recordsType == legacyRecords {
-		return child.parseMessages(block.Records.msgSet)
+			messages = append(messages, messageSetMessages...)
+		case defaultRecords:
+			recordBatchMessages, err := child.parseRecords(records.recordBatch)
+			if err != nil {
+				return nil, err
+			}
+
+			messages = append(messages, recordBatchMessages...)
+		default:
+			return nil, fmt.Errorf("unknown records type: %v", records.recordsType)
+		}
 	}
-	return child.parseRecords(block.Records.recordBatch)
+
+	return messages, nil
 }
 
 // brokerConsumer

+ 79 - 19
fetch_response.go

@@ -1,6 +1,8 @@
 package sarama
 
-import "time"
+import (
+	"time"
+)
 
 type AbortedTransaction struct {
 	ProducerID  int64
@@ -31,7 +33,9 @@ type FetchResponseBlock struct {
 	HighWaterMarkOffset int64
 	LastStableOffset    int64
 	AbortedTransactions []*AbortedTransaction
-	Records             Records
+	Records             *Records // deprecated: use FetchResponseBlock.Records
+	RecordsSet          []*Records
+	Partial             bool
 }
 
 func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error) {
@@ -79,15 +83,69 @@ func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error)
 	if err != nil {
 		return err
 	}
-	if recordsSize > 0 {
-		if err = b.Records.decode(recordsDecoder); err != nil {
+
+	b.RecordsSet = []*Records{}
+
+	for recordsDecoder.remaining() > 0 {
+		records := &Records{}
+		if err := records.decode(recordsDecoder); err != nil {
+			// If we have at least one decoded records, this is not an error
+			if err == ErrInsufficientData {
+				if len(b.RecordsSet) == 0 {
+					b.Partial = true
+				}
+				break
+			}
+			return err
+		}
+
+		partial, err := records.isPartial()
+		if err != nil {
 			return err
 		}
+
+		// If we have at least one full records, we skip incomplete ones
+		if partial && len(b.RecordsSet) > 0 {
+			break
+		}
+
+		b.RecordsSet = append(b.RecordsSet, records)
+
+		if b.Records == nil {
+			b.Records = records
+		}
 	}
 
 	return nil
 }
 
+func (b *FetchResponseBlock) numRecords() (int, error) {
+	sum := 0
+
+	for _, records := range b.RecordsSet {
+		count, err := records.numRecords()
+		if err != nil {
+			return 0, err
+		}
+
+		sum += count
+	}
+
+	return sum, nil
+}
+
+func (b *FetchResponseBlock) isPartial() (bool, error) {
+	if b.Partial {
+		return true, nil
+	}
+
+	if len(b.RecordsSet) == 1 {
+		return b.RecordsSet[0].isPartial()
+	}
+
+	return false, nil
+}
+
 func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error) {
 	pe.putInt16(int16(b.Err))
 
@@ -107,9 +165,11 @@ func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error)
 	}
 
 	pe.push(&lengthField{})
-	err = b.Records.encode(pe)
-	if err != nil {
-		return err
+	for _, records := range b.RecordsSet {
+		err = records.encode(pe)
+		if err != nil {
+			return err
+		}
 	}
 	return pe.pop()
 }
@@ -289,11 +349,11 @@ func (r *FetchResponse) AddMessage(topic string, partition int32, key, value Enc
 	kb, vb := encodeKV(key, value)
 	msg := &Message{Key: kb, Value: vb}
 	msgBlock := &MessageBlock{Msg: msg, Offset: offset}
-	set := frb.Records.msgSet
-	if set == nil {
-		set = &MessageSet{}
-		frb.Records = newLegacyRecords(set)
+	if len(frb.RecordsSet) == 0 {
+		records := newLegacyRecords(&MessageSet{})
+		frb.RecordsSet = []*Records{&records}
 	}
+	set := frb.RecordsSet[0].msgSet
 	set.Messages = append(set.Messages, msgBlock)
 }
 
@@ -301,21 +361,21 @@ func (r *FetchResponse) AddRecord(topic string, partition int32, key, value Enco
 	frb := r.getOrCreateBlock(topic, partition)
 	kb, vb := encodeKV(key, value)
 	rec := &Record{Key: kb, Value: vb, OffsetDelta: offset}
-	batch := frb.Records.recordBatch
-	if batch == nil {
-		batch = &RecordBatch{Version: 2}
-		frb.Records = newDefaultRecords(batch)
+	if len(frb.RecordsSet) == 0 {
+		records := newDefaultRecords(&RecordBatch{Version: 2})
+		frb.RecordsSet = []*Records{&records}
 	}
+	batch := frb.RecordsSet[0].recordBatch
 	batch.addRecord(rec)
 }
 
 func (r *FetchResponse) SetLastOffsetDelta(topic string, partition int32, offset int32) {
 	frb := r.getOrCreateBlock(topic, partition)
-	batch := frb.Records.recordBatch
-	if batch == nil {
-		batch = &RecordBatch{Version: 2}
-		frb.Records = newDefaultRecords(batch)
+	if len(frb.RecordsSet) == 0 {
+		records := newDefaultRecords(&RecordBatch{Version: 2})
+		frb.RecordsSet = []*Records{&records}
 	}
+	batch := frb.RecordsSet[0].recordBatch
 	batch.LastOffsetDelta = offset
 }
 

+ 9 - 9
fetch_response_test.go

@@ -117,7 +117,7 @@ func TestOneMessageFetchResponse(t *testing.T) {
 	if block.HighWaterMarkOffset != 0x10101010 {
 		t.Error("Decoding didn't produce correct high water mark offset.")
 	}
-	partial, err := block.Records.isPartial()
+	partial, err := block.isPartial()
 	if err != nil {
 		t.Fatalf("Unexpected error: %v", err)
 	}
@@ -125,14 +125,14 @@ func TestOneMessageFetchResponse(t *testing.T) {
 		t.Error("Decoding detected a partial trailing message where there wasn't one.")
 	}
 
-	n, err := block.Records.numRecords()
+	n, err := block.numRecords()
 	if err != nil {
 		t.Fatalf("Unexpected error: %v", err)
 	}
 	if n != 1 {
 		t.Fatal("Decoding produced incorrect number of messages.")
 	}
-	msgBlock := block.Records.msgSet.Messages[0]
+	msgBlock := block.RecordsSet[0].msgSet.Messages[0]
 	if msgBlock.Offset != 0x550000 {
 		t.Error("Decoding produced incorrect message offset.")
 	}
@@ -170,7 +170,7 @@ func TestOneRecordFetchResponse(t *testing.T) {
 	if block.HighWaterMarkOffset != 0x10101010 {
 		t.Error("Decoding didn't produce correct high water mark offset.")
 	}
-	partial, err := block.Records.isPartial()
+	partial, err := block.isPartial()
 	if err != nil {
 		t.Fatalf("Unexpected error: %v", err)
 	}
@@ -178,14 +178,14 @@ func TestOneRecordFetchResponse(t *testing.T) {
 		t.Error("Decoding detected a partial trailing record where there wasn't one.")
 	}
 
-	n, err := block.Records.numRecords()
+	n, err := block.numRecords()
 	if err != nil {
 		t.Fatalf("Unexpected error: %v", err)
 	}
 	if n != 1 {
 		t.Fatal("Decoding produced incorrect number of records.")
 	}
-	rec := block.Records.recordBatch.Records[0]
+	rec := block.RecordsSet[0].recordBatch.Records[0]
 	if !bytes.Equal(rec.Key, []byte{0x01, 0x02, 0x03, 0x04}) {
 		t.Error("Decoding produced incorrect record key.")
 	}
@@ -216,7 +216,7 @@ func TestOneMessageFetchResponseV4(t *testing.T) {
 	if block.HighWaterMarkOffset != 0x10101010 {
 		t.Error("Decoding didn't produce correct high water mark offset.")
 	}
-	partial, err := block.Records.isPartial()
+	partial, err := block.isPartial()
 	if err != nil {
 		t.Fatalf("Unexpected error: %v", err)
 	}
@@ -224,14 +224,14 @@ func TestOneMessageFetchResponseV4(t *testing.T) {
 		t.Error("Decoding detected a partial trailing record where there wasn't one.")
 	}
 
-	n, err := block.Records.numRecords()
+	n, err := block.numRecords()
 	if err != nil {
 		t.Fatalf("Unexpected error: %v", err)
 	}
 	if n != 1 {
 		t.Fatal("Decoding produced incorrect number of records.")
 	}
-	msgBlock := block.Records.msgSet.Messages[0]
+	msgBlock := block.RecordsSet[0].msgSet.Messages[0]
 	if msgBlock.Offset != 0x550000 {
 		t.Error("Decoding produced incorrect message offset.")
 	}

+ 9 - 0
message_set.go

@@ -64,6 +64,15 @@ func (ms *MessageSet) decode(pd packetDecoder) (err error) {
 	ms.Messages = nil
 
 	for pd.remaining() > 0 {
+		magic, err := magicValue(pd)
+		if err != nil {
+			return err
+		}
+
+		if magic > 1 {
+			return nil
+		}
+
 		msb := new(MessageBlock)
 		err = msb.decode(pd)
 		switch err {

+ 13 - 7
records.go

@@ -62,16 +62,12 @@ func (r *Records) encode(pe packetEncoder) error {
 		}
 		return r.recordBatch.encode(pe)
 	}
+
 	return fmt.Errorf("unknown records type: %v", r.recordsType)
 }
 
 func (r *Records) setTypeFromMagic(pd packetDecoder) error {
-	dec, err := pd.peek(magicOffset, magicLength)
-	if err != nil {
-		return err
-	}
-
-	magic, err := dec.getInt8()
+	magic, err := magicValue(pd)
 	if err != nil {
 		return err
 	}
@@ -80,13 +76,14 @@ func (r *Records) setTypeFromMagic(pd packetDecoder) error {
 	if magic < 2 {
 		r.recordsType = legacyRecords
 	}
+
 	return nil
 }
 
 func (r *Records) decode(pd packetDecoder) error {
 	if r.recordsType == unknownRecords {
 		if err := r.setTypeFromMagic(pd); err != nil {
-			return nil
+			return err
 		}
 	}
 
@@ -165,3 +162,12 @@ func (r *Records) isControl() (bool, error) {
 	}
 	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
 }
+
+func magicValue(pd packetDecoder) (int8, error) {
+	dec, err := pd.peek(magicOffset, magicLength)
+	if err != nil {
+		return 0, err
+	}
+
+	return dec.getInt8()
+}