Bläddra i källkod

Merge pull request #6175 from heyitsanthony/fix-conn-race

rafthttp: fix race between streamReader.stop() and connection closer
Anthony Romano 9 år sedan
förälder
incheckning
9eb6ea34bd
2 ändrade filer med 65 tillägg och 1 borttagningar
  1. 10 1
      rafthttp/stream.go
  2. 55 0
      rafthttp/stream_test.go

+ 10 - 1
rafthttp/stream.go

@@ -332,7 +332,16 @@ func (cr *streamReader) decodeLoop(rc io.ReadCloser, t streamType) error {
 	default:
 		plog.Panicf("unhandled stream type %s", t)
 	}
-	cr.closer = rc
+	select {
+	case <-cr.stopc:
+		cr.mu.Unlock()
+		if err := rc.Close(); err != nil {
+			return err
+		}
+		return io.EOF
+	default:
+		cr.closer = rc
+	}
 	cr.mu.Unlock()
 
 	for {

+ 55 - 0
rafthttp/stream_test.go

@@ -17,6 +17,7 @@ package rafthttp
 import (
 	"errors"
 	"fmt"
+	"io"
 	"net/http"
 	"net/http/httptest"
 	"reflect"
@@ -180,6 +181,60 @@ func TestStreamReaderDialResult(t *testing.T) {
 	}
 }
 
+// TestStreamReaderStopOnDial tests a stream reader closes the connection on stop.
+func TestStreamReaderStopOnDial(t *testing.T) {
+	defer testutil.AfterTest(t)
+	h := http.Header{}
+	h.Add("X-Server-Version", version.Version)
+	tr := &respWaitRoundTripper{rrt: &respRoundTripper{code: http.StatusOK, header: h}}
+	sr := &streamReader{
+		peerID: types.ID(2),
+		tr:     &Transport{streamRt: tr, ClusterID: types.ID(1)},
+		picker: mustNewURLPicker(t, []string{"http://localhost:2380"}),
+		errorc: make(chan error, 1),
+		typ:    streamTypeMessage,
+		status: newPeerStatus(types.ID(2)),
+	}
+	tr.onResp = func() {
+		// stop() waits for the run() goroutine to exit, but that exit
+		// needs a response from RoundTrip() first; use goroutine
+		go sr.stop()
+		// wait so that stop() is blocked on run() exiting
+		time.Sleep(10 * time.Millisecond)
+		// sr.run() completes dialing then begins decoding while stopped
+	}
+	sr.start()
+	select {
+	case <-sr.done:
+	case <-time.After(time.Second):
+		t.Fatal("streamReader did not stop in time")
+	}
+}
+
+type respWaitRoundTripper struct {
+	rrt    *respRoundTripper
+	onResp func()
+}
+
+func (t *respWaitRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+	resp, err := t.rrt.RoundTrip(req)
+	resp.Body = newWaitReadCloser()
+	t.onResp()
+	return resp, err
+}
+
+type waitReadCloser struct{ closec chan struct{} }
+
+func newWaitReadCloser() *waitReadCloser { return &waitReadCloser{make(chan struct{})} }
+func (wrc *waitReadCloser) Read(p []byte) (int, error) {
+	<-wrc.closec
+	return 0, io.EOF
+}
+func (wrc *waitReadCloser) Close() error {
+	close(wrc.closec)
+	return nil
+}
+
 // TestStreamReaderDialDetectUnsupport tests that dial func could find
 // out that the stream type is not supported by the remote.
 func TestStreamReaderDialDetectUnsupport(t *testing.T) {