Browse Source

Merge pull request #2957 from yichengq/fix-pipeline-test

rafthttp: fix TestStopBlockedPipeline
Yicheng Qin 10 years ago
parent
commit
e46fa0a213
2 changed files with 24 additions and 15 deletions
  1. 4 0
      rafthttp/pipeline.go
  2. 20 15
      rafthttp/pipeline_test.go

+ 4 - 0
rafthttp/pipeline.go

@@ -153,6 +153,7 @@ func (p *pipeline) post(data []byte) (err error) {
 		select {
 		case <-done:
 		case <-p.stopc:
+			waitSchedule()
 			stopped = true
 			if cancel, ok := p.tr.(canceler); ok {
 				cancel.CancelRequest(req)
@@ -199,3 +200,6 @@ func (p *pipeline) post(data []byte) (err error) {
 		return fmt.Errorf("unexpected http status %s while posting to %q", http.StatusText(resp.StatusCode), req.URL.String())
 	}
 }
+
+// waitSchedule waits other goroutines to be scheduled for a while
+func waitSchedule() { time.Sleep(time.Millisecond) }

+ 20 - 15
rafthttp/pipeline_test.go

@@ -212,34 +212,39 @@ func TestStopBlockedPipeline(t *testing.T) {
 }
 
 type roundTripperBlocker struct {
-	c         chan error
-	mu        sync.Mutex
-	unblocked bool
+	unblockc chan struct{}
+	mu       sync.Mutex
+	cancel   map[*http.Request]chan struct{}
 }
 
 func newRoundTripperBlocker() *roundTripperBlocker {
-	return &roundTripperBlocker{c: make(chan error)}
+	return &roundTripperBlocker{
+		unblockc: make(chan struct{}),
+		cancel:   make(map[*http.Request]chan struct{}),
+	}
 }
 func (t *roundTripperBlocker) RoundTrip(req *http.Request) (*http.Response, error) {
-	err := <-t.c
-	if err != nil {
-		return nil, err
+	c := make(chan struct{}, 1)
+	t.mu.Lock()
+	t.cancel[req] = c
+	t.mu.Unlock()
+	select {
+	case <-t.unblockc:
+		return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
+	case <-c:
+		return nil, errors.New("request canceled")
 	}
-	return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
 }
 func (t *roundTripperBlocker) unblock() {
-	t.mu.Lock()
-	t.unblocked = true
-	t.mu.Unlock()
-	close(t.c)
+	close(t.unblockc)
 }
 func (t *roundTripperBlocker) CancelRequest(req *http.Request) {
 	t.mu.Lock()
 	defer t.mu.Unlock()
-	if t.unblocked {
-		return
+	if c, ok := t.cancel[req]; ok {
+		c <- struct{}{}
+		delete(t.cancel, req)
 	}
-	t.c <- errors.New("request canceled")
 }
 
 type respRoundTripper struct {