Browse Source

Add decode method to request types

Maxim Vladimirsky 10 years ago
parent
commit
aa411f16fd

+ 5 - 5
broker.go

@@ -234,7 +234,7 @@ func (b *Broker) FetchOffset(request *OffsetFetchRequest) (*OffsetFetchResponse,
 	return response, nil
 }
 
-func (b *Broker) send(req requestEncoder, promiseResponse bool) (*responsePromise, error) {
+func (b *Broker) send(rb requestBody, promiseResponse bool) (*responsePromise, error) {
 	b.lock.Lock()
 	defer b.lock.Unlock()
 
@@ -245,8 +245,8 @@ func (b *Broker) send(req requestEncoder, promiseResponse bool) (*responsePromis
 		return nil, ErrNotConnected
 	}
 
-	fullRequest := request{b.correlationID, b.conf.ClientID, req}
-	buf, err := encode(&fullRequest)
+	req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
+	buf, err := encode(req)
 	if err != nil {
 		return nil, err
 	}
@@ -266,13 +266,13 @@ func (b *Broker) send(req requestEncoder, promiseResponse bool) (*responsePromis
 		return nil, nil
 	}
 
-	promise := responsePromise{fullRequest.correlationID, make(chan []byte), make(chan error)}
+	promise := responsePromise{req.correlationID, make(chan []byte), make(chan error)}
 	b.responses <- promise
 
 	return &promise, nil
 }
 
-func (b *Broker) sendAndReceive(req requestEncoder, res decoder) error {
+func (b *Broker) sendAndReceive(req requestBody, res decoder) error {
 	promise, err := b.send(req, res != nil)
 
 	if err != nil {

+ 5 - 0
consumer_metadata_request.go

@@ -8,6 +8,11 @@ func (r *ConsumerMetadataRequest) encode(pe packetEncoder) error {
 	return pe.putString(r.ConsumerGroup)
 }
 
+func (r *ConsumerMetadataRequest) decode(pd packetDecoder) (err error) {
+	r.ConsumerGroup, err = pd.getString()
+	return err
+}
+
 func (r *ConsumerMetadataRequest) key() int16 {
 	return 10
 }

+ 2 - 2
consumer_metadata_request_test.go

@@ -12,8 +12,8 @@ var (
 
 func TestConsumerMetadataRequest(t *testing.T) {
 	request := new(ConsumerMetadataRequest)
-	testEncodable(t, "empty string", request, consumerMetadataRequestEmpty)
+	testRequest(t, "empty string", request, consumerMetadataRequestEmpty)
 
 	request.ConsumerGroup = "foobar"
-	testEncodable(t, "with string", request, consumerMetadataRequestString)
+	testRequest(t, "with string", request, consumerMetadataRequestString)
 }

+ 4 - 4
encoder_decoder.go

@@ -9,15 +9,15 @@ type encoder interface {
 }
 
 // Encode takes an Encoder and turns it into bytes.
-func encode(in encoder) ([]byte, error) {
-	if in == nil {
+func encode(e encoder) ([]byte, error) {
+	if e == nil {
 		return nil, nil
 	}
 
 	var prepEnc prepEncoder
 	var realEnc realEncoder
 
-	err := in.encode(&prepEnc)
+	err := e.encode(&prepEnc)
 	if err != nil {
 		return nil, err
 	}
@@ -27,7 +27,7 @@ func encode(in encoder) ([]byte, error) {
 	}
 
 	realEnc.raw = make([]byte, prepEnc.length)
-	err = in.encode(&realEnc)
+	err = e.encode(&realEnc)
 	if err != nil {
 		return nil, err
 	}

+ 53 - 0
fetch_request.go

@@ -11,6 +11,16 @@ func (f *fetchRequestBlock) encode(pe packetEncoder) error {
 	return nil
 }
 
+func (f *fetchRequestBlock) decode(pd packetDecoder) (err error) {
+	if f.fetchOffset, err = pd.getInt64(); err != nil {
+		return err
+	}
+	if f.maxBytes, err = pd.getInt32(); err != nil {
+		return err
+	}
+	return nil
+}
+
 type FetchRequest struct {
 	MaxWaitTime int32
 	MinBytes    int32
@@ -45,6 +55,49 @@ func (f *FetchRequest) encode(pe packetEncoder) (err error) {
 	return nil
 }
 
+func (f *FetchRequest) decode(pd packetDecoder) (err error) {
+	if _, err = pd.getInt32(); err != nil {
+		return err
+	}
+	if f.MaxWaitTime, err = pd.getInt32(); err != nil {
+		return err
+	}
+	if f.MinBytes, err = pd.getInt32(); err != nil {
+		return err
+	}
+	topicCount, err := pd.getArrayLength()
+	if err != nil {
+		return err
+	}
+	if topicCount == 0 {
+		return nil
+	}
+	f.blocks = make(map[string]map[int32]*fetchRequestBlock)
+	for i := 0; i < topicCount; i++ {
+		topic, err := pd.getString()
+		if err != nil {
+			return err
+		}
+		partitionCount, err := pd.getArrayLength()
+		if err != nil {
+			return err
+		}
+		f.blocks[topic] = make(map[int32]*fetchRequestBlock)
+		for j := 0; j < partitionCount; j++ {
+			partition, err := pd.getInt32()
+			if err != nil {
+				return err
+			}
+			fetchBlock := &fetchRequestBlock{}
+			if err = fetchBlock.decode(pd); err != nil {
+				return nil
+			}
+			f.blocks[topic][partition] = fetchBlock
+		}
+	}
+	return nil
+}
+
 func (f *FetchRequest) key() int16 {
 	return 1
 }

+ 3 - 3
fetch_request_test.go

@@ -21,14 +21,14 @@ var (
 
 func TestFetchRequest(t *testing.T) {
 	request := new(FetchRequest)
-	testEncodable(t, "no blocks", request, fetchRequestNoBlocks)
+	testRequest(t, "no blocks", request, fetchRequestNoBlocks)
 
 	request.MaxWaitTime = 0x20
 	request.MinBytes = 0xEF
-	testEncodable(t, "with properties", request, fetchRequestWithProperties)
+	testRequest(t, "with properties", request, fetchRequestWithProperties)
 
 	request.MaxWaitTime = 0
 	request.MinBytes = 0
 	request.AddBlock("topic", 0x12, 0x34, 0x56)
-	testEncodable(t, "one block", request, fetchRequestOneBlock)
+	testRequest(t, "one block", request, fetchRequestOneBlock)
 }

+ 20 - 0
metadata_request.go

@@ -19,6 +19,26 @@ func (mr *MetadataRequest) encode(pe packetEncoder) error {
 	return nil
 }
 
+func (mr *MetadataRequest) decode(pd packetDecoder) error {
+	topicCount, err := pd.getArrayLength()
+	if err != nil {
+		return err
+	}
+	if topicCount == 0 {
+		return nil
+	}
+
+	mr.Topics = make([]string, topicCount)
+	for i := range mr.Topics {
+		topic, err := pd.getString()
+		if err != nil {
+			return err
+		}
+		mr.Topics[i] = topic
+	}
+	return nil
+}
+
 func (mr *MetadataRequest) key() int16 {
 	return 3
 }

+ 3 - 3
metadata_request_test.go

@@ -19,11 +19,11 @@ var (
 
 func TestMetadataRequest(t *testing.T) {
 	request := new(MetadataRequest)
-	testEncodable(t, "no topics", request, metadataRequestNoTopics)
+	testRequest(t, "no topics", request, metadataRequestNoTopics)
 
 	request.Topics = []string{"topic1"}
-	testEncodable(t, "one topic", request, metadataRequestOneTopic)
+	testRequest(t, "one topic", request, metadataRequestOneTopic)
 
 	request.Topics = []string{"foo", "bar", "baz"}
-	testEncodable(t, "three topics", request, metadataRequestThreeTopics)
+	testRequest(t, "three topics", request, metadataRequestThreeTopics)
 }

+ 66 - 0
offset_commit_request.go

@@ -22,6 +22,19 @@ func (r *offsetCommitRequestBlock) encode(pe packetEncoder, version int16) error
 	return pe.putString(r.metadata)
 }
 
+func (r *offsetCommitRequestBlock) decode(pd packetDecoder, version int16) (err error) {
+	if r.offset, err = pd.getInt64(); err != nil {
+		return err
+	}
+	if version == 1 {
+		if r.timestamp, err = pd.getInt64(); err != nil {
+			return err
+		}
+	}
+	r.metadata, err = pd.getString()
+	return err
+}
+
 type OffsetCommitRequest struct {
 	ConsumerGroup           string
 	ConsumerGroupGeneration int32  // v1 or later
@@ -85,6 +98,59 @@ func (r *OffsetCommitRequest) encode(pe packetEncoder) error {
 	return nil
 }
 
+func (r *OffsetCommitRequest) decode(pd packetDecoder) (err error) {
+	if r.ConsumerGroup, err = pd.getString(); err != nil {
+		return err
+	}
+
+	if r.Version >= 1 {
+		if r.ConsumerGroupGeneration, err = pd.getInt32(); err != nil {
+			return err
+		}
+		if r.ConsumerID, err = pd.getString(); err != nil {
+			return err
+		}
+	}
+
+	if r.Version >= 2 {
+		if r.RetentionTime, err = pd.getInt64(); err != nil {
+			return err
+		}
+	}
+
+	topicCount, err := pd.getArrayLength()
+	if err != nil {
+		return err
+	}
+	if topicCount == 0 {
+		return nil
+	}
+	r.blocks = make(map[string]map[int32]*offsetCommitRequestBlock)
+	for i := 0; i < topicCount; i++ {
+		topic, err := pd.getString()
+		if err != nil {
+			return err
+		}
+		partitionCount, err := pd.getArrayLength()
+		if err != nil {
+			return err
+		}
+		r.blocks[topic] = make(map[int32]*offsetCommitRequestBlock)
+		for j := 0; j < partitionCount; j++ {
+			partition, err := pd.getInt32()
+			if err != nil {
+				return err
+			}
+			block := &offsetCommitRequestBlock{}
+			if err := block.decode(pd, r.Version); err != nil {
+				return err
+			}
+			r.blocks[topic][partition] = block
+		}
+	}
+	return nil
+}
+
 func (r *OffsetCommitRequest) key() int16 {
 	return 8
 }

+ 6 - 6
offset_commit_request_test.go

@@ -58,10 +58,10 @@ func TestOffsetCommitRequestV0(t *testing.T) {
 	request := new(OffsetCommitRequest)
 	request.Version = 0
 	request.ConsumerGroup = "foobar"
-	testEncodable(t, "no blocks v0", request, offsetCommitRequestNoBlocksV0)
+	testRequest(t, "no blocks v0", request, offsetCommitRequestNoBlocksV0)
 
 	request.AddBlock("topic", 0x5221, 0xDEADBEEF, 0, "metadata")
-	testEncodable(t, "one block v0", request, offsetCommitRequestOneBlockV0)
+	testRequest(t, "one block v0", request, offsetCommitRequestOneBlockV0)
 }
 
 func TestOffsetCommitRequestV1(t *testing.T) {
@@ -70,10 +70,10 @@ func TestOffsetCommitRequestV1(t *testing.T) {
 	request.ConsumerID = "cons"
 	request.ConsumerGroupGeneration = 0x1122
 	request.Version = 1
-	testEncodable(t, "no blocks v1", request, offsetCommitRequestNoBlocksV1)
+	testRequest(t, "no blocks v1", request, offsetCommitRequestNoBlocksV1)
 
 	request.AddBlock("topic", 0x5221, 0xDEADBEEF, ReceiveTime, "metadata")
-	testEncodable(t, "one block v1", request, offsetCommitRequestOneBlockV1)
+	testRequest(t, "one block v1", request, offsetCommitRequestOneBlockV1)
 }
 
 func TestOffsetCommitRequestV2(t *testing.T) {
@@ -83,8 +83,8 @@ func TestOffsetCommitRequestV2(t *testing.T) {
 	request.ConsumerGroupGeneration = 0x1122
 	request.RetentionTime = 0x4433
 	request.Version = 2
-	testEncodable(t, "no blocks v2", request, offsetCommitRequestNoBlocksV2)
+	testRequest(t, "no blocks v2", request, offsetCommitRequestNoBlocksV2)
 
 	request.AddBlock("topic", 0x5221, 0xDEADBEEF, 0, "metadata")
-	testEncodable(t, "one block v2", request, offsetCommitRequestOneBlockV2)
+	testRequest(t, "one block v2", request, offsetCommitRequestOneBlockV2)
 }

+ 26 - 0
offset_fetch_request.go

@@ -28,6 +28,32 @@ func (r *OffsetFetchRequest) encode(pe packetEncoder) (err error) {
 	return nil
 }
 
+func (r *OffsetFetchRequest) decode(pd packetDecoder) (err error) {
+	if r.ConsumerGroup, err = pd.getString(); err != nil {
+		return err
+	}
+	partitionCount, err := pd.getArrayLength()
+	if err != nil {
+		return err
+	}
+	if partitionCount == 0 {
+		return nil
+	}
+	r.partitions = make(map[string][]int32)
+	for i := 0; i < partitionCount; i++ {
+		topic, err := pd.getString()
+		if err != nil {
+			return err
+		}
+		partitions, err := pd.getInt32Array()
+		if err != nil {
+			return err
+		}
+		r.partitions[topic] = partitions
+	}
+	return nil
+}
+
 func (r *OffsetFetchRequest) key() int16 {
 	return 9
 }

+ 3 - 3
offset_fetch_request_test.go

@@ -21,11 +21,11 @@ var (
 
 func TestOffsetFetchRequest(t *testing.T) {
 	request := new(OffsetFetchRequest)
-	testEncodable(t, "no group, no partitions", request, offsetFetchRequestNoGroupNoPartitions)
+	testRequest(t, "no group, no partitions", request, offsetFetchRequestNoGroupNoPartitions)
 
 	request.ConsumerGroup = "blah"
-	testEncodable(t, "no partitions", request, offsetFetchRequestNoPartitions)
+	testRequest(t, "no partitions", request, offsetFetchRequestNoPartitions)
 
 	request.AddPartition("topicTheFirst", 0x4F4F4F4F)
-	testEncodable(t, "one partition", request, offsetFetchRequestOnePartition)
+	testRequest(t, "one partition", request, offsetFetchRequestOnePartition)
 }

+ 48 - 1
offset_request.go

@@ -11,6 +11,16 @@ func (r *offsetRequestBlock) encode(pe packetEncoder) error {
 	return nil
 }
 
+func (r *offsetRequestBlock) decode(pd packetDecoder) (err error) {
+	if r.time, err = pd.getInt64(); err != nil {
+		return err
+	}
+	if r.maxOffsets, err = pd.getInt32(); err != nil {
+		return err
+	}
+	return nil
+}
+
 type OffsetRequest struct {
 	blocks map[string]map[int32]*offsetRequestBlock
 }
@@ -32,10 +42,47 @@ func (r *OffsetRequest) encode(pe packetEncoder) error {
 		}
 		for partition, block := range partitions {
 			pe.putInt32(partition)
-			err = block.encode(pe)
+			if err = block.encode(pe); err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
+func (r *OffsetRequest) decode(pd packetDecoder) error {
+	// Ignore replica ID
+	if _, err := pd.getInt32(); err != nil {
+		return err
+	}
+	blockCount, err := pd.getArrayLength()
+	if err != nil {
+		return err
+	}
+	if blockCount == 0 {
+		return nil
+	}
+	r.blocks = make(map[string]map[int32]*offsetRequestBlock)
+	for i := 0; i < blockCount; i++ {
+		topic, err := pd.getString()
+		if err != nil {
+			return err
+		}
+		partitionCount, err := pd.getArrayLength()
+		if err != nil {
+			return err
+		}
+		r.blocks[topic] = make(map[int32]*offsetRequestBlock)
+		for j := 0; j < partitionCount; j++ {
+			partition, err := pd.getInt32()
 			if err != nil {
 				return err
 			}
+			block := &offsetRequestBlock{}
+			if err := block.decode(pd); err != nil {
+				return err
+			}
+			r.blocks[topic][partition] = block
 		}
 	}
 	return nil

+ 2 - 2
offset_request_test.go

@@ -19,8 +19,8 @@ var (
 
 func TestOffsetRequest(t *testing.T) {
 	request := new(OffsetRequest)
-	testEncodable(t, "no blocks", request, offsetRequestNoBlocks)
+	testRequest(t, "no blocks", request, offsetRequestNoBlocks)
 
 	request.AddBlock("foo", 4, 1, 2)
-	testEncodable(t, "one block", request, offsetRequestOneBlock)
+	testRequest(t, "one block", request, offsetRequestOneBlock)
 }

+ 50 - 0
produce_request.go

@@ -54,6 +54,56 @@ func (p *ProduceRequest) encode(pe packetEncoder) error {
 	return nil
 }
 
+func (p *ProduceRequest) decode(pd packetDecoder) error {
+	requiredAcks, err := pd.getInt16()
+	if err != nil {
+		return err
+	}
+	p.RequiredAcks = RequiredAcks(requiredAcks)
+	if p.Timeout, err = pd.getInt32(); err != nil {
+		return err
+	}
+	topicCount, err := pd.getArrayLength()
+	if err != nil {
+		return err
+	}
+	if topicCount == 0 {
+		return nil
+	}
+	p.msgSets = make(map[string]map[int32]*MessageSet)
+	for i := 0; i < topicCount; i++ {
+		topic, err := pd.getString()
+		if err != nil {
+			return err
+		}
+		partitionCount, err := pd.getArrayLength()
+		if err != nil {
+			return err
+		}
+		p.msgSets[topic] = make(map[int32]*MessageSet)
+		for j := 0; j < partitionCount; j++ {
+			partition, err := pd.getInt32()
+			if err != nil {
+				return err
+			}
+			messageSetSize, err := pd.getInt32()
+			if err != nil {
+				return err
+			}
+			if messageSetSize == 0 {
+				continue
+			}
+			msgSet := &MessageSet{}
+			err = msgSet.decode(pd)
+			if err != nil {
+				return err
+			}
+			p.msgSets[topic][partition] = msgSet
+		}
+	}
+	return nil
+}
+
 func (p *ProduceRequest) key() int16 {
 	return 0
 }

+ 3 - 3
produce_request_test.go

@@ -36,12 +36,12 @@ var (
 
 func TestProduceRequest(t *testing.T) {
 	request := new(ProduceRequest)
-	testEncodable(t, "empty", request, produceRequestEmpty)
+	testRequest(t, "empty", request, produceRequestEmpty)
 
 	request.RequiredAcks = 0x123
 	request.Timeout = 0x444
-	testEncodable(t, "header", request, produceRequestHeader)
+	testRequest(t, "header", request, produceRequestHeader)
 
 	request.AddMessage("topic", 0xAD, &Message{Codec: CompressionNone, Key: nil, Value: []byte{0x00, 0xEE}})
-	testEncodable(t, "one message", request, produceRequestOneMessage)
+	testRequest(t, "one message", request, produceRequestOneMessage)
 }

+ 75 - 4
request.go

@@ -1,15 +1,22 @@
 package sarama
 
-type requestEncoder interface {
+import (
+	"encoding/binary"
+	"fmt"
+	"io"
+)
+
+type requestBody interface {
 	encoder
+	decoder
 	key() int16
 	version() int16
 }
 
 type request struct {
 	correlationID int32
-	id            string
-	body          requestEncoder
+	clientID      string
+	body          requestBody
 }
 
 func (r *request) encode(pe packetEncoder) (err error) {
@@ -17,7 +24,7 @@ func (r *request) encode(pe packetEncoder) (err error) {
 	pe.putInt16(r.body.key())
 	pe.putInt16(r.body.version())
 	pe.putInt32(r.correlationID)
-	err = pe.putString(r.id)
+	err = pe.putString(r.clientID)
 	if err != nil {
 		return err
 	}
@@ -27,3 +34,67 @@ func (r *request) encode(pe packetEncoder) (err error) {
 	}
 	return pe.pop()
 }
+
+func (r *request) decode(pd packetDecoder) (err error) {
+	var key int16
+	if key, err = pd.getInt16(); err != nil {
+		return err
+	}
+	var version int16
+	if version, err = pd.getInt16(); err != nil {
+		return err
+	}
+	if r.correlationID, err = pd.getInt32(); err != nil {
+		return err
+	}
+	r.clientID, err = pd.getString()
+
+	r.body = allocateBody(key, version)
+	if r.body == nil {
+		return PacketDecodingError{fmt.Sprintf("Unknown request key: %d", key)}
+	}
+	return r.body.decode(pd)
+}
+
+func decodeRequest(r io.Reader) (req *request, err error) {
+	lengthBytes := make([]byte, 4)
+	if _, err := io.ReadFull(r, lengthBytes); err != nil {
+		return nil, err
+	}
+
+	length := int32(binary.BigEndian.Uint32(lengthBytes))
+	if length <= 4 || length > MaxRequestSize {
+		return nil, PacketDecodingError{fmt.Sprintf("Message of length %d too large or too small", length)}
+	}
+
+	encodedReq := make([]byte, length)
+	if _, err := io.ReadFull(r, encodedReq); err != nil {
+		return nil, err
+	}
+
+	req = &request{}
+	if err := decode(encodedReq, req); err != nil {
+		return nil, err
+	}
+	return req, nil
+}
+
+func allocateBody(key, version int16) requestBody {
+	switch key {
+	case 0:
+		return &ProduceRequest{}
+	case 1:
+		return &FetchRequest{}
+	case 2:
+		return &OffsetRequest{}
+	case 3:
+		return &MetadataRequest{}
+	case 8:
+		return &OffsetCommitRequest{Version: version}
+	case 9:
+		return &OffsetFetchRequest{}
+	case 10:
+		return &ConsumerMetadataRequest{}
+	}
+	return nil
+}

+ 22 - 15
request_test.go

@@ -2,19 +2,10 @@ package sarama
 
 import (
 	"bytes"
+	"reflect"
 	"testing"
 )
 
-var (
-	requestSimple = []byte{
-		0x00, 0x00, 0x00, 0x17, // msglen
-		0x06, 0x66,
-		0x00, 0xD2,
-		0x00, 0x00, 0x12, 0x34,
-		0x00, 0x08, 'm', 'y', 'C', 'l', 'i', 'e', 'n', 't',
-		0x00, 0x03, 'a', 'b', 'c'}
-)
-
 type testRequestBody struct {
 }
 
@@ -30,11 +21,6 @@ func (s *testRequestBody) encode(pe packetEncoder) error {
 	return pe.putString("abc")
 }
 
-func TestRequest(t *testing.T) {
-	request := request{correlationID: 0x1234, id: "myClient", body: new(testRequestBody)}
-	testEncodable(t, "simple", &request, requestSimple)
-}
-
 // not specific to request tests, just helper functions for testing structures that
 // implement the encoder or decoder interfaces that needed somewhere to live
 
@@ -53,3 +39,24 @@ func testDecodable(t *testing.T, name string, out decoder, in []byte) {
 		t.Error("Decoding", name, "failed:", err)
 	}
 }
+
+func testRequest(t *testing.T, name string, rb requestBody, expected []byte) {
+	// Encoder request
+	req := &request{correlationID: 123, clientID: "foo", body: rb}
+	packet, err := encode(req)
+	headerSize := 14 + len("foo")
+	if err != nil {
+		t.Error(err)
+	} else if !bytes.Equal(packet[headerSize:], expected) {
+		t.Error("Encoding", name, "failed\ngot ", packet, "\nwant", expected)
+	}
+	// Decoder request
+	decoded, err := decodeRequest(bytes.NewReader(packet))
+	if err != nil {
+		t.Error("Failed to decode request", err)
+	} else if decoded.correlationID != 123 || decoded.clientID != "foo" {
+		t.Errorf("Decoded header is not valid: %v", decoded)
+	} else if !reflect.DeepEqual(rb, decoded.body) {
+		t.Errorf("Decoded request does not match the encoded one\n    encoded: %v\n    decoded: %v", rb, decoded)
+	}
+}