浏览代码

Broker simplification

publish Close method, and bubble up all errors to caller rather than
auto-closing on most of them.
Evan Huus 12 年之前
父节点
当前提交
ded6c01f40
共有 3 个文件被更改,包括 48 次插入62 次删除
  1. 46 60
      broker.go
  2. 1 1
      metadata_cache.go
  3. 1 1
      producer.go

+ 46 - 60
broker.go

@@ -14,7 +14,6 @@ type Broker struct {
 	correlation_id int32
 
 	conn net.Conn
-	addr net.TCPAddr
 
 	requests  chan requestToSend
 	responses chan responsePromise
@@ -44,17 +43,48 @@ func NewBroker(host string, port int32) (b *Broker, err error) {
 	return b, nil
 }
 
+func (b *Broker) Close() error {
+	close(b.requests)
+	close(b.responses)
+
+	return b.conn.Close()
+}
+
+func (b *Broker) Send(clientID *string, req requestEncoder) (decoder, error) {
+	fullRequest := request{b.correlation_id, clientID, req}
+	packet, err := buildBytes(&fullRequest)
+	if err != nil {
+		return nil, err
+	}
+
+	response := req.responseDecoder()
+	sendRequest := requestToSend{responsePromise{b.correlation_id, make(chan []byte), make(chan error)}, response != nil}
+
+	b.requests <- sendRequest
+	sendRequest.response.packets <- *packet // we cheat to avoid poofing up more channels than necessary
+	b.correlation_id++
+
+	select {
+	case buf := <-sendRequest.response.packets:
+		// Only try to decode if we got a response.
+		if buf != nil {
+			decoder := realDecoder{raw: buf}
+			err = response.decode(&decoder)
+			return response, err
+		}
+	case err = <-sendRequest.response.errors:
+	}
+
+	return nil, err
+}
+
 func (b *Broker) connect() (err error) {
 	addr, err := net.ResolveIPAddr("ip", *b.host)
 	if err != nil {
 		return err
 	}
 
-	b.addr.IP = addr.IP
-	b.addr.Zone = addr.Zone
-	b.addr.Port = int(b.port)
-
-	b.conn, err = net.DialTCP("tcp", nil, &b.addr)
+	b.conn, err = net.DialTCP("tcp", nil, &net.TCPAddr{addr.IP, int(b.port), addr.Zone})
 	if err != nil {
 		return err
 	}
@@ -68,17 +98,6 @@ func (b *Broker) connect() (err error) {
 	return nil
 }
 
-func (b *Broker) forceDisconnect(reqRes *responsePromise, err error) {
-	reqRes.errors <- err
-	close(reqRes.errors)
-	close(reqRes.packets)
-
-	close(b.requests)
-	close(b.responses)
-
-	b.conn.Close()
-}
-
 func (b *Broker) encode(pe packetEncoder) {
 	pe.putInt32(b.id)
 	pe.putString(b.host)
@@ -117,14 +136,11 @@ func (b *Broker) sendRequestLoop() {
 		buf := <-request.response.packets
 		_, err := b.conn.Write(buf)
 		if err != nil {
-			b.forceDisconnect(&request.response, err)
-			return
+			request.response.errors <- err
+			continue
 		}
 		if request.expectResponse {
 			b.responses <- request.response
-		} else {
-			close(request.response.packets)
-			close(request.response.errors)
 		}
 	}
 }
@@ -134,60 +150,30 @@ func (b *Broker) rcvResponseLoop() {
 	for response := range b.responses {
 		_, err := io.ReadFull(b.conn, header)
 		if err != nil {
-			b.forceDisconnect(&response, err)
-			return
+			response.errors <- err
+			continue
 		}
 
 		decoder := realDecoder{raw: header}
 		length, _ := decoder.getInt32()
 		if length <= 4 || length > 2*math.MaxUint16 {
-			b.forceDisconnect(&response, DecodingError("Malformed length field."))
-			return
+			response.errors <- DecodingError("Malformed length field.")
+			continue
 		}
 
 		corr_id, _ := decoder.getInt32()
 		if response.correlation_id != corr_id {
-			b.forceDisconnect(&response, DecodingError("Mismatched correlation id."))
-			return
+			response.errors <- DecodingError("Mismatched correlation id.")
+			continue
 		}
 
 		buf := make([]byte, length-4)
 		_, err = io.ReadFull(b.conn, buf)
 		if err != nil {
-			b.forceDisconnect(&response, err)
-			return
+			response.errors <- err
+			continue
 		}
 
 		response.packets <- buf
-		close(response.packets)
-		close(response.errors)
 	}
 }
-
-func (b *Broker) SendAndReceive(clientID *string, req requestEncoder) (decoder, error) {
-	fullRequest := request{b.correlation_id, clientID, req}
-	packet, err := buildBytes(&fullRequest)
-	if err != nil {
-		return nil, err
-	}
-
-	response := req.responseDecoder()
-	sendRequest := requestToSend{responsePromise{b.correlation_id, make(chan []byte), make(chan error)}, response != nil}
-
-	b.requests <- sendRequest
-	sendRequest.response.packets <- *packet // we cheat to avoid poofing up more channels than necessary
-	b.correlation_id++
-
-	select {
-	case buf := <-sendRequest.response.packets:
-		// Only try to decode if we got a response.
-		if buf != nil {
-			decoder := realDecoder{raw: buf}
-			err = response.decode(&decoder)
-			return response, err
-		}
-	case err = <-sendRequest.response.errors:
-	}
-
-	return nil, err
-}

+ 1 - 1
metadata_cache.go

@@ -87,7 +87,7 @@ func (mc *metadataCache) refreshTopics(topics []*string) error {
 		return OutOfBrokers{}
 	}
 
-	decoder, err := broker.SendAndReceive(mc.client.id, &metadataRequest{topics})
+	decoder, err := broker.Send(mc.client.id, &metadataRequest{topics})
 	if err != nil {
 		return err
 	}

+ 1 - 1
producer.go

@@ -52,7 +52,7 @@ func (p *Producer) SendMessage(key, value encoder) (*ProduceResponse, error) {
 	request.requiredAcks = p.responseCondition
 	request.timeout = p.responseTimeout
 
-	decoder, err := broker.SendAndReceive(p.id, request)
+	decoder, err := broker.Send(p.id, request)
 	if err != nil {
 		return nil, err
 	}