Browse Source

Permit protocol ordering in JoinGroup requests

The kafka protocol for JoinGroupRequest allows you to specify priority
on group protocols for seamless rollout of new protocols while a
consumer group is running. The old map was not supporting that use case.

Add an array field (OrderedGroupProtocols) and deprecate the map.
Evan Huus 8 years ago
parent
commit
4351c00d43
3 changed files with 89 additions and 31 deletions
  1. 59 24
      join_group_request.go
  2. 21 5
      join_group_request_test.go
  3. 9 2
      request_test.go

+ 59 - 24
join_group_request.go

@@ -1,11 +1,36 @@
 package sarama
 
+type GroupProtocol struct {
+	Name     string
+	Metadata []byte
+}
+
+func (p *GroupProtocol) decode(pd packetDecoder) (err error) {
+	p.Name, err = pd.getString()
+	if err != nil {
+		return err
+	}
+	p.Metadata, err = pd.getBytes()
+	return err
+}
+
+func (p *GroupProtocol) encode(pe packetEncoder) (err error) {
+	if err := pe.putString(p.Name); err != nil {
+		return err
+	}
+	if err := pe.putBytes(p.Metadata); err != nil {
+		return err
+	}
+	return nil
+}
+
 type JoinGroupRequest struct {
-	GroupId        string
-	SessionTimeout int32
-	MemberId       string
-	ProtocolType   string
-	GroupProtocols map[string][]byte
+	GroupId               string
+	SessionTimeout        int32
+	MemberId              string
+	ProtocolType          string
+	GroupProtocols        map[string][]byte // deprecated; use OrderedGroupProtocols
+	OrderedGroupProtocols []*GroupProtocol
 }
 
 func (r *JoinGroupRequest) encode(pe packetEncoder) error {
@@ -20,16 +45,31 @@ func (r *JoinGroupRequest) encode(pe packetEncoder) error {
 		return err
 	}
 
-	if err := pe.putArrayLength(len(r.GroupProtocols)); err != nil {
-		return err
-	}
-	for name, metadata := range r.GroupProtocols {
-		if err := pe.putString(name); err != nil {
+	if len(r.GroupProtocols) > 0 {
+		if len(r.OrderedGroupProtocols) > 0 {
+			return PacketDecodingError{"cannot specify both GroupProtocols and OrderedGroupProtocols on JoinGroupRequest"}
+		}
+
+		if err := pe.putArrayLength(len(r.GroupProtocols)); err != nil {
 			return err
 		}
-		if err := pe.putBytes(metadata); err != nil {
+		for name, metadata := range r.GroupProtocols {
+			if err := pe.putString(name); err != nil {
+				return err
+			}
+			if err := pe.putBytes(metadata); err != nil {
+				return err
+			}
+		}
+	} else {
+		if err := pe.putArrayLength(len(r.OrderedGroupProtocols)); err != nil {
 			return err
 		}
+		for _, protocol := range r.OrderedGroupProtocols {
+			if err := protocol.encode(pe); err != nil {
+				return err
+			}
+		}
 	}
 
 	return nil
@@ -62,16 +102,12 @@ func (r *JoinGroupRequest) decode(pd packetDecoder, version int16) (err error) {
 
 	r.GroupProtocols = make(map[string][]byte)
 	for i := 0; i < n; i++ {
-		name, err := pd.getString()
-		if err != nil {
-			return err
-		}
-		metadata, err := pd.getBytes()
-		if err != nil {
+		protocol := &GroupProtocol{}
+		if err := protocol.decode(pd); err != nil {
 			return err
 		}
-
-		r.GroupProtocols[name] = metadata
+		r.GroupProtocols[protocol.Name] = protocol.Metadata
+		r.OrderedGroupProtocols = append(r.OrderedGroupProtocols, protocol)
 	}
 
 	return nil
@@ -90,11 +126,10 @@ func (r *JoinGroupRequest) requiredVersion() KafkaVersion {
 }
 
 func (r *JoinGroupRequest) AddGroupProtocol(name string, metadata []byte) {
-	if r.GroupProtocols == nil {
-		r.GroupProtocols = make(map[string][]byte)
-	}
-
-	r.GroupProtocols[name] = metadata
+	r.OrderedGroupProtocols = append(r.OrderedGroupProtocols, &GroupProtocol{
+		Name:     name,
+		Metadata: metadata,
+	})
 }
 
 func (r *JoinGroupRequest) AddGroupProtocolMetadata(name string, metadata *ConsumerGroupMemberMetadata) error {

+ 21 - 5
join_group_request_test.go

@@ -23,19 +23,35 @@ var (
 )
 
 func TestJoinGroupRequest(t *testing.T) {
-	var request *JoinGroupRequest
-
-	request = new(JoinGroupRequest)
+	request := new(JoinGroupRequest)
 	request.GroupId = "TestGroup"
 	request.SessionTimeout = 100
 	request.ProtocolType = "consumer"
 	testRequest(t, "no protocols", request, joinGroupRequestNoProtocols)
+}
+
+func TestJoinGroupRequestOneProtocol(t *testing.T) {
+	request := new(JoinGroupRequest)
+	request.GroupId = "TestGroup"
+	request.SessionTimeout = 100
+	request.MemberId = "OneProtocol"
+	request.ProtocolType = "consumer"
+	request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03})
+	packet := testRequestEncode(t, "one protocol", request, joinGroupRequestOneProtocol)
+	request.GroupProtocols = make(map[string][]byte)
+	request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03}
+	testRequestDecode(t, "one protocol", request, packet)
+}
 
-	request = new(JoinGroupRequest)
+func TestJoinGroupRequestDeprecatedEncode(t *testing.T) {
+	request := new(JoinGroupRequest)
 	request.GroupId = "TestGroup"
 	request.SessionTimeout = 100
 	request.MemberId = "OneProtocol"
 	request.ProtocolType = "consumer"
+	request.GroupProtocols = make(map[string][]byte)
+	request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03}
+	packet := testRequestEncode(t, "one protocol", request, joinGroupRequestOneProtocol)
 	request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03})
-	testRequest(t, "one protocol", request, joinGroupRequestOneProtocol)
+	testRequestDecode(t, "one protocol", request, packet)
 }

+ 9 - 2
request_test.go

@@ -50,7 +50,11 @@ func testVersionDecodable(t *testing.T, name string, out versionedDecoder, in []
 }
 
 func testRequest(t *testing.T, name string, rb protocolBody, expected []byte) {
-	// Encoder request
+	packet := testRequestEncode(t, name, rb, expected)
+	testRequestDecode(t, name, rb, packet)
+}
+
+func testRequestEncode(t *testing.T, name string, rb protocolBody, expected []byte) []byte {
 	req := &request{correlationID: 123, clientID: "foo", body: rb}
 	packet, err := encode(req, nil)
 	headerSize := 14 + len("foo")
@@ -59,7 +63,10 @@ func testRequest(t *testing.T, name string, rb protocolBody, expected []byte) {
 	} else if !bytes.Equal(packet[headerSize:], expected) {
 		t.Error("Encoding", name, "failed\ngot ", packet[headerSize:], "\nwant", expected)
 	}
-	// Decoder request
+	return packet
+}
+
+func testRequestDecode(t *testing.T, name string, rb protocolBody, packet []byte) {
 	decoded, n, err := decodeRequest(bytes.NewReader(packet))
 	if err != nil {
 		t.Error("Failed to decode request", err)