Browse Source

add create partitions

Robin 8 years ago
parent
commit
b3f149d4d2

+ 119 - 0
create_partitions_request.go

@@ -0,0 +1,119 @@
+package sarama
+
+import "time"
+
+type CreatePartitionsRequest struct {
+	TopicPartitions map[string]*TopicPartition
+	Timeout         time.Duration
+	ValidateOnly    bool
+}
+
+func (c *CreatePartitionsRequest) encode(pe packetEncoder) error {
+	if err := pe.putArrayLength(len(c.TopicPartitions)); err != nil {
+		return err
+	}
+
+	for topic, partition := range c.TopicPartitions {
+		if err := pe.putString(topic); err != nil {
+			return err
+		}
+		if err := partition.encode(pe); err != nil {
+			return err
+		}
+	}
+
+	pe.putInt32(int32(c.Timeout / time.Millisecond))
+
+	pe.putBool(c.ValidateOnly)
+
+	return nil
+}
+
+func (c *CreatePartitionsRequest) decode(pd packetDecoder, version int16) (err error) {
+	n, err := pd.getArrayLength()
+	if err != nil {
+		return err
+	}
+	c.TopicPartitions = make(map[string]*TopicPartition, n)
+	for i := 0; i < n; i++ {
+		topic, err := pd.getString()
+		if err != nil {
+			return err
+		}
+		c.TopicPartitions[topic] = new(TopicPartition)
+		if err := c.TopicPartitions[topic].decode(pd, version); err != nil {
+			return err
+		}
+	}
+
+	timeout, err := pd.getInt32()
+	if err != nil {
+		return err
+	}
+	c.Timeout = time.Duration(timeout) * time.Millisecond
+
+	if c.ValidateOnly, err = pd.getBool(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (r *CreatePartitionsRequest) key() int16 {
+	return 37
+}
+
+func (r *CreatePartitionsRequest) version() int16 {
+	return 0
+}
+
+func (r *CreatePartitionsRequest) requiredVersion() KafkaVersion {
+	return V1_0_0_0
+}
+
+type TopicPartition struct {
+	Count      int32
+	Assignment [][]int32
+}
+
+func (t *TopicPartition) encode(pe packetEncoder) error {
+	pe.putInt32(t.Count)
+
+	if len(t.Assignment) == 0 {
+		pe.putInt32(-1)
+		return nil
+	}
+
+	pe.putInt32(int32(len(t.Assignment)))
+
+	for _, assign := range t.Assignment {
+		if err := pe.putInt32Array(assign); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (t *TopicPartition) decode(pd packetDecoder, version int16) (err error) {
+	if t.Count, err = pd.getInt32(); err != nil {
+		return err
+	}
+
+	n, err := pd.getInt32()
+	if err != nil {
+		return err
+	}
+	if n <= 0 {
+		return nil
+	}
+	t.Assignment = make([][]int32, n)
+
+	for i := 0; i < int(n); i++ {
+		if t.Assignment[i], err = pd.getInt32Array(); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}

+ 50 - 0
create_partitions_request_test.go

@@ -0,0 +1,50 @@
+package sarama
+
+import (
+	"testing"
+	"time"
+)
+
+var (
+	createPartitionRequestNoAssignment = []byte{
+		0, 0, 0, 1, // one topic
+		0, 5, 't', 'o', 'p', 'i', 'c',
+		0, 0, 0, 3, // 3 partitions
+		255, 255, 255, 255, // no assignments
+		0, 0, 0, 100, // timeout
+		0, // validate only = false
+	}
+
+	createPartitionRequestAssignment = []byte{
+		0, 0, 0, 1,
+		0, 5, 't', 'o', 'p', 'i', 'c',
+		0, 0, 0, 3, // 3 partitions
+		0, 0, 0, 2,
+		0, 0, 0, 2,
+		0, 0, 0, 2, 0, 0, 0, 3,
+		0, 0, 0, 2,
+		0, 0, 0, 3, 0, 0, 0, 1,
+		0, 0, 0, 100,
+		1, // validate only = true
+	}
+)
+
+func TestCreatePartitionsRequest(t *testing.T) {
+	req := &CreatePartitionsRequest{
+		TopicPartitions: map[string]*TopicPartition{
+			"topic": &TopicPartition{
+				Count: 3,
+			},
+		},
+		Timeout: 100 * time.Millisecond,
+	}
+
+	buf := testRequestEncode(t, "no assignment", req, createPartitionRequestNoAssignment)
+	testRequestDecode(t, "no assignment", req, buf)
+
+	req.ValidateOnly = true
+	req.TopicPartitions["topic"].Assignment = [][]int32{{2, 3}, {3, 1}}
+
+	buf = testRequestEncode(t, "assignment", req, createPartitionRequestAssignment)
+	testRequestDecode(t, "assignment", req, buf)
+}

+ 94 - 0
create_partitions_response.go

@@ -0,0 +1,94 @@
+package sarama
+
+import "time"
+
+type CreatePartitionsResponse struct {
+	ThrottleTime         time.Duration
+	TopicPartitionErrors map[string]*TopicPartitionError
+}
+
+func (c *CreatePartitionsResponse) encode(pe packetEncoder) error {
+	pe.putInt32(int32(c.ThrottleTime / time.Millisecond))
+	if err := pe.putArrayLength(len(c.TopicPartitionErrors)); err != nil {
+		return err
+	}
+
+	for topic, partitionError := range c.TopicPartitionErrors {
+		if err := pe.putString(topic); err != nil {
+			return err
+		}
+		if err := partitionError.encode(pe); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (c *CreatePartitionsResponse) decode(pd packetDecoder, version int16) (err error) {
+	throttleTime, err := pd.getInt32()
+	if err != nil {
+		return err
+	}
+	c.ThrottleTime = time.Duration(throttleTime) * time.Millisecond
+
+	n, err := pd.getArrayLength()
+	if err != nil {
+		return err
+	}
+
+	c.TopicPartitionErrors = make(map[string]*TopicPartitionError, n)
+	for i := 0; i < n; i++ {
+		topic, err := pd.getString()
+		if err != nil {
+			return err
+		}
+		c.TopicPartitionErrors[topic] = new(TopicPartitionError)
+		if err := c.TopicPartitionErrors[topic].decode(pd, version); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (r *CreatePartitionsResponse) key() int16 {
+	return 37
+}
+
+func (r *CreatePartitionsResponse) version() int16 {
+	return 0
+}
+
+func (r *CreatePartitionsResponse) requiredVersion() KafkaVersion {
+	return V1_0_0_0
+}
+
+type TopicPartitionError struct {
+	Err    KError
+	ErrMsg *string
+}
+
+func (t *TopicPartitionError) encode(pe packetEncoder) error {
+	pe.putInt16(int16(t.Err))
+
+	if err := pe.putNullableString(t.ErrMsg); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (t *TopicPartitionError) decode(pd packetDecoder, version int16) (err error) {
+	kerr, err := pd.getInt16()
+	if err != nil {
+		return err
+	}
+	t.Err = KError(kerr)
+
+	if t.ErrMsg, err = pd.getNullableString(); err != nil {
+		return err
+	}
+
+	return nil
+}

+ 52 - 0
create_partitions_response_test.go

@@ -0,0 +1,52 @@
+package sarama
+
+import (
+	"reflect"
+	"testing"
+	"time"
+)
+
+var (
+	createPartitionResponseSuccess = []byte{
+		0, 0, 0, 100, // throttleTimeMs
+		0, 0, 0, 1,
+		0, 5, 't', 'o', 'p', 'i', 'c',
+		0, 0, // no error
+		255, 255, // no error message
+	}
+
+	createPartitionResponseFail = []byte{
+		0, 0, 0, 100, // throttleTimeMs
+		0, 0, 0, 1,
+		0, 5, 't', 'o', 'p', 'i', 'c',
+		0, 37, // partition error
+		0, 5, 'e', 'r', 'r', 'o', 'r',
+	}
+)
+
+func TestCreatePartitionsResponse(t *testing.T) {
+	resp := &CreatePartitionsResponse{
+		ThrottleTime: 100 * time.Millisecond,
+		TopicPartitionErrors: map[string]*TopicPartitionError{
+			"topic": &TopicPartitionError{},
+		},
+	}
+
+	testResponse(t, "success", resp, createPartitionResponseSuccess)
+	decodedresp := new(CreatePartitionsResponse)
+	testVersionDecodable(t, "success", decodedresp, createPartitionResponseSuccess, 0)
+	if !reflect.DeepEqual(decodedresp, resp) {
+		t.Errorf("Decoding error: expected %v but got %v", decodedresp, resp)
+	}
+
+	errMsg := "error"
+	resp.TopicPartitionErrors["topic"].Err = ErrInvalidPartitions
+	resp.TopicPartitionErrors["topic"].ErrMsg = &errMsg
+
+	testResponse(t, "with errors", resp, createPartitionResponseFail)
+	decodedresp = new(CreatePartitionsResponse)
+	testVersionDecodable(t, "with errors", decodedresp, createPartitionResponseFail, 0)
+	if !reflect.DeepEqual(decodedresp, resp) {
+		t.Errorf("Decoding error: expected %v but got %v", decodedresp, resp)
+	}
+}

+ 1 - 0
packet_decoder.go

@@ -11,6 +11,7 @@ type packetDecoder interface {
 	getInt64() (int64, error)
 	getVarint() (int64, error)
 	getArrayLength() (int, error)
+	getBool() (bool, error)
 
 	// Collections
 	getBytes() ([]byte, error)

+ 1 - 0
packet_encoder.go

@@ -13,6 +13,7 @@ type packetEncoder interface {
 	putInt64(in int64)
 	putVarint(in int64)
 	putArrayLength(in int) error
+	putBool(in bool)
 
 	// Collections
 	putBytes(in []byte) error

+ 4 - 0
prep_encoder.go

@@ -44,6 +44,10 @@ func (pe *prepEncoder) putArrayLength(in int) error {
 	return nil
 }
 
+func (pe *prepEncoder) putBool(in bool) {
+	pe.length++
+}
+
 // arrays
 
 func (pe *prepEncoder) putBytes(in []byte) error {

+ 14 - 3
real_decoder.go

@@ -11,6 +11,7 @@ var errInvalidByteSliceLengthType = PacketDecodingError{"invalid byteslice lengt
 var errInvalidStringLength = PacketDecodingError{"invalid string length"}
 var errInvalidSubsetSize = PacketDecodingError{"invalid subset size"}
 var errVarintOverflow = PacketDecodingError{"varint overflow"}
+var errInvalidBool = PacketDecodingError{"invalid bool"}
 
 type realDecoder struct {
 	raw   []byte
@@ -90,6 +91,17 @@ func (rd *realDecoder) getArrayLength() (int, error) {
 	return tmp, nil
 }
 
+func (rd *realDecoder) getBool() (bool, error) {
+	b, err := rd.getInt8()
+	if err != nil || b == 0 {
+		return false, err
+	}
+	if b != 1 {
+		return false, errInvalidBool
+	}
+	return true, nil
+}
+
 // collections
 
 func (rd *realDecoder) getBytes() ([]byte, error) {
@@ -143,11 +155,10 @@ func (rd *realDecoder) getString() (string, error) {
 }
 
 func (rd *realDecoder) getNullableString() (*string, error) {
-	tmp, err := rd.getInt16()
-	if err != nil || tmp == -1 {
+	str, err := rd.getString()
+	if err != nil || str == "" {
 		return nil, err
 	}
-	str, err := rd.getString()
 	return &str, err
 }
 

+ 8 - 0
real_encoder.go

@@ -44,6 +44,14 @@ func (re *realEncoder) putArrayLength(in int) error {
 	return nil
 }
 
+func (re *realEncoder) putBool(in bool) {
+	if in {
+		re.putInt8(1)
+		return
+	}
+	re.putInt8(0)
+}
+
 // collection
 
 func (re *realEncoder) putRawBytes(in []byte) error {

+ 2 - 0
request.go

@@ -114,6 +114,8 @@ func allocateBody(key, version int16) protocolBody {
 		return &SaslHandshakeRequest{}
 	case 18:
 		return &ApiVersionsRequest{}
+	case 37:
+		return &CreatePartitionsRequest{}
 	}
 	return nil
 }