Bläddra i källkod

checkpoint wip update protocol for new encoding api

Evan Huus 12 år sedan
förälder
incheckning
a53b1b1736

+ 2 - 1
encoding/crc32_field.go

@@ -18,9 +18,10 @@ func (c *CRC32Field) ReserveLength() int {
 	return 4
 }
 
-func (c *CRC32Field) Run(curOffset int, buf []byte) {
+func (c *CRC32Field) Run(curOffset int, buf []byte) error {
 	crc := crc32.ChecksumIEEE(buf[c.startOffset+4 : curOffset])
 	binary.BigEndian.PutUint32(buf[c.startOffset:], crc)
+	return nil
 }
 
 func (c *CRC32Field) Check(curOffset int, buf []byte) error {

+ 2 - 1
encoding/length_field.go

@@ -15,8 +15,9 @@ func (l *LengthField) ReserveLength() int {
 	return 4
 }
 
-func (l *LengthField) Run(curOffset int, buf []byte) {
+func (l *LengthField) Run(curOffset int, buf []byte) error {
 	binary.BigEndian.PutUint32(buf[l.startOffset:], uint32(curOffset-l.startOffset-4))
+	return nil
 }
 
 func (l *LengthField) Check(curOffset int, buf []byte) error {

+ 3 - 0
encoding/packet_decoder.go

@@ -16,6 +16,9 @@ type PacketDecoder interface {
 	GetString() (string, error)
 	GetInt32Array() ([]int32, error)
 	GetInt64Array() ([]int64, error)
+
+	// Subsets
+	Remaining() int
 	GetSubset(length int) (PacketDecoder, error)
 
 	// Stacks, see PushDecoder

+ 6 - 5
encoding/packet_encoder.go

@@ -5,10 +5,11 @@ package encoding
 // not about how a string is represented in Kafka.
 type PacketEncoder interface {
 	// Primitives
-	PutInt8(in int8) error
-	PutInt16(in int16) error
-	PutInt32(in int32) error
-	PutInt64(in int64) error
+	PutInt8(in int8)
+	PutInt16(in int16)
+	PutInt32(in int32)
+	PutInt64(in int64)
+	PutArrayLength(in int) error
 
 	// Collections
 	PutBytes(in []byte) error
@@ -16,7 +17,7 @@ type PacketEncoder interface {
 	PutInt32Array(in []int32) error
 
 	// Stacks, see PushEncoder
-	Push(in PushEncoder) error
+	Push(in PushEncoder)
 	Pop() error
 }
 

+ 16 - 10
encoding/prep_encoder.go

@@ -8,23 +8,27 @@ type prepEncoder struct {
 
 // primitives
 
-func (pe *prepEncoder) PutInt8(in int8) error {
+func (pe *prepEncoder) PutInt8(in int8) {
 	pe.length += 1
-	return nil
 }
 
-func (pe *prepEncoder) PutInt16(in int16) error {
+func (pe *prepEncoder) PutInt16(in int16) {
 	pe.length += 2
-	return nil
 }
 
-func (pe *prepEncoder) PutInt32(in int32) error {
+func (pe *prepEncoder) PutInt32(in int32) {
 	pe.length += 4
-	return nil
 }
 
-func (pe *prepEncoder) PutInt64(in int64) error {
+func (pe *prepEncoder) PutInt64(in int64) {
 	pe.length += 8
+}
+
+func (pe *prepEncoder) PutArrayLength(in int) error {
+	if in > math.MaxInt32 {
+		return EncodingError
+	}
+	pe.length += 4
 	return nil
 }
 
@@ -52,16 +56,18 @@ func (pe *prepEncoder) PutString(in string) error {
 }
 
 func (pe *prepEncoder) PutInt32Array(in []int32) error {
-	pe.length += 4
+	err := pe.PutArrayLength(len(in))
+	if err != nil {
+		return err
+	}
 	pe.length += 4 * len(in)
 	return nil
 }
 
 // stackable
 
-func (pe *prepEncoder) Push(in PushEncoder) error {
+func (pe *prepEncoder) Push(in PushEncoder) {
 	pe.length += in.ReserveLength()
-	return nil
 }
 
 func (pe *prepEncoder) Pop() error {

+ 20 - 18
encoding/real_decoder.go

@@ -11,14 +11,10 @@ type realDecoder struct {
 	stack []PushDecoder
 }
 
-func (rd *realDecoder) remaining() int {
-	return len(rd.raw) - rd.off
-}
-
 // primitives
 
 func (rd *realDecoder) GetInt8() (int8, error) {
-	if rd.remaining() < 1 {
+	if rd.Remaining() < 1 {
 		return -1, InsufficientData
 	}
 	tmp := int8(rd.raw[rd.off])
@@ -27,7 +23,7 @@ func (rd *realDecoder) GetInt8() (int8, error) {
 }
 
 func (rd *realDecoder) GetInt16() (int16, error) {
-	if rd.remaining() < 2 {
+	if rd.Remaining() < 2 {
 		return -1, InsufficientData
 	}
 	tmp := int16(binary.BigEndian.Uint16(rd.raw[rd.off:]))
@@ -36,7 +32,7 @@ func (rd *realDecoder) GetInt16() (int16, error) {
 }
 
 func (rd *realDecoder) GetInt32() (int32, error) {
-	if rd.remaining() < 4 {
+	if rd.Remaining() < 4 {
 		return -1, InsufficientData
 	}
 	tmp := int32(binary.BigEndian.Uint32(rd.raw[rd.off:]))
@@ -45,7 +41,7 @@ func (rd *realDecoder) GetInt32() (int32, error) {
 }
 
 func (rd *realDecoder) GetInt64() (int64, error) {
-	if rd.remaining() < 8 {
+	if rd.Remaining() < 8 {
 		return -1, InsufficientData
 	}
 	tmp := int64(binary.BigEndian.Uint64(rd.raw[rd.off:]))
@@ -54,12 +50,12 @@ func (rd *realDecoder) GetInt64() (int64, error) {
 }
 
 func (rd *realDecoder) GetArrayLength() (int, error) {
-	if rd.remaining() < 4 {
+	if rd.Remaining() < 4 {
 		return -1, InsufficientData
 	}
 	tmp := int(binary.BigEndian.Uint32(rd.raw[rd.off:]))
 	rd.off += 4
-	if tmp > rd.remaining() {
+	if tmp > rd.Remaining() {
 		return -1, InsufficientData
 	} else if tmp > 2*math.MaxUint16 {
 		return -1, DecodingError
@@ -85,7 +81,7 @@ func (rd *realDecoder) GetBytes() ([]byte, error) {
 		return nil, nil
 	case n == 0:
 		return make([]byte, 0), nil
-	case n > rd.remaining():
+	case n > rd.Remaining():
 		return nil, InsufficientData
 	default:
 		tmp := rd.raw[rd.off : rd.off+n]
@@ -110,7 +106,7 @@ func (rd *realDecoder) GetString() (string, error) {
 		return "", nil
 	case n == 0:
 		return "", nil
-	case n > rd.remaining():
+	case n > rd.Remaining():
 		return "", InsufficientData
 	default:
 		tmp := string(rd.raw[rd.off : rd.off+n])
@@ -120,14 +116,14 @@ func (rd *realDecoder) GetString() (string, error) {
 }
 
 func (rd *realDecoder) GetInt32Array() ([]int32, error) {
-	if rd.remaining() < 4 {
+	if rd.Remaining() < 4 {
 		return nil, InsufficientData
 	}
 	n := int(binary.BigEndian.Uint32(rd.raw[rd.off:]))
 	rd.off += 4
 
 	var ret []int32 = nil
-	if rd.remaining() < 4*n {
+	if rd.Remaining() < 4*n {
 		return nil, InsufficientData
 	} else if n > 0 {
 		ret = make([]int32, n)
@@ -140,14 +136,14 @@ func (rd *realDecoder) GetInt32Array() ([]int32, error) {
 }
 
 func (rd *realDecoder) GetInt64Array() ([]int64, error) {
-	if rd.remaining() < 4 {
+	if rd.Remaining() < 4 {
 		return nil, InsufficientData
 	}
 	n := int(binary.BigEndian.Uint32(rd.raw[rd.off:]))
 	rd.off += 4
 
 	var ret []int64 = nil
-	if rd.remaining() < 8*n {
+	if rd.Remaining() < 8*n {
 		return nil, InsufficientData
 	} else if n > 0 {
 		ret = make([]int64, n)
@@ -159,8 +155,14 @@ func (rd *realDecoder) GetInt64Array() ([]int64, error) {
 	return ret, nil
 }
 
+// subsets
+
+func (rd *realDecoder) Remaining() int {
+	return len(rd.raw) - rd.off
+}
+
 func (rd *realDecoder) GetSubset(length int) (PacketDecoder, error) {
-	if length > rd.remaining() {
+	if length > rd.Remaining() {
 		return nil, InsufficientData
 	}
 
@@ -173,7 +175,7 @@ func (rd *realDecoder) Push(in PushDecoder) error {
 	in.SaveOffset(rd.off)
 
 	reserve := in.ReserveLength()
-	if rd.remaining() < reserve {
+	if rd.Remaining() < reserve {
 		return DecodingError
 	}
 

+ 10 - 10
encoding/real_encoder.go

@@ -10,27 +10,28 @@ type realEncoder struct {
 
 // primitives
 
-func (re *realEncoder) PutInt8(in int8) error {
+func (re *realEncoder) PutInt8(in int8) {
 	re.raw[re.off] = byte(in)
 	re.off += 1
-	return nil
 }
 
-func (re *realEncoder) PutInt16(in int16) error {
+func (re *realEncoder) PutInt16(in int16) {
 	binary.BigEndian.PutUint16(re.raw[re.off:], uint16(in))
 	re.off += 2
-	return nil
 }
 
-func (re *realEncoder) PutInt32(in int32) error {
+func (re *realEncoder) PutInt32(in int32) {
 	binary.BigEndian.PutUint32(re.raw[re.off:], uint32(in))
 	re.off += 4
-	return nil
 }
 
-func (re *realEncoder) PutInt64(in int64) error {
+func (re *realEncoder) PutInt64(in int64) {
 	binary.BigEndian.PutUint64(re.raw[re.off:], uint64(in))
 	re.off += 8
+}
+
+func (re *realEncoder) PutArrayLength(in int) error {
+	re.PutInt32(int32(in))
 	return nil
 }
 
@@ -55,7 +56,7 @@ func (re *realEncoder) PutString(in string) error {
 }
 
 func (re *realEncoder) PutInt32Array(in []int32) error {
-	re.PutInt32(int32(len(in)))
+	re.PutArrayLength(len(in))
 	for _, val := range in {
 		re.PutInt32(val)
 	}
@@ -64,11 +65,10 @@ func (re *realEncoder) PutInt32Array(in []int32) error {
 
 // stacks
 
-func (re *realEncoder) Push(in PushEncoder) error {
+func (re *realEncoder) Push(in PushEncoder) {
 	in.SaveOffset(re.off)
 	re.off += in.ReserveLength()
 	re.stack = append(re.stack, in)
-	return nil
 }
 
 func (re *realEncoder) Pop() error {

+ 2 - 2
kafka/client.go

@@ -1,7 +1,7 @@
 /*
-Package kafka provides a high-level API for writing Kafka 0.8 clients.
+Package kafka provides a high-level API for writing Kafka clients.
 
-It is built strictly on sister package sarama/protocol.
+It is built on sister package sarama/protocol.
 */
 package kafka
 

+ 12 - 10
protocol/broker.go

@@ -1,5 +1,7 @@
 package protocol
 
+import enc "sarama/encoding"
+import "sarama/types"
 import (
 	"io"
 	"net"
@@ -129,7 +131,7 @@ func (b *Broker) Produce(clientID string, request *ProduceRequest) (*ProduceResp
 	var response *ProduceResponse
 	var err error
 
-	if request.RequiredAcks == NO_RESPONSE {
+	if request.RequiredAcks == types.NO_RESPONSE {
 		err = b.sendAndReceive(clientID, request, nil)
 	} else {
 		response = new(ProduceResponse)
@@ -188,7 +190,7 @@ func (b *Broker) send(clientID string, req requestEncoder, promiseResponse bool)
 	}
 
 	fullRequest := request{b.correlation_id, clientID, req}
-	buf, err := encode(&fullRequest)
+	buf, err := enc.Encode(&fullRequest)
 	if err != nil {
 		return nil, err
 	}
@@ -209,7 +211,7 @@ func (b *Broker) send(clientID string, req requestEncoder, promiseResponse bool)
 	return &promise, nil
 }
 
-func (b *Broker) sendAndReceive(clientID string, req requestEncoder, res decoder) error {
+func (b *Broker) sendAndReceive(clientID string, req requestEncoder, res enc.Decoder) error {
 	promise, err := b.send(clientID, req, res != nil)
 
 	if err != nil {
@@ -222,24 +224,24 @@ func (b *Broker) sendAndReceive(clientID string, req requestEncoder, res decoder
 
 	select {
 	case buf := <-promise.packets:
-		return decode(buf, res)
+		return enc.Decode(buf, res)
 	case err = <-promise.errors:
 		return err
 	}
 }
 
-func (b *Broker) decode(pd packetDecoder) (err error) {
-	b.id, err = pd.getInt32()
+func (b *Broker) Decode(pd enc.PacketDecoder) (err error) {
+	b.id, err = pd.GetInt32()
 	if err != nil {
 		return err
 	}
 
-	b.host, err = pd.getString()
+	b.host, err = pd.GetString()
 	if err != nil {
 		return err
 	}
 
-	b.port, err = pd.getInt32()
+	b.port, err = pd.GetInt32()
 	if err != nil {
 		return err
 	}
@@ -257,13 +259,13 @@ func (b *Broker) responseReceiver() {
 		}
 
 		decodedHeader := responseHeader{}
-		err = decode(header, &decodedHeader)
+		err = enc.Decode(header, &decodedHeader)
 		if err != nil {
 			response.errors <- err
 			continue
 		}
 		if decodedHeader.correlation_id != response.correlation_id {
-			response.errors <- DecodingError("Mismatched correlation id.")
+			response.errors <- enc.DecodingError
 			continue
 		}
 

+ 28 - 12
protocol/fetch_request.go

@@ -1,13 +1,16 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type fetchRequestBlock struct {
 	fetchOffset int64
 	maxBytes    int32
 }
 
-func (f *fetchRequestBlock) encode(pe packetEncoder) {
-	pe.putInt64(f.fetchOffset)
-	pe.putInt32(f.maxBytes)
+func (f *fetchRequestBlock) Encode(pe enc.PacketEncoder) error {
+	pe.PutInt64(f.fetchOffset)
+	pe.PutInt32(f.maxBytes)
+	return nil
 }
 
 type FetchRequest struct {
@@ -16,19 +19,32 @@ type FetchRequest struct {
 	blocks      map[string]map[int32]*fetchRequestBlock
 }
 
-func (f *FetchRequest) encode(pe packetEncoder) {
-	pe.putInt32(-1) // replica ID is always -1 for clients
-	pe.putInt32(f.MaxWaitTime)
-	pe.putInt32(f.MinBytes)
-	pe.putArrayCount(len(f.blocks))
+func (f *FetchRequest) Encode(pe enc.PacketEncoder) (err error) {
+	pe.PutInt32(-1) // replica ID is always -1 for clients
+	pe.PutInt32(f.MaxWaitTime)
+	pe.PutInt32(f.MinBytes)
+	err = pe.PutArrayLength(len(f.blocks))
+	if err != nil {
+		return err
+	}
 	for topic, blocks := range f.blocks {
-		pe.putString(topic)
-		pe.putArrayCount(len(blocks))
+		err = pe.PutString(topic)
+		if err != nil {
+			return err
+		}
+		err = pe.PutArrayLength(len(blocks))
+		if err != nil {
+			return err
+		}
 		for partition, block := range blocks {
-			pe.putInt32(partition)
-			block.encode(pe)
+			pe.PutInt32(partition)
+			err = block.Encode(pe)
+			if err != nil {
+				return err
+			}
 		}
 	}
+	return nil
 }
 
 func (f *FetchRequest) key() int16 {

+ 13 - 11
protocol/fetch_response.go

@@ -1,32 +1,34 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type FetchResponseBlock struct {
 	Err                 KError
 	HighWaterMarkOffset int64
 	MsgSet              MessageSet
 }
 
-func (pr *FetchResponseBlock) decode(pd packetDecoder) (err error) {
-	pr.Err, err = pd.getError()
+func (pr *FetchResponseBlock) Decode(pd enc.PacketDecoder) (err error) {
+	pr.Err, err = pd.GetError()
 	if err != nil {
 		return err
 	}
 
-	pr.HighWaterMarkOffset, err = pd.getInt64()
+	pr.HighWaterMarkOffset, err = pd.GetInt64()
 	if err != nil {
 		return err
 	}
 
-	msgSetSize, err := pd.getInt32()
+	msgSetSize, err := pd.GetInt32()
 	if err != nil {
 		return err
 	}
 
-	msgSetDecoder, err := pd.getSubset(int(msgSetSize))
+	msgSetDecoder, err := pd.GetSubset(int(msgSetSize))
 	if err != nil {
 		return err
 	}
-	err = (&pr.MsgSet).decode(msgSetDecoder)
+	err = (&pr.MsgSet).Decode(msgSetDecoder)
 
 	return err
 }
@@ -35,20 +37,20 @@ type FetchResponse struct {
 	Blocks map[string]map[int32]*FetchResponseBlock
 }
 
-func (fr *FetchResponse) decode(pd packetDecoder) (err error) {
-	numTopics, err := pd.getArrayCount()
+func (fr *FetchResponse) Decode(pd enc.PacketDecoder) (err error) {
+	numTopics, err := pd.GetArrayLength()
 	if err != nil {
 		return err
 	}
 
 	fr.Blocks = make(map[string]map[int32]*FetchResponseBlock, numTopics)
 	for i := 0; i < numTopics; i++ {
-		name, err := pd.getString()
+		name, err := pd.GetString()
 		if err != nil {
 			return err
 		}
 
-		numBlocks, err := pd.getArrayCount()
+		numBlocks, err := pd.GetArrayLength()
 		if err != nil {
 			return err
 		}
@@ -56,7 +58,7 @@ func (fr *FetchResponse) decode(pd packetDecoder) (err error) {
 		fr.Blocks[name] = make(map[int32]*FetchResponseBlock, numBlocks)
 
 		for j := 0; j < numBlocks; j++ {
-			id, err := pd.getInt32()
+			id, err := pd.GetInt32()
 			if err != nil {
 				return err
 			}

+ 32 - 25
protocol/message.go

@@ -1,5 +1,6 @@
 package protocol
 
+import enc "sarama/encoding"
 import (
 	"bytes"
 	"compress/gzip"
@@ -13,26 +14,29 @@ const message_format int8 = 0
 
 type Message struct {
 	Codec types.CompressionCodec // codec used to compress the message contents
-	Key   []byte           // the message key, may be nil
-	Value []byte           // the message contents
+	Key   []byte                 // the message key, may be nil
+	Value []byte                 // the message contents
 }
 
-func (m *Message) encode(pe packetEncoder) {
-	pe.pushCRC32()
+func (m *Message) Encode(pe enc.PacketEncoder) error {
+	pe.Push(&enc.CRC32Field{})
 
-	pe.putInt8(message_format)
+	pe.PutInt8(message_format)
 
 	var attributes int8 = 0
 	attributes |= m.Codec & 0x07
-	pe.putInt8(attributes)
+	pe.PutInt8(attributes)
 
-	pe.putBytes(m.Key)
+	err := pe.PutBytes(m.Key)
+	if err != nil {
+		return err
+	}
 
 	var body []byte
 	switch m.Codec {
-	case COMPRESSION_NONE:
+	case types.COMPRESSION_NONE:
 		body = m.Value
-	case COMPRESSION_GZIP:
+	case types.COMPRESSION_GZIP:
 		if m.Value != nil {
 			var buf bytes.Buffer
 			writer := gzip.NewWriter(&buf)
@@ -40,50 +44,53 @@ func (m *Message) encode(pe packetEncoder) {
 			writer.Close()
 			body = buf.Bytes()
 		}
-	case COMPRESSION_SNAPPY:
+	case types.COMPRESSION_SNAPPY:
 		// TODO
 	}
-	pe.putBytes(body)
+	err = pe.PutBytes(body)
+	if err != nil {
+		return err
+	}
 
-	pe.pop()
+	return pe.Pop()
 }
 
-func (m *Message) decode(pd packetDecoder) (err error) {
-	err = pd.pushCRC32()
+func (m *Message) Decode(pd enc.PacketDecoder) (err error) {
+	err = pd.Push(&CRC32Field{})
 	if err != nil {
 		return err
 	}
 
-	format, err := pd.getInt8()
+	format, err := pd.GetInt8()
 	if err != nil {
 		return err
 	}
 	if format != message_format {
-		return DecodingError("Message format mismatch.")
+		return enc.DecodingError
 	}
 
-	attribute, err := pd.getInt8()
+	attribute, err := pd.GetInt8()
 	if err != nil {
 		return err
 	}
 	m.Codec = attribute & 0x07
 
-	m.Key, err = pd.getBytes()
+	m.Key, err = pd.GetBytes()
 	if err != nil {
 		return err
 	}
 
-	m.Value, err = pd.getBytes()
+	m.Value, err = pd.GetBytes()
 	if err != nil {
 		return err
 	}
 
 	switch m.Codec {
-	case COMPRESSION_NONE:
+	case types.COMPRESSION_NONE:
 		// nothing to do
-	case COMPRESSION_GZIP:
+	case types.COMPRESSION_GZIP:
 		if m.Value == nil {
-			return DecodingError("Nil contents cannot be compressed.")
+			return enc.DecodingError
 		}
 		reader, err := gzip.NewReader(bytes.NewReader(m.Value))
 		if err != nil {
@@ -93,13 +100,13 @@ func (m *Message) decode(pd packetDecoder) (err error) {
 		if err != nil {
 			return err
 		}
-	case COMPRESSION_SNAPPY:
+	case types.COMPRESSION_SNAPPY:
 		// TODO
 	default:
-		return DecodingError("Unknown compression codec.")
+		return enc.DecodingError
 	}
 
-	err = pd.pop()
+	err = pd.Pop()
 	if err != nil {
 		return err
 	}

+ 24 - 16
protocol/message_set.go

@@ -1,35 +1,40 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type MessageBlock struct {
 	Offset int64
 	Msg    *Message
 }
 
-func (msb *MessageBlock) encode(pe packetEncoder) {
-	pe.putInt64(msb.Offset)
-	pe.pushLength32()
-	msb.Msg.encode(pe)
-	pe.pop()
+func (msb *MessageBlock) Encode(pe enc.PacketEncoder) error {
+	pe.PutInt64(msb.Offset)
+	pe.Push(&enc.LengthField{})
+	err := msb.Msg.Encode(pe)
+	if err != nil {
+		return err
+	}
+	pe.Pop()
 }
 
-func (msb *MessageBlock) decode(pd packetDecoder) (err error) {
-	msb.Offset, err = pd.getInt64()
+func (msb *MessageBlock) Decode(pd enc.PacketDecoder) (err error) {
+	msb.Offset, err = pd.GetInt64()
 	if err != nil {
 		return err
 	}
 
-	err = pd.pushLength32()
+	err = pd.Push(&enc.LengthField{})
 	if err != nil {
 		return err
 	}
 
 	msb.Msg = new(Message)
-	err = msb.Msg.decode(pd)
+	err = msb.Msg.Decode(pd)
 	if err != nil {
 		return err
 	}
 
-	err = pd.pop()
+	err = pd.Pop()
 	if err != nil {
 		return err
 	}
@@ -42,22 +47,25 @@ type MessageSet struct {
 	Messages               []*MessageBlock
 }
 
-func (ms *MessageSet) encode(pe packetEncoder) {
+func (ms *MessageSet) Encode(pe enc.PacketEncoder) error {
 	for i := range ms.Messages {
-		ms.Messages[i].encode(pe)
+		err := ms.Messages[i].Encode(pe)
+		if err != nil {
+			return err
+		}
 	}
 }
 
-func (ms *MessageSet) decode(pd packetDecoder) (err error) {
+func (ms *MessageSet) Decode(pd enc.PacketDecoder) (err error) {
 	ms.Messages = nil
 
-	for pd.remaining() > 0 {
+	for pd.Remaining() > 0 {
 		msb := new(MessageBlock)
-		err = msb.decode(pd)
+		err = msb.Decode(pd)
 		switch err.(type) {
 		case nil:
 			ms.Messages = append(ms.Messages, msb)
-		case InsufficientData:
+		case enc.InsufficientData:
 			// As an optimization the server is allowed to return a partial message at the
 			// end of the message set. Clients should handle this case. So we just ignore such things.
 			ms.PartialTrailingMessage = true

+ 12 - 3
protocol/metadata_request.go

@@ -1,13 +1,22 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type MetadataRequest struct {
 	Topics []string
 }
 
-func (mr *MetadataRequest) encode(pe packetEncoder) {
-	pe.putArrayCount(len(mr.Topics))
+func (mr *MetadataRequest) Encode(pe enc.PacketEncoder) error {
+	err := pe.PutArrayLength(len(mr.Topics))
+	if err != nil {
+		return err
+	}
+
 	for i := range mr.Topics {
-		pe.putString(mr.Topics[i])
+		err = pe.PutString(mr.Topics[i])
+		if err != nil {
+			return err
+		}
 	}
 }
 

+ 18 - 16
protocol/metadata_response.go

@@ -1,5 +1,7 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type PartitionMetadata struct {
 	Err      KError
 	Id       int32
@@ -8,28 +10,28 @@ type PartitionMetadata struct {
 	Isr      []int32
 }
 
-func (pm *PartitionMetadata) decode(pd packetDecoder) (err error) {
-	pm.Err, err = pd.getError()
+func (pm *PartitionMetadata) decode(pd enc.PacketDecoder) (err error) {
+	pm.Err, err = pd.GetError()
 	if err != nil {
 		return err
 	}
 
-	pm.Id, err = pd.getInt32()
+	pm.Id, err = pd.GetInt32()
 	if err != nil {
 		return err
 	}
 
-	pm.Leader, err = pd.getInt32()
+	pm.Leader, err = pd.GetInt32()
 	if err != nil {
 		return err
 	}
 
-	pm.Replicas, err = pd.getInt32Array()
+	pm.Replicas, err = pd.GetInt32Array()
 	if err != nil {
 		return err
 	}
 
-	pm.Isr, err = pd.getInt32Array()
+	pm.Isr, err = pd.GetInt32Array()
 	if err != nil {
 		return err
 	}
@@ -43,25 +45,25 @@ type TopicMetadata struct {
 	Partitions []*PartitionMetadata
 }
 
-func (tm *TopicMetadata) decode(pd packetDecoder) (err error) {
-	tm.Err, err = pd.getError()
+func (tm *TopicMetadata) Decode(pd enc.PacketDecoder) (err error) {
+	tm.Err, err = pd.GetError()
 	if err != nil {
 		return err
 	}
 
-	tm.Name, err = pd.getString()
+	tm.Name, err = pd.GetString()
 	if err != nil {
 		return err
 	}
 
-	n, err := pd.getArrayCount()
+	n, err := pd.GetArrayLength()
 	if err != nil {
 		return err
 	}
 	tm.Partitions = make([]*PartitionMetadata, n)
 	for i := 0; i < n; i++ {
 		tm.Partitions[i] = new(PartitionMetadata)
-		err = tm.Partitions[i].decode(pd)
+		err = tm.Partitions[i].Decode(pd)
 		if err != nil {
 			return err
 		}
@@ -75,8 +77,8 @@ type MetadataResponse struct {
 	Topics  []*TopicMetadata
 }
 
-func (m *MetadataResponse) decode(pd packetDecoder) (err error) {
-	n, err := pd.getArrayCount()
+func (m *MetadataResponse) Decode(pd enc.PacketDecoder) (err error) {
+	n, err := pd.GetArrayLength()
 	if err != nil {
 		return err
 	}
@@ -84,13 +86,13 @@ func (m *MetadataResponse) decode(pd packetDecoder) (err error) {
 	m.Brokers = make([]*Broker, n)
 	for i := 0; i < n; i++ {
 		m.Brokers[i] = new(Broker)
-		err = m.Brokers[i].decode(pd)
+		err = m.Brokers[i].Decode(pd)
 		if err != nil {
 			return err
 		}
 	}
 
-	n, err = pd.getArrayCount()
+	n, err = pd.GetArrayLength()
 	if err != nil {
 		return err
 	}
@@ -98,7 +100,7 @@ func (m *MetadataResponse) decode(pd packetDecoder) (err error) {
 	m.Topics = make([]*TopicMetadata, n)
 	for i := 0; i < n; i++ {
 		m.Topics[i] = new(TopicMetadata)
-		err = m.Topics[i].decode(pd)
+		err = m.Topics[i].Decode(pd)
 		if err != nil {
 			return err
 		}

+ 27 - 10
protocol/offset_commit_request.go

@@ -1,13 +1,15 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type offsetCommitRequestBlock struct {
 	offset   int64
 	metadata string
 }
 
-func (r *offsetCommitRequestBlock) encode(pe packetEncoder) {
-	pe.putInt64(r.offset)
-	pe.putString(r.metadata)
+func (r *offsetCommitRequestBlock) Encode(pe enc.PacketEncoder) error {
+	pe.PutInt64(r.offset)
+	return pe.PutString(r.metadata)
 }
 
 type OffsetCommitRequest struct {
@@ -15,15 +17,30 @@ type OffsetCommitRequest struct {
 	blocks        map[string]map[int32]*offsetCommitRequestBlock
 }
 
-func (r *OffsetCommitRequest) encode(pe packetEncoder) {
-	pe.putString(r.ConsumerGroup)
-	pe.putArrayCount(len(r.blocks))
+func (r *OffsetCommitRequest) Encode(pe enc.PacketEncoder) error {
+	err := pe.PutString(r.ConsumerGroup)
+	if err != nil {
+		return err
+	}
+	err = pe.PutArrayLength(len(r.blocks))
+	if err != nil {
+		return err
+	}
 	for topic, partitions := range r.blocks {
-		pe.putString(topic)
-		pe.putArrayCount(len(partitions))
+		err = pe.PutString(topic)
+		if err != nil {
+			return err
+		}
+		err = pe.PutArrayLength(len(partitions))
+		if err != nil {
+			return err
+		}
 		for partition, block := range partitions {
-			pe.putInt32(partition)
-			block.encode(pe)
+			pe.PutInt32(partition)
+			err = block.Encode(pe)
+			if err != nil {
+				return err
+			}
 		}
 	}
 }

+ 9 - 7
protocol/offset_commit_response.go

@@ -1,29 +1,31 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type OffsetCommitResponse struct {
 	ClientID string
 	Errors   map[string]map[int32]KError
 }
 
-func (r *OffsetCommitResponse) decode(pd packetDecoder) (err error) {
-	r.ClientID, err = pd.getString()
+func (r *OffsetCommitResponse) Decode(pd enc.PacketDecoder) (err error) {
+	r.ClientID, err = pd.GetString()
 	if err != nil {
 		return err
 	}
 
-	numTopics, err := pd.getArrayCount()
+	numTopics, err := pd.GetArrayLength()
 	if err != nil {
 		return err
 	}
 
 	r.Errors = make(map[string]map[int32]KError, numTopics)
 	for i := 0; i < numTopics; i++ {
-		name, err := pd.getString()
+		name, err := pd.GetString()
 		if err != nil {
 			return err
 		}
 
-		numErrors, err := pd.getArrayCount()
+		numErrors, err := pd.GetArrayLength()
 		if err != nil {
 			return err
 		}
@@ -31,12 +33,12 @@ func (r *OffsetCommitResponse) decode(pd packetDecoder) (err error) {
 		r.Errors[name] = make(map[int32]KError, numErrors)
 
 		for j := 0; j < numErrors; j++ {
-			id, err := pd.getInt32()
+			id, err := pd.GetInt32()
 			if err != nil {
 				return err
 			}
 
-			tmp, err := pd.getError()
+			tmp, err := pd.GetError()
 			if err != nil {
 				return err
 			}

+ 16 - 5
protocol/offset_fetch_request.go

@@ -1,16 +1,27 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type OffsetFetchRequest struct {
 	ConsumerGroup string
 	partitions    map[string][]int32
 }
 
-func (r *OffsetFetchRequest) encode(pe packetEncoder) {
-	pe.putString(r.ConsumerGroup)
-	pe.putArrayCount(len(r.partitions))
+func (r *OffsetFetchRequest) Encode(pe enc.PacketEncoder) error {
+	err := pe.PutString(r.ConsumerGroup)
+	if err != nil {
+		return err
+	}
+	err = pe.PutArrayLength(len(r.partitions))
+	if err != nil {
+		return err
+	}
 	for topic, partitions := range r.partitions {
-		pe.putString(topic)
-		pe.putInt32Array(partitions)
+		err = pe.PutString(topic)
+		if err != nil {
+			return err
+		}
+		pe.PutInt32Array(partitions)
 	}
 }
 

+ 13 - 11
protocol/offset_fetch_response.go

@@ -1,23 +1,25 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type OffsetFetchResponseBlock struct {
 	Offset   int64
 	Metadata string
 	Err      KError
 }
 
-func (r *OffsetFetchResponseBlock) decode(pd packetDecoder) (err error) {
-	r.Offset, err = pd.getInt64()
+func (r *OffsetFetchResponseBlock) Decode(pd enc.PacketDecoder) (err error) {
+	r.Offset, err = pd.GetInt64()
 	if err != nil {
 		return err
 	}
 
-	r.Metadata, err = pd.getString()
+	r.Metadata, err = pd.GetString()
 	if err != nil {
 		return err
 	}
 
-	r.Err, err = pd.getError()
+	r.Err, err = pd.GetError()
 
 	return err
 }
@@ -27,25 +29,25 @@ type OffsetFetchResponse struct {
 	Blocks   map[string]map[int32]*OffsetFetchResponseBlock
 }
 
-func (r *OffsetFetchResponse) decode(pd packetDecoder) (err error) {
-	r.ClientID, err = pd.getString()
+func (r *OffsetFetchResponse) Decode(pd enc.PacketDecoder) (err error) {
+	r.ClientID, err = pd.GetString()
 	if err != nil {
 		return err
 	}
 
-	numTopics, err := pd.getArrayCount()
+	numTopics, err := pd.GetArrayLength()
 	if err != nil {
 		return err
 	}
 
 	r.Blocks = make(map[string]map[int32]*OffsetFetchResponseBlock, numTopics)
 	for i := 0; i < numTopics; i++ {
-		name, err := pd.getString()
+		name, err := pd.GetString()
 		if err != nil {
 			return err
 		}
 
-		numBlocks, err := pd.getArrayCount()
+		numBlocks, err := pd.GetArrayLength()
 		if err != nil {
 			return err
 		}
@@ -53,13 +55,13 @@ func (r *OffsetFetchResponse) decode(pd packetDecoder) (err error) {
 		r.Blocks[name] = make(map[int32]*OffsetFetchResponseBlock, numBlocks)
 
 		for j := 0; j < numBlocks; j++ {
-			id, err := pd.getInt32()
+			id, err := pd.GetInt32()
 			if err != nil {
 				return err
 			}
 
 			block := new(OffsetFetchResponseBlock)
-			err = block.decode(pd)
+			err = block.Decode(pd)
 			if err != nil {
 				return err
 			}

+ 23 - 8
protocol/offset_request.go

@@ -1,28 +1,43 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type offsetRequestBlock struct {
 	time       int64
 	maxOffsets int32
 }
 
-func (r *offsetRequestBlock) encode(pe packetEncoder) {
+func (r *offsetRequestBlock) Encode(pe enc.PacketEncoder) error {
 	pe.putInt64(r.time)
 	pe.putInt32(r.maxOffsets)
+	return nil
 }
 
 type OffsetRequest struct {
 	blocks map[string]map[int32]*offsetRequestBlock
 }
 
-func (r *OffsetRequest) encode(pe packetEncoder) {
-	pe.putInt32(-1) // replica ID is always -1 for clients
-	pe.putArrayCount(len(r.blocks))
+func (r *OffsetRequest) Encode(pe enc.PacketEncoder) error {
+	pe.PutInt32(-1) // replica ID is always -1 for clients
+	err := pe.PutArrayLength(len(r.blocks))
+	if err != nil {
+		return err
+	}
 	for topic, partitions := range r.blocks {
-		pe.putString(topic)
-		pe.putArrayCount(len(partitions))
+		err = pe.PutString(topic)
+		if err != nil {
+			return err
+		}
+		err = pe.PutArrayLength(len(partitions))
+		if err != nil {
+			return err
+		}
 		for partition, block := range partitions {
-			pe.putInt32(partition)
-			block.encode(pe)
+			pe.PutInt32(partition)
+			err = block.Encode(pe)
+			if err != nil {
+				return err
+			}
 		}
 	}
 }

+ 11 - 9
protocol/offset_response.go

@@ -1,17 +1,19 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type OffsetResponseBlock struct {
 	Err     KError
 	Offsets []int64
 }
 
-func (r *OffsetResponseBlock) decode(pd packetDecoder) (err error) {
-	r.Err, err = pd.getError()
+func (r *OffsetResponseBlock) Decode(pd enc.PacketDecoder) (err error) {
+	r.Err, err = pd.GetError()
 	if err != nil {
 		return err
 	}
 
-	r.Offsets, err = pd.getInt64Array()
+	r.Offsets, err = pd.GetInt64Array()
 
 	return err
 }
@@ -20,20 +22,20 @@ type OffsetResponse struct {
 	Blocks map[string]map[int32]*OffsetResponseBlock
 }
 
-func (r *OffsetResponse) decode(pd packetDecoder) (err error) {
-	numTopics, err := pd.getArrayCount()
+func (r *OffsetResponse) Decode(pd enc.PacketDecoder) (err error) {
+	numTopics, err := pd.GetArrayLength()
 	if err != nil {
 		return err
 	}
 
 	r.Blocks = make(map[string]map[int32]*OffsetResponseBlock, numTopics)
 	for i := 0; i < numTopics; i++ {
-		name, err := pd.getString()
+		name, err := pd.GetString()
 		if err != nil {
 			return err
 		}
 
-		numBlocks, err := pd.getArrayCount()
+		numBlocks, err := pd.GetArrayLength()
 		if err != nil {
 			return err
 		}
@@ -41,13 +43,13 @@ func (r *OffsetResponse) decode(pd packetDecoder) (err error) {
 		r.Blocks[name] = make(map[int32]*OffsetResponseBlock, numBlocks)
 
 		for j := 0; j < numBlocks; j++ {
-			id, err := pd.getInt32()
+			id, err := pd.GetInt32()
 			if err != nil {
 				return err
 			}
 
 			block := new(OffsetResponseBlock)
-			err = block.decode(pd)
+			err = block.Decode(pd)
 			if err != nil {
 				return err
 			}

+ 29 - 11
protocol/produce_request.go

@@ -1,23 +1,41 @@
 package protocol
 
+import enc "sarama/encoding"
+import "sarama/types"
+
 type ProduceRequest struct {
-	RequiredAcks int16
+	RequiredAcks types.RequiredAcks
 	Timeout      int32
 	msgSets      map[string]map[int32]*MessageSet
 }
 
-func (p *ProduceRequest) encode(pe packetEncoder) {
-	pe.putInt16(p.RequiredAcks)
-	pe.putInt32(p.Timeout)
-	pe.putArrayCount(len(p.msgSets))
+func (p *ProduceRequest) Encode(pe enc.PacketEncoder) error {
+	pe.PutInt16(p.RequiredAcks)
+	pe.PutInt32(p.Timeout)
+	err := pe.PutArrayLength(len(p.msgSets))
+	if err != nil {
+		return err
+	}
 	for topic, partitions := range p.msgSets {
-		pe.putString(topic)
-		pe.putArrayCount(len(partitions))
+		err = pe.PutString(topic)
+		if err != nil {
+			return err
+		}
+		err = pe.PutArrayLength(len(partitions))
+		if err != nil {
+			return err
+		}
 		for id, msgSet := range partitions {
-			pe.putInt32(id)
-			pe.pushLength32()
-			msgSet.encode(pe)
-			pe.pop()
+			pe.PutInt32(id)
+			pe.PushLength32()
+			err = msgSet.Encode(pe)
+			if err != nil {
+				return err
+			}
+			err = pe.Pop()
+			if err != nil {
+				return err
+			}
 		}
 	}
 }

+ 11 - 9
protocol/produce_response.go

@@ -1,17 +1,19 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type ProduceResponseBlock struct {
 	Err    KError
 	Offset int64
 }
 
-func (pr *ProduceResponseBlock) decode(pd packetDecoder) (err error) {
-	pr.Err, err = pd.getError()
+func (pr *ProduceResponseBlock) Decode(pd enc.PacketDecoder) (err error) {
+	pr.Err, err = pd.GetError()
 	if err != nil {
 		return err
 	}
 
-	pr.Offset, err = pd.getInt64()
+	pr.Offset, err = pd.GetInt64()
 	if err != nil {
 		return err
 	}
@@ -23,20 +25,20 @@ type ProduceResponse struct {
 	Blocks map[string]map[int32]*ProduceResponseBlock
 }
 
-func (pr *ProduceResponse) decode(pd packetDecoder) (err error) {
-	numTopics, err := pd.getArrayCount()
+func (pr *ProduceResponse) Decode(pd enc.PacketDecoder) (err error) {
+	numTopics, err := pd.GetArrayLength()
 	if err != nil {
 		return err
 	}
 
 	pr.Blocks = make(map[string]map[int32]*ProduceResponseBlock, numTopics)
 	for i := 0; i < numTopics; i++ {
-		name, err := pd.getString()
+		name, err := pd.GetString()
 		if err != nil {
 			return err
 		}
 
-		numBlocks, err := pd.getArrayCount()
+		numBlocks, err := pd.GetArrayLength()
 		if err != nil {
 			return err
 		}
@@ -44,13 +46,13 @@ func (pr *ProduceResponse) decode(pd packetDecoder) (err error) {
 		pr.Blocks[name] = make(map[int32]*ProduceResponseBlock, numBlocks)
 
 		for j := 0; j < numBlocks; j++ {
-			id, err := pd.getInt32()
+			id, err := pd.GetInt32()
 			if err != nil {
 				return err
 			}
 
 			block := new(ProduceResponseBlock)
-			err = block.decode(pd)
+			err = block.Decode(pd)
 			if err != nil {
 				return err
 			}

+ 17 - 9
protocol/request.go

@@ -1,7 +1,9 @@
 package protocol
 
+import enc "sarama/encoding"
+
 type requestEncoder interface {
-	encoder
+	enc.Encoder
 	key() int16
 	version() int16
 }
@@ -12,12 +14,18 @@ type request struct {
 	body           requestEncoder
 }
 
-func (r *request) encode(pe packetEncoder) {
-	pe.pushLength32()
-	pe.putInt16(r.body.key())
-	pe.putInt16(r.body.version())
-	pe.putInt32(r.correlation_id)
-	pe.putString(r.id)
-	r.body.encode(pe)
-	pe.pop()
+func (r *request) Encode(pe enc.PacketEncoder) error {
+	pe.Push(&LengthField{})
+	pe.PutInt16(r.body.key())
+	pe.PutInt16(r.body.version())
+	pe.PutInt32(r.correlation_id)
+	err = pe.PutString(r.id)
+	if err != nil {
+		return err
+	}
+	err = r.body.Encode(pe)
+	if err != nil {
+		return err
+	}
+	return pe.Pop()
 }

+ 5 - 4
protocol/response_header.go

@@ -1,21 +1,22 @@
 package protocol
 
 import "math"
+import enc "sarama/encoding"
 
 type responseHeader struct {
 	length         int32
 	correlation_id int32
 }
 
-func (r *responseHeader) decode(pd packetDecoder) (err error) {
-	r.length, err = pd.getInt32()
+func (r *responseHeader) Decode(pd enc.PacketDecoder) (err error) {
+	r.length, err = pd.GetInt32()
 	if err != nil {
 		return err
 	}
 	if r.length <= 4 || r.length > 2*math.MaxUint16 {
-		return DecodingError("Malformed length field.")
+		return enc.DecodingError("Malformed length field.")
 	}
 
-	r.correlation_id, err = pd.getInt32()
+	r.correlation_id, err = pd.GetInt32()
 	return err
 }