Pārlūkot izejas kodu

Added versioning for join-group requests and responses

Dimitrij Denissenko 6 gadi atpakaļ
vecāks
revīzija
e1b85f3400
4 mainītis faili ar 159 papildinājumiem un 19 dzēšanām
  1. 22 2
      join_group_request.go
  2. 34 8
      join_group_request_test.go
  3. 22 2
      join_group_response.go
  4. 81 7
      join_group_response_test.go

+ 22 - 2
join_group_request.go

@@ -25,8 +25,10 @@ func (p *GroupProtocol) encode(pe packetEncoder) (err error) {
 }
 
 type JoinGroupRequest struct {
+	Version               int16
 	GroupId               string
 	SessionTimeout        int32
+	RebalanceTimeout      int32
 	MemberId              string
 	ProtocolType          string
 	GroupProtocols        map[string][]byte // deprecated; use OrderedGroupProtocols
@@ -38,6 +40,9 @@ func (r *JoinGroupRequest) encode(pe packetEncoder) error {
 		return err
 	}
 	pe.putInt32(r.SessionTimeout)
+	if r.Version >= 1 {
+		pe.putInt32(r.RebalanceTimeout)
+	}
 	if err := pe.putString(r.MemberId); err != nil {
 		return err
 	}
@@ -76,6 +81,8 @@ func (r *JoinGroupRequest) encode(pe packetEncoder) error {
 }
 
 func (r *JoinGroupRequest) decode(pd packetDecoder, version int16) (err error) {
+	r.Version = version
+
 	if r.GroupId, err = pd.getString(); err != nil {
 		return
 	}
@@ -84,6 +91,12 @@ func (r *JoinGroupRequest) decode(pd packetDecoder, version int16) (err error) {
 		return
 	}
 
+	if version >= 1 {
+		if r.RebalanceTimeout, err = pd.getInt32(); err != nil {
+			return err
+		}
+	}
+
 	if r.MemberId, err = pd.getString(); err != nil {
 		return
 	}
@@ -118,11 +131,18 @@ func (r *JoinGroupRequest) key() int16 {
 }
 
 func (r *JoinGroupRequest) version() int16 {
-	return 0
+	return r.Version
 }
 
 func (r *JoinGroupRequest) requiredVersion() KafkaVersion {
-	return V0_9_0_0
+	switch r.Version {
+	case 2:
+		return V0_11_0_0
+	case 1:
+		return V0_10_1_0
+	default:
+		return V0_9_0_0
+	}
 }
 
 func (r *JoinGroupRequest) AddGroupProtocol(name string, metadata []byte) {

+ 34 - 8
join_group_request_test.go

@@ -3,7 +3,7 @@ package sarama
 import "testing"
 
 var (
-	joinGroupRequestNoProtocols = []byte{
+	joinGroupRequestV0_NoProtocols = []byte{
 		0, 9, 'T', 'e', 's', 't', 'G', 'r', 'o', 'u', 'p', // Group ID
 		0, 0, 0, 100, // Session timeout
 		0, 0, // Member ID
@@ -11,7 +11,7 @@ var (
 		0, 0, 0, 0, // 0 protocol groups
 	}
 
-	joinGroupRequestOneProtocol = []byte{
+	joinGroupRequestV0_OneProtocol = []byte{
 		0, 9, 'T', 'e', 's', 't', 'G', 'r', 'o', 'u', 'p', // Group ID
 		0, 0, 0, 100, // Session timeout
 		0, 11, 'O', 'n', 'e', 'P', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Member ID
@@ -20,6 +20,17 @@ var (
 		0, 3, 'o', 'n', 'e', // Protocol name
 		0, 0, 0, 3, 0x01, 0x02, 0x03, // protocol metadata
 	}
+
+	joinGroupRequestV1 = []byte{
+		0, 9, 'T', 'e', 's', 't', 'G', 'r', 'o', 'u', 'p', // Group ID
+		0, 0, 0, 100, // Session timeout
+		0, 0, 0, 200, // Rebalance timeout
+		0, 11, 'O', 'n', 'e', 'P', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Member ID
+		0, 8, 'c', 'o', 'n', 's', 'u', 'm', 'e', 'r', // Protocol Type
+		0, 0, 0, 1, // 1 group protocol
+		0, 3, 'o', 'n', 'e', // Protocol name
+		0, 0, 0, 3, 0x01, 0x02, 0x03, // protocol metadata
+	}
 )
 
 func TestJoinGroupRequest(t *testing.T) {
@@ -27,20 +38,20 @@ func TestJoinGroupRequest(t *testing.T) {
 	request.GroupId = "TestGroup"
 	request.SessionTimeout = 100
 	request.ProtocolType = "consumer"
-	testRequest(t, "no protocols", request, joinGroupRequestNoProtocols)
+	testRequest(t, "V0: no protocols", request, joinGroupRequestV0_NoProtocols)
 }
 
-func TestJoinGroupRequestOneProtocol(t *testing.T) {
+func TestJoinGroupRequestV0_OneProtocol(t *testing.T) {
 	request := new(JoinGroupRequest)
 	request.GroupId = "TestGroup"
 	request.SessionTimeout = 100
 	request.MemberId = "OneProtocol"
 	request.ProtocolType = "consumer"
 	request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03})
-	packet := testRequestEncode(t, "one protocol", request, joinGroupRequestOneProtocol)
+	packet := testRequestEncode(t, "V0: one protocol", request, joinGroupRequestV0_OneProtocol)
 	request.GroupProtocols = make(map[string][]byte)
 	request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03}
-	testRequestDecode(t, "one protocol", request, packet)
+	testRequestDecode(t, "V0: one protocol", request, packet)
 }
 
 func TestJoinGroupRequestDeprecatedEncode(t *testing.T) {
@@ -51,7 +62,22 @@ func TestJoinGroupRequestDeprecatedEncode(t *testing.T) {
 	request.ProtocolType = "consumer"
 	request.GroupProtocols = make(map[string][]byte)
 	request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03}
-	packet := testRequestEncode(t, "one protocol", request, joinGroupRequestOneProtocol)
+	packet := testRequestEncode(t, "V0: one protocol", request, joinGroupRequestV0_OneProtocol)
 	request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03})
-	testRequestDecode(t, "one protocol", request, packet)
+	testRequestDecode(t, "V0: one protocol", request, packet)
+}
+
+func TestJoinGroupRequestV1(t *testing.T) {
+	request := new(JoinGroupRequest)
+	request.Version = 1
+	request.GroupId = "TestGroup"
+	request.SessionTimeout = 100
+	request.RebalanceTimeout = 200
+	request.MemberId = "OneProtocol"
+	request.ProtocolType = "consumer"
+	request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03})
+	packet := testRequestEncode(t, "V1", request, joinGroupRequestV1)
+	request.GroupProtocols = make(map[string][]byte)
+	request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03}
+	testRequestDecode(t, "V1", request, packet)
 }

+ 22 - 2
join_group_response.go

@@ -1,6 +1,8 @@
 package sarama
 
 type JoinGroupResponse struct {
+	Version       int16
+	ThrottleTime  int32
 	Err           KError
 	GenerationId  int32
 	GroupProtocol string
@@ -22,6 +24,9 @@ func (r *JoinGroupResponse) GetMembers() (map[string]ConsumerGroupMemberMetadata
 }
 
 func (r *JoinGroupResponse) encode(pe packetEncoder) error {
+	if r.Version >= 2 {
+		pe.putInt32(r.ThrottleTime)
+	}
 	pe.putInt16(int16(r.Err))
 	pe.putInt32(r.GenerationId)
 
@@ -53,6 +58,14 @@ func (r *JoinGroupResponse) encode(pe packetEncoder) error {
 }
 
 func (r *JoinGroupResponse) decode(pd packetDecoder, version int16) (err error) {
+	r.Version = version
+
+	if version >= 2 {
+		if r.ThrottleTime, err = pd.getInt32(); err != nil {
+			return
+		}
+	}
+
 	kerr, err := pd.getInt16()
 	if err != nil {
 		return err
@@ -107,9 +120,16 @@ func (r *JoinGroupResponse) key() int16 {
 }
 
 func (r *JoinGroupResponse) version() int16 {
-	return 0
+	return r.Version
 }
 
 func (r *JoinGroupResponse) requiredVersion() KafkaVersion {
-	return V0_9_0_0
+	switch r.Version {
+	case 2:
+		return V0_11_0_0
+	case 1:
+		return V0_10_1_0
+	default:
+		return V0_9_0_0
+	}
 }

+ 81 - 7
join_group_response_test.go

@@ -6,7 +6,7 @@ import (
 )
 
 var (
-	joinGroupResponseNoError = []byte{
+	joinGroupResponseV0_NoError = []byte{
 		0x00, 0x00, // No error
 		0x00, 0x01, 0x02, 0x03, // Generation ID
 		0, 8, 'p', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Protocol name chosen
@@ -15,7 +15,7 @@ var (
 		0, 0, 0, 0, // No member info
 	}
 
-	joinGroupResponseWithError = []byte{
+	joinGroupResponseV0_WithError = []byte{
 		0, 23, // Error: inconsistent group protocol
 		0x00, 0x00, 0x00, 0x00, // Generation ID
 		0, 0, // Protocol name chosen
@@ -24,7 +24,7 @@ var (
 		0, 0, 0, 0, // No member info
 	}
 
-	joinGroupResponseLeader = []byte{
+	joinGroupResponseV0_Leader = []byte{
 		0x00, 0x00, // No error
 		0x00, 0x01, 0x02, 0x03, // Generation ID
 		0, 8, 'p', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Protocol name chosen
@@ -34,13 +34,32 @@ var (
 		0, 3, 'f', 'o', 'o', // Member ID
 		0, 0, 0, 3, 0x01, 0x02, 0x03, // Member metadata
 	}
+
+	joinGroupResponseV1 = []byte{
+		0x00, 0x00, // No error
+		0x00, 0x01, 0x02, 0x03, // Generation ID
+		0, 8, 'p', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Protocol name chosen
+		0, 3, 'f', 'o', 'o', // Leader ID
+		0, 3, 'b', 'a', 'r', // Member ID
+		0, 0, 0, 0, // No member info
+	}
+
+	joinGroupResponseV2 = []byte{
+		0, 0, 0, 100,
+		0x00, 0x00, // No error
+		0x00, 0x01, 0x02, 0x03, // Generation ID
+		0, 8, 'p', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Protocol name chosen
+		0, 3, 'f', 'o', 'o', // Leader ID
+		0, 3, 'b', 'a', 'r', // Member ID
+		0, 0, 0, 0, // No member info
+	}
 )
 
-func TestJoinGroupResponse(t *testing.T) {
+func TestJoinGroupResponseV0(t *testing.T) {
 	var response *JoinGroupResponse
 
 	response = new(JoinGroupResponse)
-	testVersionDecodable(t, "no error", response, joinGroupResponseNoError, 0)
+	testVersionDecodable(t, "no error", response, joinGroupResponseV0_NoError, 0)
 	if response.Err != ErrNoError {
 		t.Error("Decoding Err failed: no error expected but found", response.Err)
 	}
@@ -58,7 +77,7 @@ func TestJoinGroupResponse(t *testing.T) {
 	}
 
 	response = new(JoinGroupResponse)
-	testVersionDecodable(t, "with error", response, joinGroupResponseWithError, 0)
+	testVersionDecodable(t, "with error", response, joinGroupResponseV0_WithError, 0)
 	if response.Err != ErrInconsistentGroupProtocol {
 		t.Error("Decoding Err failed: ErrInconsistentGroupProtocol expected but found", response.Err)
 	}
@@ -76,7 +95,7 @@ func TestJoinGroupResponse(t *testing.T) {
 	}
 
 	response = new(JoinGroupResponse)
-	testVersionDecodable(t, "with error", response, joinGroupResponseLeader, 0)
+	testVersionDecodable(t, "with error", response, joinGroupResponseV0_Leader, 0)
 	if response.Err != ErrNoError {
 		t.Error("Decoding Err failed: ErrNoError expected but found", response.Err)
 	}
@@ -96,3 +115,58 @@ func TestJoinGroupResponse(t *testing.T) {
 		t.Error("Decoding foo member failed, found:", response.Members["foo"])
 	}
 }
+
+func TestJoinGroupResponseV1(t *testing.T) {
+	response := new(JoinGroupResponse)
+	testVersionDecodable(t, "no error", response, joinGroupResponseV1, 1)
+	if response.Err != ErrNoError {
+		t.Error("Decoding Err failed: no error expected but found", response.Err)
+	}
+	if response.GenerationId != 66051 {
+		t.Error("Decoding GenerationId failed, found:", response.GenerationId)
+	}
+	if response.GroupProtocol != "protocol" {
+		t.Error("Decoding GroupProtocol failed, found:", response.GroupProtocol)
+	}
+	if response.LeaderId != "foo" {
+		t.Error("Decoding LeaderId failed, found:", response.LeaderId)
+	}
+	if response.MemberId != "bar" {
+		t.Error("Decoding MemberId failed, found:", response.MemberId)
+	}
+	if response.Version != 1 {
+		t.Error("Decoding Version failed, found:", response.Version)
+	}
+	if len(response.Members) != 0 {
+		t.Error("Decoding Members failed, found:", response.Members)
+	}
+}
+
+func TestJoinGroupResponseV2(t *testing.T) {
+	response := new(JoinGroupResponse)
+	testVersionDecodable(t, "no error", response, joinGroupResponseV2, 2)
+	if response.ThrottleTime != 100 {
+		t.Error("Decoding ThrottleTime failed, found:", response.ThrottleTime)
+	}
+	if response.Err != ErrNoError {
+		t.Error("Decoding Err failed: no error expected but found", response.Err)
+	}
+	if response.GenerationId != 66051 {
+		t.Error("Decoding GenerationId failed, found:", response.GenerationId)
+	}
+	if response.GroupProtocol != "protocol" {
+		t.Error("Decoding GroupProtocol failed, found:", response.GroupProtocol)
+	}
+	if response.LeaderId != "foo" {
+		t.Error("Decoding LeaderId failed, found:", response.LeaderId)
+	}
+	if response.MemberId != "bar" {
+		t.Error("Decoding MemberId failed, found:", response.MemberId)
+	}
+	if response.Version != 2 {
+		t.Error("Decoding Version failed, found:", response.Version)
+	}
+	if len(response.Members) != 0 {
+		t.Error("Decoding Members failed, found:", response.Members)
+	}
+}