Browse Source

Fix race condition with graceful shutdown of MockBroker

- use Logger instead of Logf in unit test to stay consistent
- add MmockBroker.WaitForExpectations for graceful shutdown
Sebastien Launay 9 years ago
parent
commit
b7f401f9d4
2 changed files with 48 additions and 20 deletions
  1. 8 5
      broker_test.go
  2. 40 15
      mockbroker.go

+ 8 - 5
broker_test.go

@@ -3,6 +3,7 @@ package sarama
 import (
 	"fmt"
 	"testing"
+	"time"
 
 	"github.com/rcrowley/go-metrics"
 )
@@ -55,12 +56,9 @@ func TestBrokerAccessors(t *testing.T) {
 
 func TestSimpleBrokerCommunication(t *testing.T) {
 	for _, tt := range brokerTestTable {
-		t.Log("Testing broker communication for", tt.name)
+		Logger.Printf("Testing broker communication for %s", tt.name)
 		mb := NewMockBroker(t, 0)
-		// Do not add expectation for ProduceRequest (No Response)
-		if len(tt.response) != 0 {
-			mb.Returns(&mockEncoder{tt.response})
-		}
+		mb.Returns(&mockEncoder{tt.response})
 		broker := NewBroker(mb.Addr())
 		// Set the broker id in order to validate local broker metrics
 		broker.id = 0
@@ -77,6 +75,11 @@ func TestSimpleBrokerCommunication(t *testing.T) {
 		if err != nil {
 			t.Error(err)
 		}
+		// Wait up to 500 ms for the remote broker to process requests
+		// in order to have consistent metrics
+		if err := mb.WaitForExpectations(500 * time.Millisecond); err != nil {
+			t.Error(err)
+		}
 		mb.Close()
 		validateBrokerMetrics(t, broker, mb)
 	}

+ 40 - 15
mockbroker.go

@@ -3,6 +3,7 @@ package sarama
 import (
 	"bytes"
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"io"
 	"net"
@@ -50,10 +51,12 @@ type MockBroker struct {
 	closing      chan none
 	stopper      chan none
 	expectations chan encoder
+	done         sync.WaitGroup
 	listener     net.Listener
 	t            TestReporter
 	latency      time.Duration
 	handler      requestHandlerFunc
+	origHandler  bool
 	history      []*RequestResponse
 	lock         sync.Mutex
 }
@@ -116,6 +119,21 @@ func (b *MockBroker) Addr() string {
 	return b.listener.Addr().String()
 }
 
+// Wait for the remaining expectations to be consumed or that the timeout expires
+func (b *MockBroker) WaitForExpectations(timeout time.Duration) error {
+	c := make(chan none)
+	go func() {
+		b.done.Wait()
+		close(c)
+	}()
+	select {
+	case <-c:
+		return nil
+	case <-time.After(timeout):
+		return errors.New(fmt.Sprintf("Not all expectations have been honoured after %v", timeout))
+	}
+}
+
 // Close terminates the broker blocking until it stops internal goroutines and
 // releases all resources.
 func (b *MockBroker) Close() {
@@ -137,6 +155,7 @@ func (b *MockBroker) Close() {
 func (b *MockBroker) setHandler(handler requestHandlerFunc) {
 	b.lock.Lock()
 	b.handler = handler
+	b.origHandler = false
 	b.lock.Unlock()
 }
 
@@ -196,6 +215,7 @@ func (b *MockBroker) handleRequests(conn net.Conn, idx int, wg *sync.WaitGroup)
 		}
 
 		b.lock.Lock()
+		originalHandlerUsed := b.origHandler
 		res := b.handler(req)
 		requestResponse := RequestResponse{req.body, res, bytesRead, 0}
 		b.history = append(b.history, &requestResponse)
@@ -212,23 +232,25 @@ func (b *MockBroker) handleRequests(conn net.Conn, idx int, wg *sync.WaitGroup)
 			b.serverError(err)
 			break
 		}
-		if len(encodedRes) == 0 {
-			continue
-		}
-
-		binary.BigEndian.PutUint32(resHeader, uint32(len(encodedRes)+4))
-		binary.BigEndian.PutUint32(resHeader[4:], uint32(req.correlationID))
-		if _, err = conn.Write(resHeader); err != nil {
-			b.serverError(err)
-			break
+		if len(encodedRes) != 0 {
+			binary.BigEndian.PutUint32(resHeader, uint32(len(encodedRes)+4))
+			binary.BigEndian.PutUint32(resHeader[4:], uint32(req.correlationID))
+			if _, err = conn.Write(resHeader); err != nil {
+				b.serverError(err)
+				break
+			}
+			if _, err = conn.Write(encodedRes); err != nil {
+				b.serverError(err)
+				break
+			}
+			b.lock.Lock()
+			requestResponse.ResponseSize = len(resHeader) + len(encodedRes)
+			b.lock.Unlock()
 		}
-		if _, err = conn.Write(encodedRes); err != nil {
-			b.serverError(err)
-			break
+		// Prevent negative wait group in case we are using a custom handler
+		if originalHandlerUsed {
+			b.done.Done()
 		}
-		b.lock.Lock()
-		requestResponse.ResponseSize = len(resHeader) + len(encodedRes)
-		b.lock.Unlock()
 	}
 	Logger.Printf("*** mockbroker/%d/%d: connection closed, err=%v", b.BrokerID(), idx, err)
 }
@@ -280,8 +302,10 @@ func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker
 		t:            t,
 		brokerID:     brokerID,
 		expectations: make(chan encoder, 512),
+		done:         sync.WaitGroup{},
 	}
 	broker.handler = broker.defaultRequestHandler
+	broker.origHandler = true
 
 	broker.listener, err = net.Listen("tcp", addr)
 	if err != nil {
@@ -304,5 +328,6 @@ func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker
 }
 
 func (b *MockBroker) Returns(e encoder) {
+	b.done.Add(1)
 	b.expectations <- e
 }