broker.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package kafka
  2. import (
  3. "io"
  4. "math"
  5. "net"
  6. )
  7. type broker struct {
  8. id int32
  9. host *string
  10. port int32
  11. correlation_id int32
  12. conn net.Conn
  13. addr net.TCPAddr
  14. requests chan requestToSend
  15. responses chan responsePromise
  16. }
  17. type responsePromise struct {
  18. correlation_id int32
  19. packets chan []byte
  20. errors chan error
  21. }
  22. type requestToSend struct {
  23. // we cheat and use the responsePromise channels to avoid creating more than necessary
  24. response responsePromise
  25. expectResponse bool
  26. }
  27. func newBroker(host string, port int32) (b *broker, err error) {
  28. b = new(broker)
  29. b.id = -1 // don't know it yet
  30. b.host = &host
  31. b.port = port
  32. err = b.connect()
  33. if err != nil {
  34. return nil, err
  35. }
  36. return b, nil
  37. }
  38. func (b *broker) connect() (err error) {
  39. addr, err := net.ResolveIPAddr("ip", *b.host)
  40. if err != nil {
  41. return err
  42. }
  43. b.addr.IP = addr.IP
  44. b.addr.Zone = addr.Zone
  45. b.addr.Port = int(b.port)
  46. b.conn, err = net.DialTCP("tcp", nil, &b.addr)
  47. if err != nil {
  48. return err
  49. }
  50. b.requests = make(chan requestToSend)
  51. b.responses = make(chan responsePromise)
  52. go b.sendRequestLoop()
  53. go b.rcvResponseLoop()
  54. return nil
  55. }
  56. func (b *broker) forceDisconnect(reqRes *responsePromise, err error) {
  57. reqRes.errors <- err
  58. close(reqRes.errors)
  59. close(reqRes.packets)
  60. close(b.requests)
  61. close(b.responses)
  62. b.conn.Close()
  63. }
  64. func (b *broker) encode(pe packetEncoder) {
  65. pe.putInt32(b.id)
  66. pe.putString(b.host)
  67. pe.putInt32(b.port)
  68. }
  69. func (b *broker) decode(pd packetDecoder) (err error) {
  70. b.id, err = pd.getInt32()
  71. if err != nil {
  72. return err
  73. }
  74. b.host, err = pd.getString()
  75. if err != nil {
  76. return err
  77. }
  78. b.port, err = pd.getInt32()
  79. if err != nil {
  80. return err
  81. }
  82. if b.port > math.MaxUint16 {
  83. return DecodingError{"Broker port > 65536"}
  84. }
  85. err = b.connect()
  86. if err != nil {
  87. return err
  88. }
  89. return nil
  90. }
  91. func (b *broker) sendRequestLoop() {
  92. for request := range b.requests {
  93. buf := <-request.response.packets
  94. _, err := b.conn.Write(buf)
  95. if err != nil {
  96. b.forceDisconnect(&request.response, err)
  97. return
  98. }
  99. if request.expectResponse {
  100. b.responses <- request.response
  101. } else {
  102. close(request.response.packets)
  103. close(request.response.errors)
  104. }
  105. }
  106. }
  107. func (b *broker) rcvResponseLoop() {
  108. header := make([]byte, 8)
  109. for response := range b.responses {
  110. _, err := io.ReadFull(b.conn, header)
  111. if err != nil {
  112. b.forceDisconnect(&response, err)
  113. return
  114. }
  115. decoder := realDecoder{raw: header}
  116. length, _ := decoder.getInt32()
  117. if length <= 4 || length > 2*math.MaxUint16 {
  118. b.forceDisconnect(&response, DecodingError{})
  119. return
  120. }
  121. corr_id, _ := decoder.getInt32()
  122. if response.correlation_id != corr_id {
  123. b.forceDisconnect(&response, DecodingError{})
  124. return
  125. }
  126. buf := make([]byte, length-4)
  127. _, err = io.ReadFull(b.conn, buf)
  128. if err != nil {
  129. b.forceDisconnect(&response, err)
  130. return
  131. }
  132. response.packets <- buf
  133. close(response.packets)
  134. close(response.errors)
  135. }
  136. }
  137. func (b *broker) sendRequest(clientID *string, body requestEncoder) (*responsePromise, error) {
  138. var prepEnc prepEncoder
  139. var realEnc realEncoder
  140. req := request{b.correlation_id, clientID, body}
  141. req.encode(&prepEnc)
  142. if prepEnc.err != nil {
  143. return nil, prepEnc.err
  144. }
  145. realEnc.raw = make([]byte, prepEnc.length+4)
  146. realEnc.putInt32(int32(prepEnc.length))
  147. req.encode(&realEnc)
  148. request := requestToSend{responsePromise{b.correlation_id, make(chan []byte), make(chan error)}, body.expectResponse()}
  149. b.requests <- request
  150. request.response.packets <- realEnc.raw // we cheat to avoid poofing up more channels than necessary
  151. b.correlation_id++
  152. return &request.response, nil
  153. }
  154. // returns true if there was a response, even if there was an error decoding it (in
  155. // which case it will also return an error of some sort)
  156. func (b *broker) sendAndReceive(clientID *string, req requestEncoder, res decoder) (bool, error) {
  157. responseChan, err := b.sendRequest(clientID, req)
  158. if err != nil {
  159. return false, err
  160. }
  161. select {
  162. case buf := <-responseChan.packets:
  163. // Only try to decode if we got a response.
  164. if buf != nil {
  165. decoder := realDecoder{raw: buf}
  166. err = res.decode(&decoder)
  167. return true, err
  168. }
  169. case err = <-responseChan.errors:
  170. }
  171. return false, err
  172. }