Parcourir la source

Clean up better on connection error

Evan Huus il y a 12 ans
Parent
commit
d300ecbabe
1 fichiers modifiés avec 17 ajouts et 6 suppressions
  1. 17 6
      broker.go

+ 17 - 6
broker.go

@@ -58,6 +58,17 @@ func (b *broker) connect() (err error) {
 	return nil
 }
 
+func (b *broker) disconnect() {
+	close(b.requests)
+	b.requests = nil
+
+	close(b.responses)
+	b.responses = nil
+
+	b.conn.Close()
+	b.conn = nil
+}
+
 func (b *broker) encode(pe packetEncoder) {
 	pe.putInt32(b.nodeId)
 	pe.putString(b.host)
@@ -96,7 +107,7 @@ func (b *broker) sendRequestLoop() {
 		buf = <-request.packets
 		n, err = b.conn.Write(buf)
 		if err != nil || n != len(buf) {
-			close(b.requests)
+			b.disconnect()
 			return
 		}
 		b.responses <- request
@@ -112,29 +123,29 @@ func (b *broker) rcvResponseLoop() {
 	for response := range b.responses {
 		n, err = b.conn.Read(header)
 		if err != nil || n != 4 {
-			close(b.responses)
+			b.disconnect()
 			return
 		}
 		length = int32(binary.BigEndian.Uint32(header))
 		if length <= 4 || length > 2*math.MaxUint16 {
-			close(b.responses)
+			b.disconnect()
 			return
 		}
 
 		n, err = b.conn.Read(header)
 		if err != nil || n != 4 {
-			close(b.responses)
+			b.disconnect()
 			return
 		}
 		if response.correlation_id != int32(binary.BigEndian.Uint32(header)) {
-			close(b.responses)
+			b.disconnect()
 			return
 		}
 
 		buf = make([]byte, length-4)
 		n, err = b.conn.Read(buf)
 		if err != nil || n != int(length-4) {
-			close(b.responses)
+			b.disconnect()
 			return
 		}