ソースを参照

Merge pull request #990 from wladh/magic

Determine the records type based on the magic number not API version
Evan Huus 8 年 前
コミット
6a8d89d71d
9 ファイル変更211 行追加22 行削除
  1. 3 4
      consumer.go
  2. 43 0
      consumer_test.go
  3. 1 8
      fetch_response.go
  4. 75 2
      fetch_response_test.go
  5. 1 0
      packet_decoder.go
  6. 0 5
      produce_request.go
  7. 8 0
      real_decoder.go
  8. 72 1
      records.go
  9. 8 2
      records_test.go

+ 3 - 4
consumer.go

@@ -519,11 +519,10 @@ func (child *partitionConsumer) parseMessages(msgSet *MessageSet) ([]*ConsumerMe
 	return messages, nil
 }
 
-func (child *partitionConsumer) parseRecords(block *FetchResponseBlock) ([]*ConsumerMessage, error) {
+func (child *partitionConsumer) parseRecords(batch *RecordBatch) ([]*ConsumerMessage, error) {
 	var messages []*ConsumerMessage
 	var incomplete bool
 	prelude := true
-	batch := block.Records.recordBatch
 
 	for _, rec := range batch.Records {
 		offset := batch.FirstOffset + rec.OffsetDelta
@@ -599,10 +598,10 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
 		return nil, err
 	}
 
-	if response.Version < 4 {
+	if block.Records.recordsType == legacyRecords {
 		return child.parseMessages(block.Records.msgSet)
 	}
-	return child.parseRecords(block)
+	return child.parseRecords(block.Records.recordBatch)
 }
 
 // brokerConsumer

+ 43 - 0
consumer_test.go

@@ -435,6 +435,49 @@ func TestConsumerExtraOffsets(t *testing.T) {
 	}
 }
 
+func TestConsumeMessageWithNewerFetchAPIVersion(t *testing.T) {
+	// Given
+	fetchResponse1 := &FetchResponse{Version: 4}
+	fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 1)
+	fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 2)
+
+	cfg := NewConfig()
+	cfg.Version = V0_11_0_0
+
+	broker0 := NewMockBroker(t, 0)
+	fetchResponse2 := &FetchResponse{}
+	fetchResponse2.Version = 4
+	fetchResponse2.AddError("my_topic", 0, ErrNoError)
+	broker0.SetHandlerByMap(map[string]MockResponse{
+		"MetadataRequest": NewMockMetadataResponse(t).
+			SetBroker(broker0.Addr(), broker0.BrokerID()).
+			SetLeader("my_topic", 0, broker0.BrokerID()),
+		"OffsetRequest": NewMockOffsetResponse(t).
+			SetVersion(1).
+			SetOffset("my_topic", 0, OffsetNewest, 1234).
+			SetOffset("my_topic", 0, OffsetOldest, 0),
+		"FetchRequest": NewMockSequence(fetchResponse1, fetchResponse2),
+	})
+
+	master, err := NewConsumer([]string{broker0.Addr()}, cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// When
+	consumer, err := master.ConsumePartition("my_topic", 0, 1)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	assertMessageOffset(t, <-consumer.Messages(), 1)
+	assertMessageOffset(t, <-consumer.Messages(), 2)
+
+	safeClose(t, consumer)
+	safeClose(t, master)
+	broker0.Close()
+}
+
 // It is fine if offsets of fetched messages are not sequential (although
 // strictly increasing!).
 func TestConsumerNonSequentialOffsets(t *testing.T) {

+ 1 - 8
fetch_response.go

@@ -79,18 +79,11 @@ func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error)
 	if err != nil {
 		return err
 	}
-	var records Records
-	if version >= 4 {
-		records = newDefaultRecords(nil)
-	} else {
-		records = newLegacyRecords(nil)
-	}
 	if recordsSize > 0 {
-		if err = records.decode(recordsDecoder); err != nil {
+		if err = b.Records.decode(recordsDecoder); err != nil {
 			return err
 		}
 	}
-	b.Records = records
 
 	return nil
 }

+ 75 - 2
fetch_response_test.go

@@ -61,8 +61,28 @@ var (
 		0x06, 0x05, 0x06, 0x07,
 		0x02,
 		0x06, 0x08, 0x09, 0x0A,
-		0x04, 0x0B, 0x0C,
-	}
+		0x04, 0x0B, 0x0C}
+
+	oneMessageFetchResponseV4 = []byte{
+		0x00, 0x00, 0x00, 0x00, // ThrottleTime
+		0x00, 0x00, 0x00, 0x01, // Number of Topics
+		0x00, 0x05, 't', 'o', 'p', 'i', 'c', // Topic
+		0x00, 0x00, 0x00, 0x01, // Number of Partitions
+		0x00, 0x00, 0x00, 0x05, // Partition
+		0x00, 0x01, // Error
+		0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x10, 0x10, // High Watermark Offset
+		0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x10, 0x10, // Last Stable Offset
+		0x00, 0x00, 0x00, 0x00, // Number of Aborted Transactions
+		0x00, 0x00, 0x00, 0x1C,
+		// messageSet
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x55, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x10,
+		// message
+		0x23, 0x96, 0x4a, 0xf7, // CRC
+		0x00,
+		0x00,
+		0xFF, 0xFF, 0xFF, 0xFF,
+		0x00, 0x00, 0x00, 0x02, 0x00, 0xEE}
 )
 
 func TestEmptyFetchResponse(t *testing.T) {
@@ -173,3 +193,56 @@ func TestOneRecordFetchResponse(t *testing.T) {
 		t.Error("Decoding produced incorrect record value.")
 	}
 }
+
+func TestOneMessageFetchResponseV4(t *testing.T) {
+	response := FetchResponse{}
+	testVersionDecodable(t, "one message v4", &response, oneMessageFetchResponseV4, 4)
+
+	if len(response.Blocks) != 1 {
+		t.Fatal("Decoding produced incorrect number of topic blocks.")
+	}
+
+	if len(response.Blocks["topic"]) != 1 {
+		t.Fatal("Decoding produced incorrect number of partition blocks for topic.")
+	}
+
+	block := response.GetBlock("topic", 5)
+	if block == nil {
+		t.Fatal("GetBlock didn't return block.")
+	}
+	if block.Err != ErrOffsetOutOfRange {
+		t.Error("Decoding didn't produce correct error code.")
+	}
+	if block.HighWaterMarkOffset != 0x10101010 {
+		t.Error("Decoding didn't produce correct high water mark offset.")
+	}
+	partial, err := block.Records.isPartial()
+	if err != nil {
+		t.Fatalf("Unexpected error: %v", err)
+	}
+	if partial {
+		t.Error("Decoding detected a partial trailing record where there wasn't one.")
+	}
+
+	n, err := block.Records.numRecords()
+	if err != nil {
+		t.Fatalf("Unexpected error: %v", err)
+	}
+	if n != 1 {
+		t.Fatal("Decoding produced incorrect number of records.")
+	}
+	msgBlock := block.Records.msgSet.Messages[0]
+	if msgBlock.Offset != 0x550000 {
+		t.Error("Decoding produced incorrect message offset.")
+	}
+	msg := msgBlock.Msg
+	if msg.Codec != CompressionNone {
+		t.Error("Decoding produced incorrect message compression.")
+	}
+	if msg.Key != nil {
+		t.Error("Decoding produced message key where there was none.")
+	}
+	if !bytes.Equal(msg.Value, []byte{0x00, 0xEE}) {
+		t.Error("Decoding produced incorrect message value.")
+	}
+}

+ 1 - 0
packet_decoder.go

@@ -25,6 +25,7 @@ type packetDecoder interface {
 	// Subsets
 	remaining() int
 	getSubset(length int) (packetDecoder, error)
+	peek(offset, length int) (packetDecoder, error) // similar to getSubset, but it doesn't advance the offset
 
 	// Stacks, see PushDecoder
 	push(in pushDecoder) error

+ 0 - 5
produce_request.go

@@ -188,11 +188,6 @@ func (r *ProduceRequest) decode(pd packetDecoder, version int16) error {
 				return err
 			}
 			var records Records
-			if version >= 3 {
-				records = newDefaultRecords(nil)
-			} else {
-				records = newLegacyRecords(nil)
-			}
 			if err := records.decode(recordsDecoder); err != nil {
 				return err
 			}

+ 8 - 0
real_decoder.go

@@ -264,6 +264,14 @@ func (rd *realDecoder) getRawBytes(length int) ([]byte, error) {
 	return rd.raw[start:rd.off], nil
 }
 
+func (rd *realDecoder) peek(offset, length int) (packetDecoder, error) {
+	if rd.remaining() < offset+length {
+		return nil, ErrInsufficientData
+	}
+	off := rd.off + offset
+	return &realDecoder{raw: rd.raw[off : off+length]}, nil
+}
+
 // stacks
 
 func (rd *realDecoder) push(in pushDecoder) error {

+ 72 - 1
records.go

@@ -3,8 +3,12 @@ package sarama
 import "fmt"
 
 const (
-	legacyRecords = iota
+	unknownRecords = iota
+	legacyRecords
 	defaultRecords
+
+	magicOffset = 16
+	magicLength = 1
 )
 
 // Records implements a union type containing either a RecordBatch or a legacy MessageSet.
@@ -22,7 +26,30 @@ func newDefaultRecords(batch *RecordBatch) Records {
 	return Records{recordsType: defaultRecords, recordBatch: batch}
 }
 
+// setTypeFromFields sets type of Records depending on which of msgSet or recordBatch is not nil.
+// The first return value indicates whether both fields are nil (and the type is not set).
+// If both fields are not nil, it returns an error.
+func (r *Records) setTypeFromFields() (bool, error) {
+	if r.msgSet == nil && r.recordBatch == nil {
+		return true, nil
+	}
+	if r.msgSet != nil && r.recordBatch != nil {
+		return false, fmt.Errorf("both msgSet and recordBatch are set, but record type is unknown")
+	}
+	r.recordsType = defaultRecords
+	if r.msgSet != nil {
+		r.recordsType = legacyRecords
+	}
+	return false, nil
+}
+
 func (r *Records) encode(pe packetEncoder) error {
+	if r.recordsType == unknownRecords {
+		if empty, err := r.setTypeFromFields(); err != nil || empty {
+			return err
+		}
+	}
+
 	switch r.recordsType {
 	case legacyRecords:
 		if r.msgSet == nil {
@@ -38,7 +65,31 @@ func (r *Records) encode(pe packetEncoder) error {
 	return fmt.Errorf("unknown records type: %v", r.recordsType)
 }
 
+func (r *Records) setTypeFromMagic(pd packetDecoder) error {
+	dec, err := pd.peek(magicOffset, magicLength)
+	if err != nil {
+		return err
+	}
+
+	magic, err := dec.getInt8()
+	if err != nil {
+		return err
+	}
+
+	r.recordsType = defaultRecords
+	if magic < 2 {
+		r.recordsType = legacyRecords
+	}
+	return nil
+}
+
 func (r *Records) decode(pd packetDecoder) error {
+	if r.recordsType == unknownRecords {
+		if err := r.setTypeFromMagic(pd); err != nil {
+			return nil
+		}
+	}
+
 	switch r.recordsType {
 	case legacyRecords:
 		r.msgSet = &MessageSet{}
@@ -51,6 +102,12 @@ func (r *Records) decode(pd packetDecoder) error {
 }
 
 func (r *Records) numRecords() (int, error) {
+	if r.recordsType == unknownRecords {
+		if empty, err := r.setTypeFromFields(); err != nil || empty {
+			return 0, err
+		}
+	}
+
 	switch r.recordsType {
 	case legacyRecords:
 		if r.msgSet == nil {
@@ -67,7 +124,15 @@ func (r *Records) numRecords() (int, error) {
 }
 
 func (r *Records) isPartial() (bool, error) {
+	if r.recordsType == unknownRecords {
+		if empty, err := r.setTypeFromFields(); err != nil || empty {
+			return false, err
+		}
+	}
+
 	switch r.recordsType {
+	case unknownRecords:
+		return false, nil
 	case legacyRecords:
 		if r.msgSet == nil {
 			return false, nil
@@ -83,6 +148,12 @@ func (r *Records) isPartial() (bool, error) {
 }
 
 func (r *Records) isControl() (bool, error) {
+	if r.recordsType == unknownRecords {
+		if empty, err := r.setTypeFromFields(); err != nil || empty {
+			return false, err
+		}
+	}
+
 	switch r.recordsType {
 	case legacyRecords:
 		return false, nil

+ 8 - 2
records_test.go

@@ -31,7 +31,7 @@ func TestLegacyRecords(t *testing.T) {
 	}
 
 	set = &MessageSet{}
-	r = newLegacyRecords(nil)
+	r = Records{}
 
 	err = decode(exp, set)
 	if err != nil {
@@ -42,6 +42,9 @@ func TestLegacyRecords(t *testing.T) {
 		t.Fatal(err)
 	}
 
+	if r.recordsType != legacyRecords {
+		t.Fatalf("Wrong records type %v, expected %v", r.recordsType, legacyRecords)
+	}
 	if !reflect.DeepEqual(set, r.msgSet) {
 		t.Errorf("Wrong decoding for legacy records, wanted %#+v, got %#+v", set, r.msgSet)
 	}
@@ -96,7 +99,7 @@ func TestDefaultRecords(t *testing.T) {
 	}
 
 	batch = &RecordBatch{}
-	r = newDefaultRecords(nil)
+	r = Records{}
 
 	err = decode(exp, batch)
 	if err != nil {
@@ -107,6 +110,9 @@ func TestDefaultRecords(t *testing.T) {
 		t.Fatal(err)
 	}
 
+	if r.recordsType != defaultRecords {
+		t.Fatalf("Wrong records type %v, expected %v", r.recordsType, defaultRecords)
+	}
 	if !reflect.DeepEqual(batch, r.recordBatch) {
 		t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
 	}