|
|
@@ -14,7 +14,6 @@ type Broker struct {
|
|
|
correlation_id int32
|
|
|
|
|
|
conn net.Conn
|
|
|
- addr net.TCPAddr
|
|
|
|
|
|
requests chan requestToSend
|
|
|
responses chan responsePromise
|
|
|
@@ -44,17 +43,48 @@ func NewBroker(host string, port int32) (b *Broker, err error) {
|
|
|
return b, nil
|
|
|
}
|
|
|
|
|
|
+func (b *Broker) Close() error {
|
|
|
+ close(b.requests)
|
|
|
+ close(b.responses)
|
|
|
+
|
|
|
+ return b.conn.Close()
|
|
|
+}
|
|
|
+
|
|
|
+func (b *Broker) Send(clientID *string, req requestEncoder) (decoder, error) {
|
|
|
+ fullRequest := request{b.correlation_id, clientID, req}
|
|
|
+ packet, err := buildBytes(&fullRequest)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ response := req.responseDecoder()
|
|
|
+ sendRequest := requestToSend{responsePromise{b.correlation_id, make(chan []byte), make(chan error)}, response != nil}
|
|
|
+
|
|
|
+ b.requests <- sendRequest
|
|
|
+ sendRequest.response.packets <- *packet // we cheat to avoid poofing up more channels than necessary
|
|
|
+ b.correlation_id++
|
|
|
+
|
|
|
+ select {
|
|
|
+ case buf := <-sendRequest.response.packets:
|
|
|
+ // Only try to decode if we got a response.
|
|
|
+ if buf != nil {
|
|
|
+ decoder := realDecoder{raw: buf}
|
|
|
+ err = response.decode(&decoder)
|
|
|
+ return response, err
|
|
|
+ }
|
|
|
+ case err = <-sendRequest.response.errors:
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil, err
|
|
|
+}
|
|
|
+
|
|
|
func (b *Broker) connect() (err error) {
|
|
|
addr, err := net.ResolveIPAddr("ip", *b.host)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
- b.addr.IP = addr.IP
|
|
|
- b.addr.Zone = addr.Zone
|
|
|
- b.addr.Port = int(b.port)
|
|
|
-
|
|
|
- b.conn, err = net.DialTCP("tcp", nil, &b.addr)
|
|
|
+ b.conn, err = net.DialTCP("tcp", nil, &net.TCPAddr{addr.IP, int(b.port), addr.Zone})
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -68,17 +98,6 @@ func (b *Broker) connect() (err error) {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (b *Broker) forceDisconnect(reqRes *responsePromise, err error) {
|
|
|
- reqRes.errors <- err
|
|
|
- close(reqRes.errors)
|
|
|
- close(reqRes.packets)
|
|
|
-
|
|
|
- close(b.requests)
|
|
|
- close(b.responses)
|
|
|
-
|
|
|
- b.conn.Close()
|
|
|
-}
|
|
|
-
|
|
|
func (b *Broker) encode(pe packetEncoder) {
|
|
|
pe.putInt32(b.id)
|
|
|
pe.putString(b.host)
|
|
|
@@ -117,14 +136,11 @@ func (b *Broker) sendRequestLoop() {
|
|
|
buf := <-request.response.packets
|
|
|
_, err := b.conn.Write(buf)
|
|
|
if err != nil {
|
|
|
- b.forceDisconnect(&request.response, err)
|
|
|
- return
|
|
|
+ request.response.errors <- err
|
|
|
+ continue
|
|
|
}
|
|
|
if request.expectResponse {
|
|
|
b.responses <- request.response
|
|
|
- } else {
|
|
|
- close(request.response.packets)
|
|
|
- close(request.response.errors)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -134,60 +150,30 @@ func (b *Broker) rcvResponseLoop() {
|
|
|
for response := range b.responses {
|
|
|
_, err := io.ReadFull(b.conn, header)
|
|
|
if err != nil {
|
|
|
- b.forceDisconnect(&response, err)
|
|
|
- return
|
|
|
+ response.errors <- err
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
decoder := realDecoder{raw: header}
|
|
|
length, _ := decoder.getInt32()
|
|
|
if length <= 4 || length > 2*math.MaxUint16 {
|
|
|
- b.forceDisconnect(&response, DecodingError("Malformed length field."))
|
|
|
- return
|
|
|
+ response.errors <- DecodingError("Malformed length field.")
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
corr_id, _ := decoder.getInt32()
|
|
|
if response.correlation_id != corr_id {
|
|
|
- b.forceDisconnect(&response, DecodingError("Mismatched correlation id."))
|
|
|
- return
|
|
|
+ response.errors <- DecodingError("Mismatched correlation id.")
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
buf := make([]byte, length-4)
|
|
|
_, err = io.ReadFull(b.conn, buf)
|
|
|
if err != nil {
|
|
|
- b.forceDisconnect(&response, err)
|
|
|
- return
|
|
|
+ response.errors <- err
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
response.packets <- buf
|
|
|
- close(response.packets)
|
|
|
- close(response.errors)
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
-func (b *Broker) SendAndReceive(clientID *string, req requestEncoder) (decoder, error) {
|
|
|
- fullRequest := request{b.correlation_id, clientID, req}
|
|
|
- packet, err := buildBytes(&fullRequest)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
- response := req.responseDecoder()
|
|
|
- sendRequest := requestToSend{responsePromise{b.correlation_id, make(chan []byte), make(chan error)}, response != nil}
|
|
|
-
|
|
|
- b.requests <- sendRequest
|
|
|
- sendRequest.response.packets <- *packet // we cheat to avoid poofing up more channels than necessary
|
|
|
- b.correlation_id++
|
|
|
-
|
|
|
- select {
|
|
|
- case buf := <-sendRequest.response.packets:
|
|
|
- // Only try to decode if we got a response.
|
|
|
- if buf != nil {
|
|
|
- decoder := realDecoder{raw: buf}
|
|
|
- err = response.decode(&decoder)
|
|
|
- return response, err
|
|
|
- }
|
|
|
- case err = <-sendRequest.response.errors:
|
|
|
- }
|
|
|
-
|
|
|
- return nil, err
|
|
|
-}
|