mockbroker.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. package sarama
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net"
  9. "reflect"
  10. "strconv"
  11. "sync"
  12. "time"
  13. "github.com/davecgh/go-spew/spew"
  14. )
  15. const (
  16. expectationTimeout = 500 * time.Millisecond
  17. )
  18. type requestHandlerFunc func(req *request) (res encoder)
  19. // MockBroker is a mock Kafka broker that is used in unit tests. It is exposed
  20. // to facilitate testing of higher level or specialized consumers and producers
  21. // built on top of Sarama. Note that it does not 'mimic' the Kafka API protocol,
  22. // but rather provides a facility to do that. It takes care of the TCP
  23. // transport, request unmarshaling, response marshaling, and makes it the test
  24. // writer responsibility to program correct according to the Kafka API protocol
  25. // MockBroker behaviour.
  26. //
  27. // MockBroker is implemented as a TCP server listening on a kernel-selected
  28. // localhost port that can accept many connections. It reads Kafka requests
  29. // from that connection and returns responses programmed by the SetHandlerByMap
  30. // function. If a MockBroker receives a request that it has no programmed
  31. // response for, then it returns nothing and the request times out.
  32. //
  33. // A set of MockRequest builders to define mappings used by MockBroker is
  34. // provided by Sarama. But users can develop MockRequests of their own and use
  35. // them along with or instead of the standard ones.
  36. //
  37. // When running tests with MockBroker it is strongly recommended to specify
  38. // a timeout to `go test` so that if the broker hangs waiting for a response,
  39. // the test panics.
  40. //
  41. // It is not necessary to prefix message length or correlation ID to your
  42. // response bytes, the server does that automatically as a convenience.
  43. type MockBroker struct {
  44. brokerID int32
  45. port int32
  46. closing chan none
  47. stopper chan none
  48. expectations chan encoder
  49. done sync.WaitGroup
  50. listener net.Listener
  51. t TestReporter
  52. latency time.Duration
  53. handler requestHandlerFunc
  54. origHandler bool
  55. history []*RequestResponse
  56. lock sync.Mutex
  57. }
  58. // RequestResponse represents a Request/Response pair processed by MockBroker.
  59. type RequestResponse struct {
  60. Request protocolBody
  61. Response encoder
  62. RequestSize int
  63. ResponseSize int
  64. }
  65. // SetLatency makes broker pause for the specified period every time before
  66. // replying.
  67. func (b *MockBroker) SetLatency(latency time.Duration) {
  68. b.latency = latency
  69. }
  70. // SetHandlerByMap defines mapping of Request types to MockResponses. When a
  71. // request is received by the broker, it looks up the request type in the map
  72. // and uses the found MockResponse instance to generate an appropriate reply.
  73. // If the request type is not found in the map then nothing is sent.
  74. func (b *MockBroker) SetHandlerByMap(handlerMap map[string]MockResponse) {
  75. b.setHandler(func(req *request) (res encoder) {
  76. reqTypeName := reflect.TypeOf(req.body).Elem().Name()
  77. mockResponse := handlerMap[reqTypeName]
  78. if mockResponse == nil {
  79. return nil
  80. }
  81. return mockResponse.For(req.body)
  82. })
  83. }
  84. // BrokerID returns broker ID assigned to the broker.
  85. func (b *MockBroker) BrokerID() int32 {
  86. return b.brokerID
  87. }
  88. // History returns a slice of RequestResponse pairs in the order they were
  89. // processed by the broker. Note that in case of multiple connections to the
  90. // broker the order expected by a test can be different from the order recorded
  91. // in the history, unless some synchronization is implemented in the test.
  92. func (b *MockBroker) History() []RequestResponse {
  93. b.lock.Lock()
  94. history := make([]RequestResponse, len(b.history))
  95. for i, rr := range b.history {
  96. history[i] = *rr
  97. }
  98. b.lock.Unlock()
  99. return history
  100. }
  101. // Port returns the TCP port number the broker is listening for requests on.
  102. func (b *MockBroker) Port() int32 {
  103. return b.port
  104. }
  105. // Addr returns the broker connection string in the form "<address>:<port>".
  106. func (b *MockBroker) Addr() string {
  107. return b.listener.Addr().String()
  108. }
  109. // Wait for the remaining expectations to be consumed or that the timeout expires
  110. func (b *MockBroker) WaitForExpectations(timeout time.Duration) error {
  111. c := make(chan none)
  112. go func() {
  113. b.done.Wait()
  114. close(c)
  115. }()
  116. select {
  117. case <-c:
  118. return nil
  119. case <-time.After(timeout):
  120. return errors.New(fmt.Sprintf("Not all expectations have been honoured after %v", timeout))
  121. }
  122. }
  123. // Close terminates the broker blocking until it stops internal goroutines and
  124. // releases all resources.
  125. func (b *MockBroker) Close() {
  126. close(b.expectations)
  127. if len(b.expectations) > 0 {
  128. buf := bytes.NewBufferString(fmt.Sprintf("mockbroker/%d: not all expectations were satisfied! Still waiting on:\n", b.BrokerID()))
  129. for e := range b.expectations {
  130. _, _ = buf.WriteString(spew.Sdump(e))
  131. }
  132. b.t.Error(buf.String())
  133. }
  134. close(b.closing)
  135. <-b.stopper
  136. }
  137. // setHandler sets the specified function as the request handler. Whenever
  138. // a mock broker reads a request from the wire it passes the request to the
  139. // function and sends back whatever the handler function returns.
  140. func (b *MockBroker) setHandler(handler requestHandlerFunc) {
  141. b.lock.Lock()
  142. b.handler = handler
  143. b.origHandler = false
  144. b.lock.Unlock()
  145. }
  146. func (b *MockBroker) serverLoop() {
  147. defer close(b.stopper)
  148. var err error
  149. var conn net.Conn
  150. go func() {
  151. <-b.closing
  152. err := b.listener.Close()
  153. if err != nil {
  154. b.t.Error(err)
  155. }
  156. }()
  157. wg := &sync.WaitGroup{}
  158. i := 0
  159. for conn, err = b.listener.Accept(); err == nil; conn, err = b.listener.Accept() {
  160. wg.Add(1)
  161. go b.handleRequests(conn, i, wg)
  162. i++
  163. }
  164. wg.Wait()
  165. Logger.Printf("*** mockbroker/%d: listener closed, err=%v", b.BrokerID(), err)
  166. }
  167. func (b *MockBroker) handleRequests(conn net.Conn, idx int, wg *sync.WaitGroup) {
  168. defer wg.Done()
  169. defer func() {
  170. _ = conn.Close()
  171. }()
  172. Logger.Printf("*** mockbroker/%d/%d: connection opened", b.BrokerID(), idx)
  173. var err error
  174. abort := make(chan none)
  175. defer close(abort)
  176. go func() {
  177. select {
  178. case <-b.closing:
  179. _ = conn.Close()
  180. case <-abort:
  181. }
  182. }()
  183. resHeader := make([]byte, 8)
  184. for {
  185. req, bytesRead, err := decodeRequest(conn)
  186. if err != nil {
  187. Logger.Printf("*** mockbroker/%d/%d: invalid request: err=%+v, %+v", b.brokerID, idx, err, spew.Sdump(req))
  188. b.serverError(err)
  189. break
  190. }
  191. if b.latency > 0 {
  192. time.Sleep(b.latency)
  193. }
  194. b.lock.Lock()
  195. originalHandlerUsed := b.origHandler
  196. res := b.handler(req)
  197. requestResponse := RequestResponse{req.body, res, bytesRead, 0}
  198. b.history = append(b.history, &requestResponse)
  199. b.lock.Unlock()
  200. if res == nil {
  201. Logger.Printf("*** mockbroker/%d/%d: ignored %v", b.brokerID, idx, spew.Sdump(req))
  202. continue
  203. }
  204. Logger.Printf("*** mockbroker/%d/%d: served %v -> %v", b.brokerID, idx, req, res)
  205. encodedRes, err := encode(res)
  206. if err != nil {
  207. b.serverError(err)
  208. break
  209. }
  210. if len(encodedRes) != 0 {
  211. binary.BigEndian.PutUint32(resHeader, uint32(len(encodedRes)+4))
  212. binary.BigEndian.PutUint32(resHeader[4:], uint32(req.correlationID))
  213. if _, err = conn.Write(resHeader); err != nil {
  214. b.serverError(err)
  215. break
  216. }
  217. if _, err = conn.Write(encodedRes); err != nil {
  218. b.serverError(err)
  219. break
  220. }
  221. b.lock.Lock()
  222. requestResponse.ResponseSize = len(resHeader) + len(encodedRes)
  223. b.lock.Unlock()
  224. }
  225. // Prevent negative wait group in case we are using a custom handler
  226. if originalHandlerUsed {
  227. b.done.Done()
  228. }
  229. }
  230. Logger.Printf("*** mockbroker/%d/%d: connection closed, err=%v", b.BrokerID(), idx, err)
  231. }
  232. func (b *MockBroker) defaultRequestHandler(req *request) (res encoder) {
  233. select {
  234. case res, ok := <-b.expectations:
  235. if !ok {
  236. return nil
  237. }
  238. return res
  239. case <-time.After(expectationTimeout):
  240. return nil
  241. }
  242. }
  243. func (b *MockBroker) serverError(err error) {
  244. isConnectionClosedError := false
  245. if _, ok := err.(*net.OpError); ok {
  246. isConnectionClosedError = true
  247. } else if err == io.EOF {
  248. isConnectionClosedError = true
  249. } else if err.Error() == "use of closed network connection" {
  250. isConnectionClosedError = true
  251. }
  252. if isConnectionClosedError {
  253. return
  254. }
  255. b.t.Errorf(err.Error())
  256. }
  257. // NewMockBroker launches a fake Kafka broker. It takes a TestReporter as provided by the
  258. // test framework and a channel of responses to use. If an error occurs it is
  259. // simply logged to the TestReporter and the broker exits.
  260. func NewMockBroker(t TestReporter, brokerID int32) *MockBroker {
  261. return NewMockBrokerAddr(t, brokerID, "localhost:0")
  262. }
  263. // NewMockBrokerAddr behaves like newMockBroker but listens on the address you give
  264. // it rather than just some ephemeral port.
  265. func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker {
  266. var err error
  267. broker := &MockBroker{
  268. closing: make(chan none),
  269. stopper: make(chan none),
  270. t: t,
  271. brokerID: brokerID,
  272. expectations: make(chan encoder, 512),
  273. done: sync.WaitGroup{},
  274. }
  275. broker.handler = broker.defaultRequestHandler
  276. broker.origHandler = true
  277. broker.listener, err = net.Listen("tcp", addr)
  278. if err != nil {
  279. t.Fatal(err)
  280. }
  281. Logger.Printf("*** mockbroker/%d listening on %s\n", brokerID, broker.listener.Addr().String())
  282. _, portStr, err := net.SplitHostPort(broker.listener.Addr().String())
  283. if err != nil {
  284. t.Fatal(err)
  285. }
  286. tmp, err := strconv.ParseInt(portStr, 10, 32)
  287. if err != nil {
  288. t.Fatal(err)
  289. }
  290. broker.port = int32(tmp)
  291. go broker.serverLoop()
  292. return broker
  293. }
  294. func (b *MockBroker) Returns(e encoder) {
  295. b.done.Add(1)
  296. b.expectations <- e
  297. }