|
@@ -3,6 +3,7 @@ package sarama
|
|
|
import (
|
|
import (
|
|
|
"bytes"
|
|
"bytes"
|
|
|
"encoding/binary"
|
|
"encoding/binary"
|
|
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
"io"
|
|
"io"
|
|
|
"net"
|
|
"net"
|
|
@@ -50,10 +51,12 @@ type MockBroker struct {
|
|
|
closing chan none
|
|
closing chan none
|
|
|
stopper chan none
|
|
stopper chan none
|
|
|
expectations chan encoder
|
|
expectations chan encoder
|
|
|
|
|
+ done sync.WaitGroup
|
|
|
listener net.Listener
|
|
listener net.Listener
|
|
|
t TestReporter
|
|
t TestReporter
|
|
|
latency time.Duration
|
|
latency time.Duration
|
|
|
handler requestHandlerFunc
|
|
handler requestHandlerFunc
|
|
|
|
|
+ origHandler bool
|
|
|
history []*RequestResponse
|
|
history []*RequestResponse
|
|
|
lock sync.Mutex
|
|
lock sync.Mutex
|
|
|
}
|
|
}
|
|
@@ -116,6 +119,21 @@ func (b *MockBroker) Addr() string {
|
|
|
return b.listener.Addr().String()
|
|
return b.listener.Addr().String()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// Wait for the remaining expectations to be consumed or that the timeout expires
|
|
|
|
|
+func (b *MockBroker) WaitForExpectations(timeout time.Duration) error {
|
|
|
|
|
+ c := make(chan none)
|
|
|
|
|
+ go func() {
|
|
|
|
|
+ b.done.Wait()
|
|
|
|
|
+ close(c)
|
|
|
|
|
+ }()
|
|
|
|
|
+ select {
|
|
|
|
|
+ case <-c:
|
|
|
|
|
+ return nil
|
|
|
|
|
+ case <-time.After(timeout):
|
|
|
|
|
+ return errors.New(fmt.Sprintf("Not all expectations have been honoured after %v", timeout))
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
// Close terminates the broker blocking until it stops internal goroutines and
|
|
// Close terminates the broker blocking until it stops internal goroutines and
|
|
|
// releases all resources.
|
|
// releases all resources.
|
|
|
func (b *MockBroker) Close() {
|
|
func (b *MockBroker) Close() {
|
|
@@ -137,6 +155,7 @@ func (b *MockBroker) Close() {
|
|
|
func (b *MockBroker) setHandler(handler requestHandlerFunc) {
|
|
func (b *MockBroker) setHandler(handler requestHandlerFunc) {
|
|
|
b.lock.Lock()
|
|
b.lock.Lock()
|
|
|
b.handler = handler
|
|
b.handler = handler
|
|
|
|
|
+ b.origHandler = false
|
|
|
b.lock.Unlock()
|
|
b.lock.Unlock()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -196,6 +215,7 @@ func (b *MockBroker) handleRequests(conn net.Conn, idx int, wg *sync.WaitGroup)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
b.lock.Lock()
|
|
b.lock.Lock()
|
|
|
|
|
+ originalHandlerUsed := b.origHandler
|
|
|
res := b.handler(req)
|
|
res := b.handler(req)
|
|
|
requestResponse := RequestResponse{req.body, res, bytesRead, 0}
|
|
requestResponse := RequestResponse{req.body, res, bytesRead, 0}
|
|
|
b.history = append(b.history, &requestResponse)
|
|
b.history = append(b.history, &requestResponse)
|
|
@@ -212,23 +232,25 @@ func (b *MockBroker) handleRequests(conn net.Conn, idx int, wg *sync.WaitGroup)
|
|
|
b.serverError(err)
|
|
b.serverError(err)
|
|
|
break
|
|
break
|
|
|
}
|
|
}
|
|
|
- if len(encodedRes) == 0 {
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- binary.BigEndian.PutUint32(resHeader, uint32(len(encodedRes)+4))
|
|
|
|
|
- binary.BigEndian.PutUint32(resHeader[4:], uint32(req.correlationID))
|
|
|
|
|
- if _, err = conn.Write(resHeader); err != nil {
|
|
|
|
|
- b.serverError(err)
|
|
|
|
|
- break
|
|
|
|
|
|
|
+ if len(encodedRes) != 0 {
|
|
|
|
|
+ binary.BigEndian.PutUint32(resHeader, uint32(len(encodedRes)+4))
|
|
|
|
|
+ binary.BigEndian.PutUint32(resHeader[4:], uint32(req.correlationID))
|
|
|
|
|
+ if _, err = conn.Write(resHeader); err != nil {
|
|
|
|
|
+ b.serverError(err)
|
|
|
|
|
+ break
|
|
|
|
|
+ }
|
|
|
|
|
+ if _, err = conn.Write(encodedRes); err != nil {
|
|
|
|
|
+ b.serverError(err)
|
|
|
|
|
+ break
|
|
|
|
|
+ }
|
|
|
|
|
+ b.lock.Lock()
|
|
|
|
|
+ requestResponse.ResponseSize = len(resHeader) + len(encodedRes)
|
|
|
|
|
+ b.lock.Unlock()
|
|
|
}
|
|
}
|
|
|
- if _, err = conn.Write(encodedRes); err != nil {
|
|
|
|
|
- b.serverError(err)
|
|
|
|
|
- break
|
|
|
|
|
|
|
+ // Prevent negative wait group in case we are using a custom handler
|
|
|
|
|
+ if originalHandlerUsed {
|
|
|
|
|
+ b.done.Done()
|
|
|
}
|
|
}
|
|
|
- b.lock.Lock()
|
|
|
|
|
- requestResponse.ResponseSize = len(resHeader) + len(encodedRes)
|
|
|
|
|
- b.lock.Unlock()
|
|
|
|
|
}
|
|
}
|
|
|
Logger.Printf("*** mockbroker/%d/%d: connection closed, err=%v", b.BrokerID(), idx, err)
|
|
Logger.Printf("*** mockbroker/%d/%d: connection closed, err=%v", b.BrokerID(), idx, err)
|
|
|
}
|
|
}
|
|
@@ -280,8 +302,10 @@ func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker
|
|
|
t: t,
|
|
t: t,
|
|
|
brokerID: brokerID,
|
|
brokerID: brokerID,
|
|
|
expectations: make(chan encoder, 512),
|
|
expectations: make(chan encoder, 512),
|
|
|
|
|
+ done: sync.WaitGroup{},
|
|
|
}
|
|
}
|
|
|
broker.handler = broker.defaultRequestHandler
|
|
broker.handler = broker.defaultRequestHandler
|
|
|
|
|
+ broker.origHandler = true
|
|
|
|
|
|
|
|
broker.listener, err = net.Listen("tcp", addr)
|
|
broker.listener, err = net.Listen("tcp", addr)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -304,5 +328,6 @@ func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (b *MockBroker) Returns(e encoder) {
|
|
func (b *MockBroker) Returns(e encoder) {
|
|
|
|
|
+ b.done.Add(1)
|
|
|
b.expectations <- e
|
|
b.expectations <- e
|
|
|
}
|
|
}
|