Prechádzať zdrojové kódy

Support multiple record batches, closes #1022

Ivan Babrou 8 rokov pred
rodič
commit
068e0b77c8
9 zmenil súbory, kde vykonal 103 pridanie a 75 odobranie
  1. 16 5
      consumer.go
  2. 9 11
      fetch_response.go
  3. 1 1
      fetch_response_test.go
  4. 2 2
      produce_request.go
  5. 5 5
      produce_set.go
  6. 1 1
      produce_set_test.go
  7. 41 0
      record_batch.go
  8. 24 34
      records.go
  9. 4 16
      records_test.go

+ 16 - 5
consumer.go

@@ -601,14 +601,25 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
 	child.fetchSize = child.conf.Consumer.Fetch.Default
 	atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)
 
-	if control, err := block.Records.isControl(); err != nil || control {
-		return nil, err
-	}
-
 	if block.Records.recordsType == legacyRecords {
 		return child.parseMessages(block.Records.msgSet)
 	}
-	return child.parseRecords(block.Records.recordBatch)
+
+	messages := []*ConsumerMessage{}
+	for _, recordBatch := range block.Records.recordBatchSet.batches {
+		if recordBatch.Control {
+			continue
+		}
+
+		recordBatchMessages, err := child.parseRecords(recordBatch)
+		messages = append(messages, recordBatchMessages...)
+
+		if err != nil {
+			return messages, err
+		}
+	}
+
+	return messages, nil
 }
 
 // brokerConsumer

+ 9 - 11
fetch_response.go

@@ -1,6 +1,8 @@
 package sarama
 
-import "time"
+import (
+	"time"
+)
 
 type AbortedTransaction struct {
 	ProducerID  int64
@@ -301,22 +303,18 @@ func (r *FetchResponse) AddRecord(topic string, partition int32, key, value Enco
 	frb := r.getOrCreateBlock(topic, partition)
 	kb, vb := encodeKV(key, value)
 	rec := &Record{Key: kb, Value: vb, OffsetDelta: offset}
-	batch := frb.Records.recordBatch
-	if batch == nil {
-		batch = &RecordBatch{Version: 2}
-		frb.Records = newDefaultRecords(batch)
+	if frb.Records.recordBatchSet == nil {
+		frb.Records = newDefaultRecords([]*RecordBatch{&RecordBatch{Version: 2}})
 	}
-	batch.addRecord(rec)
+	frb.Records.recordBatchSet.batches[0].addRecord(rec)
 }
 
 func (r *FetchResponse) SetLastOffsetDelta(topic string, partition int32, offset int32) {
 	frb := r.getOrCreateBlock(topic, partition)
-	batch := frb.Records.recordBatch
-	if batch == nil {
-		batch = &RecordBatch{Version: 2}
-		frb.Records = newDefaultRecords(batch)
+	if frb.Records.recordBatchSet == nil {
+		frb.Records = newDefaultRecords([]*RecordBatch{&RecordBatch{Version: 2}})
 	}
-	batch.LastOffsetDelta = offset
+	frb.Records.recordBatchSet.batches[0].LastOffsetDelta = offset
 }
 
 func (r *FetchResponse) SetLastStableOffset(topic string, partition int32, offset int64) {

+ 1 - 1
fetch_response_test.go

@@ -185,7 +185,7 @@ func TestOneRecordFetchResponse(t *testing.T) {
 	if n != 1 {
 		t.Fatal("Decoding produced incorrect number of records.")
 	}
-	rec := block.Records.recordBatch.Records[0]
+	rec := block.Records.recordBatchSet.batches[0].Records[0]
 	if !bytes.Equal(rec.Key, []byte{0x01, 0x02, 0x03, 0x04}) {
 		t.Error("Decoding produced incorrect record key.")
 	}

+ 2 - 2
produce_request.go

@@ -113,7 +113,7 @@ func (r *ProduceRequest) encode(pe packetEncoder) error {
 			}
 			if metricRegistry != nil {
 				if r.Version >= 3 {
-					topicRecordCount += updateBatchMetrics(records.recordBatch, compressionRatioMetric, topicCompressionRatioMetric)
+					topicRecordCount += updateBatchMetrics(records.recordBatchSet.batches[0], compressionRatioMetric, topicCompressionRatioMetric)
 				} else {
 					topicRecordCount += updateMsgSetMetrics(records.msgSet, compressionRatioMetric, topicCompressionRatioMetric)
 				}
@@ -248,5 +248,5 @@ func (r *ProduceRequest) AddSet(topic string, partition int32, set *MessageSet)
 
 func (r *ProduceRequest) AddBatch(topic string, partition int32, batch *RecordBatch) {
 	r.ensureRecords(topic, partition)
-	r.records[topic][partition] = newDefaultRecords(batch)
+	r.records[topic][partition] = newDefaultRecords([]*RecordBatch{batch})
 }

+ 5 - 5
produce_set.go

@@ -64,7 +64,7 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
 				ProducerID:     -1, /* No producer id */
 				Codec:          ps.parent.conf.Producer.Compression,
 			}
-			set = &partitionSet{recordsToSend: newDefaultRecords(batch)}
+			set = &partitionSet{recordsToSend: newDefaultRecords([]*RecordBatch{batch})}
 			size = recordBatchOverhead
 		} else {
 			set = &partitionSet{recordsToSend: newLegacyRecords(new(MessageSet))}
@@ -79,7 +79,7 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
 		rec := &Record{
 			Key:            key,
 			Value:          val,
-			TimestampDelta: timestamp.Sub(set.recordsToSend.recordBatch.FirstTimestamp),
+			TimestampDelta: timestamp.Sub(set.recordsToSend.recordBatchSet.batches[0].FirstTimestamp),
 		}
 		size += len(key) + len(val)
 		if len(msg.Headers) > 0 {
@@ -89,7 +89,7 @@ func (ps *produceSet) add(msg *ProducerMessage) error {
 				size += len(rec.Headers[i].Key) + len(rec.Headers[i].Value) + 2*binary.MaxVarintLen32
 			}
 		}
-		set.recordsToSend.recordBatch.addRecord(rec)
+		set.recordsToSend.recordBatchSet.batches[0].addRecord(rec)
 	} else {
 		msgToSend := &Message{Codec: CompressionNone, Key: key, Value: val}
 		if ps.parent.conf.Version.IsAtLeast(V0_10_0_0) {
@@ -122,11 +122,11 @@ func (ps *produceSet) buildRequest() *ProduceRequest {
 	for topic, partitionSet := range ps.msgs {
 		for partition, set := range partitionSet {
 			if req.Version >= 3 {
-				for i, record := range set.recordsToSend.recordBatch.Records {
+				for i, record := range set.recordsToSend.recordBatchSet.batches[0].Records {
 					record.OffsetDelta = int64(i)
 				}
 
-				req.AddBatch(topic, partition, set.recordsToSend.recordBatch)
+				req.AddBatch(topic, partition, set.recordsToSend.recordBatchSet.batches[0])
 				continue
 			}
 			if ps.parent.conf.Producer.Compression == CompressionNone {

+ 1 - 1
produce_set_test.go

@@ -227,7 +227,7 @@ func TestProduceSetV3RequestBuilding(t *testing.T) {
 		t.Error("Wrong request version")
 	}
 
-	batch := req.records["t1"][0].recordBatch
+	batch := req.records["t1"][0].recordBatchSet.batches[0]
 	if batch.FirstTimestamp != now {
 		t.Errorf("Wrong first timestamp: %v", batch.FirstTimestamp)
 	}

+ 41 - 0
record_batch.go

@@ -35,6 +35,47 @@ func (e recordsArray) decode(pd packetDecoder) error {
 	return nil
 }
 
+type RecordBatchSet struct {
+	batches []*RecordBatch
+}
+
+func (rbs *RecordBatchSet) encode(pe packetEncoder) error {
+	for _, rb := range rbs.batches {
+		if err := rb.encode(pe); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (rbs *RecordBatchSet) decode(pd packetDecoder) error {
+	rbs.batches = []*RecordBatch{}
+
+	for {
+		if pd.remaining() == 0 {
+			break
+		}
+
+		rb := &RecordBatch{}
+		if err := rb.decode(pd); err != nil {
+			// If we have at least one decoded record batch, this is not an error
+			if err == ErrInsufficientData && len(rbs.batches) > 0 {
+				return nil
+			}
+			return err
+		}
+
+		// If we have at least one full record batch, we skip incomplete ones
+		if rb.PartialTrailingRecord && len(rbs.batches) > 0 {
+			return nil
+		}
+
+		rbs.batches = append(rbs.batches, rb)
+	}
+
+	return nil
+}
+
 type RecordBatch struct {
 	FirstOffset           int64
 	PartitionLeaderEpoch  int32

+ 24 - 34
records.go

@@ -1,6 +1,8 @@
 package sarama
 
-import "fmt"
+import (
+	"fmt"
+)
 
 const (
 	unknownRecords = iota
@@ -13,28 +15,28 @@ const (
 
 // Records implements a union type containing either a RecordBatch or a legacy MessageSet.
 type Records struct {
-	recordsType int
-	msgSet      *MessageSet
-	recordBatch *RecordBatch
+	recordsType    int
+	msgSet         *MessageSet
+	recordBatchSet *RecordBatchSet
 }
 
 func newLegacyRecords(msgSet *MessageSet) Records {
 	return Records{recordsType: legacyRecords, msgSet: msgSet}
 }
 
-func newDefaultRecords(batch *RecordBatch) Records {
-	return Records{recordsType: defaultRecords, recordBatch: batch}
+func newDefaultRecords(batches []*RecordBatch) Records {
+	return Records{recordsType: defaultRecords, recordBatchSet: &RecordBatchSet{batches}}
 }
 
 // setTypeFromFields sets type of Records depending on which of msgSet or recordBatch is not nil.
 // The first return value indicates whether both fields are nil (and the type is not set).
 // If both fields are not nil, it returns an error.
 func (r *Records) setTypeFromFields() (bool, error) {
-	if r.msgSet == nil && r.recordBatch == nil {
+	if r.msgSet == nil && r.recordBatchSet == nil {
 		return true, nil
 	}
-	if r.msgSet != nil && r.recordBatch != nil {
-		return false, fmt.Errorf("both msgSet and recordBatch are set, but record type is unknown")
+	if r.msgSet != nil && r.recordBatchSet != nil {
+		return false, fmt.Errorf("both msgSet and recordBatchSet are set, but record type is unknown")
 	}
 	r.recordsType = defaultRecords
 	if r.msgSet != nil {
@@ -57,10 +59,10 @@ func (r *Records) encode(pe packetEncoder) error {
 		}
 		return r.msgSet.encode(pe)
 	case defaultRecords:
-		if r.recordBatch == nil {
+		if r.recordBatchSet == nil {
 			return nil
 		}
-		return r.recordBatch.encode(pe)
+		return r.recordBatchSet.encode(pe)
 	}
 	return fmt.Errorf("unknown records type: %v", r.recordsType)
 }
@@ -95,8 +97,8 @@ func (r *Records) decode(pd packetDecoder) error {
 		r.msgSet = &MessageSet{}
 		return r.msgSet.decode(pd)
 	case defaultRecords:
-		r.recordBatch = &RecordBatch{}
-		return r.recordBatch.decode(pd)
+		r.recordBatchSet = &RecordBatchSet{batches: []*RecordBatch{}}
+		return r.recordBatchSet.decode(pd)
 	}
 	return fmt.Errorf("unknown records type: %v", r.recordsType)
 }
@@ -115,10 +117,14 @@ func (r *Records) numRecords() (int, error) {
 		}
 		return len(r.msgSet.Messages), nil
 	case defaultRecords:
-		if r.recordBatch == nil {
+		if r.recordBatchSet == nil {
 			return 0, nil
 		}
-		return len(r.recordBatch.Records), nil
+		s := 0
+		for i := range r.recordBatchSet.batches {
+			s += len(r.recordBatchSet.batches[i].Records)
+		}
+		return s, nil
 	}
 	return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
 }
@@ -139,29 +145,13 @@ func (r *Records) isPartial() (bool, error) {
 		}
 		return r.msgSet.PartialTrailingMessage, nil
 	case defaultRecords:
-		if r.recordBatch == nil {
+		if r.recordBatchSet == nil {
 			return false, nil
 		}
-		return r.recordBatch.PartialTrailingRecord, nil
-	}
-	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
-}
-
-func (r *Records) isControl() (bool, error) {
-	if r.recordsType == unknownRecords {
-		if empty, err := r.setTypeFromFields(); err != nil || empty {
-			return false, err
+		if len(r.recordBatchSet.batches) == 1 {
+			return r.recordBatchSet.batches[0].PartialTrailingRecord, nil
 		}
-	}
-
-	switch r.recordsType {
-	case legacyRecords:
 		return false, nil
-	case defaultRecords:
-		if r.recordBatch == nil {
-			return false, nil
-		}
-		return r.recordBatch.Control, nil
 	}
 	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
 }

+ 4 - 16
records_test.go

@@ -64,14 +64,6 @@ func TestLegacyRecords(t *testing.T) {
 	if p {
 		t.Errorf("MessageSet shouldn't have a partial trailing message")
 	}
-
-	c, err := r.isControl()
-	if err != nil {
-		t.Fatal(err)
-	}
-	if c {
-		t.Errorf("MessageSet can't be a control batch")
-	}
 }
 
 func TestDefaultRecords(t *testing.T) {
@@ -84,7 +76,7 @@ func TestDefaultRecords(t *testing.T) {
 		},
 	}
 
-	r := newDefaultRecords(batch)
+	r := newDefaultRecords([]*RecordBatch{batch})
 
 	exp, err := encode(batch, nil)
 	if err != nil {
@@ -113,8 +105,8 @@ func TestDefaultRecords(t *testing.T) {
 	if r.recordsType != defaultRecords {
 		t.Fatalf("Wrong records type %v, expected %v", r.recordsType, defaultRecords)
 	}
-	if !reflect.DeepEqual(batch, r.recordBatch) {
-		t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
+	if !reflect.DeepEqual(batch, r.recordBatchSet.batches[0]) {
+		t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatchSet.batches[0])
 	}
 
 	n, err := r.numRecords()
@@ -133,11 +125,7 @@ func TestDefaultRecords(t *testing.T) {
 		t.Errorf("RecordBatch shouldn't have a partial trailing record")
 	}
 
-	c, err := r.isControl()
-	if err != nil {
-		t.Fatal(err)
-	}
-	if c {
+	if r.recordBatchSet.batches[0].Control {
 		t.Errorf("RecordBatch shouldn't be a control batch")
 	}
 }