Browse Source

Consistent encode and decode helper functions.

Evan Huus 12 years ago
parent
commit
ba2bf0685c
5 changed files with 64 additions and 37 deletions
  1. 13 16
      broker.go
  2. 28 0
      encoder_decoder.go
  3. 2 2
      message.go
  4. 0 19
      packet_encoder.go
  5. 21 0
      response_header.go

+ 13 - 16
broker.go

@@ -52,7 +52,7 @@ func (b *Broker) Close() error {
 
 func (b *Broker) Send(clientID *string, req requestEncoder) (decoder, error) {
 	fullRequest := request{b.correlation_id, clientID, req}
-	packet, err := buildBytes(&fullRequest)
+	packet, err := encode(&fullRequest)
 	if err != nil {
 		return nil, err
 	}
@@ -66,16 +66,15 @@ func (b *Broker) Send(clientID *string, req requestEncoder) (decoder, error) {
 
 	select {
 	case buf := <-sendRequest.response.packets:
-		// Only try to decode if we got a response.
-		if buf != nil {
-			decoder := realDecoder{raw: buf}
-			err = response.decode(&decoder)
-			return response, err
-		}
+		err = decode(buf, response)
 	case err = <-sendRequest.response.errors:
 	}
 
-	return nil, err
+	if err != nil {
+		return nil, err
+	}
+
+	return response, nil
 }
 
 func (b *Broker) connect() (err error) {
@@ -154,20 +153,18 @@ func (b *Broker) rcvResponseLoop() {
 			continue
 		}
 
-		decoder := realDecoder{raw: header}
-		length, _ := decoder.getInt32()
-		if length <= 4 || length > 2*math.MaxUint16 {
-			response.errors <- DecodingError("Malformed length field.")
+		decodedHeader := responseHeader{}
+		err = decode(header, &decodedHeader)
+		if err != nil {
+			response.errors <- err
 			continue
 		}
-
-		corr_id, _ := decoder.getInt32()
-		if response.correlation_id != corr_id {
+		if decodedHeader.correlation_id != response.correlation_id {
 			response.errors <- DecodingError("Mismatched correlation id.")
 			continue
 		}
 
-		buf := make([]byte, length-4)
+		buf := make([]byte, decodedHeader.length-4)
 		_, err = io.ReadFull(b.conn, buf)
 		if err != nil {
 			response.errors <- err

+ 28 - 0
encoder_decoder.go

@@ -4,10 +4,38 @@ type encoder interface {
 	encode(pe packetEncoder)
 }
 
+func encode(in encoder) ([]byte, error) {
+	if in == nil {
+		return nil, nil
+	}
+
+	var prepEnc prepEncoder
+	var realEnc realEncoder
+
+	in.encode(&prepEnc)
+	if prepEnc.err != nil {
+		return nil, prepEnc.err
+	}
+
+	realEnc.raw = make([]byte, prepEnc.length)
+	in.encode(&realEnc)
+
+	return realEnc.raw, nil
+}
+
 type decoder interface {
 	decode(pd packetDecoder) error
 }
 
+func decode(buf []byte, in decoder) error {
+	if buf == nil {
+		return nil
+	}
+
+	helper := realDecoder{raw: buf}
+	return in.decode(&helper)
+}
+
 type encoderDecoder interface {
 	encoder
 	decoder

+ 2 - 2
message.go

@@ -117,12 +117,12 @@ func (m *message) decode(pd packetDecoder) (err error) {
 func newMessage(key, value encoder) (msg *message, err error) {
 	msg = new(message)
 
-	msg.key, err = buildBytes(key)
+	msg.key, err = encode(key)
 	if err != nil {
 		return nil, err
 	}
 
-	msg.value, err = buildBytes(value)
+	msg.value, err = encode(value)
 	if err != nil {
 		return nil, err
 	}

+ 0 - 19
packet_encoder.go

@@ -29,22 +29,3 @@ type pushEncoder interface {
 	reserveLength() int
 	run(curOffset int, buf []byte)
 }
-
-func buildBytes(in encoder) ([]byte, error) {
-	if in == nil {
-		return nil, nil
-	}
-
-	var prepEnc prepEncoder
-	var realEnc realEncoder
-
-	in.encode(&prepEnc)
-	if prepEnc.err != nil {
-		return nil, prepEnc.err
-	}
-
-	realEnc.raw = make([]byte, prepEnc.length)
-	in.encode(&realEnc)
-
-	return realEnc.raw, nil
-}

+ 21 - 0
response_header.go

@@ -0,0 +1,21 @@
+package kafka
+
+import "math"
+
+type responseHeader struct {
+	length         int32
+	correlation_id int32
+}
+
+func (r *responseHeader) decode(pd packetDecoder) (err error) {
+	r.length, err = pd.getInt32()
+	if err != nil {
+		return err
+	}
+	if r.length <= 4 || r.length > 2*math.MaxUint16 {
+		return DecodingError("Malformed length field.")
+	}
+
+	r.correlation_id, err = pd.getInt32()
+	return err
+}