Pārlūkot izejas kodu

Merge pull request #6223 from heyitsanthony/fix-rafthttp-badoutgoing

rafthttp: remove WaitSchedule() from tests
Anthony Romano 9 gadi atpakaļ
vecāks
revīzija
f4b6ed2469
3 mainītis faili ar 91 papildinājumiem un 77 dzēšanām
  1. 29 32
      rafthttp/pipeline_test.go
  2. 11 5
      rafthttp/stream.go
  3. 51 40
      rafthttp/stream_test.go

+ 29 - 32
rafthttp/pipeline_test.go

@@ -34,18 +34,14 @@ import (
 // TestPipelineSend tests that pipeline could send data using roundtripper
 // and increase success count in stats.
 func TestPipelineSend(t *testing.T) {
-	tr := &roundTripperRecorder{}
+	tr := &roundTripperRecorder{rec: testutil.NewRecorderStream()}
 	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
 	tp := &Transport{pipelineRt: tr}
 	p := startTestPipeline(tp, picker)
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
-	testutil.WaitSchedule()
+	tr.rec.Wait(1)
 	p.stop()
-
-	if tr.Request() == nil {
-		t.Errorf("sender fails to post the data")
-	}
 	if p.followerStats.Counts.Success != 1 {
 		t.Errorf("success = %d, want 1", p.followerStats.Counts.Success)
 	}
@@ -79,15 +75,12 @@ func TestPipelineExceedMaximumServing(t *testing.T) {
 
 	// keep the sender busy and make the buffer full
 	// nothing can go out as we block the sender
-	testutil.WaitSchedule()
 	for i := 0; i < connPerPipeline+pipelineBufSize; i++ {
 		select {
 		case p.msgc <- raftpb.Message{}:
-		default:
+		case <-time.After(10 * time.Millisecond):
 			t.Errorf("failed to send out message")
 		}
-		// force the sender to grab data
-		testutil.WaitSchedule()
 	}
 
 	// try to send a data when we are sure the buffer is full
@@ -99,12 +92,11 @@ func TestPipelineExceedMaximumServing(t *testing.T) {
 
 	// unblock the senders and force them to send out the data
 	tr.unblock()
-	testutil.WaitSchedule()
 
 	// It could send new data after previous ones succeed
 	select {
 	case p.msgc <- raftpb.Message{}:
-	default:
+	case <-time.After(10 * time.Millisecond):
 		t.Errorf("failed to send out message")
 	}
 }
@@ -113,11 +105,16 @@ func TestPipelineExceedMaximumServing(t *testing.T) {
 // it increases fail count in stats.
 func TestPipelineSendFailed(t *testing.T) {
 	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
-	tp := &Transport{pipelineRt: newRespRoundTripper(0, errors.New("blah"))}
+	rt := newRespRoundTripper(0, errors.New("blah"))
+	rt.rec = testutil.NewRecorderStream()
+	tp := &Transport{pipelineRt: rt}
 	p := startTestPipeline(tp, picker)
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
-	testutil.WaitSchedule()
+	if _, err := rt.rec.Wait(1); err != nil {
+		t.Fatal(err)
+	}
+
 	p.stop()
 
 	if p.followerStats.Counts.Fail != 1 {
@@ -126,34 +123,40 @@ func TestPipelineSendFailed(t *testing.T) {
 }
 
 func TestPipelinePost(t *testing.T) {
-	tr := &roundTripperRecorder{}
+	tr := &roundTripperRecorder{rec: &testutil.RecorderBuffered{}}
 	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
 	tp := &Transport{ClusterID: types.ID(1), pipelineRt: tr}
 	p := startTestPipeline(tp, picker)
 	if err := p.post([]byte("some data")); err != nil {
 		t.Fatalf("unexpected post error: %v", err)
 	}
+	act, err := tr.rec.Wait(1)
+	if err != nil {
+		t.Fatal(err)
+	}
 	p.stop()
 
-	if g := tr.Request().Method; g != "POST" {
+	req := act[0].Params[0].(*http.Request)
+
+	if g := req.Method; g != "POST" {
 		t.Errorf("method = %s, want %s", g, "POST")
 	}
-	if g := tr.Request().URL.String(); g != "http://localhost:2380/raft" {
+	if g := req.URL.String(); g != "http://localhost:2380/raft" {
 		t.Errorf("url = %s, want %s", g, "http://localhost:2380/raft")
 	}
-	if g := tr.Request().Header.Get("Content-Type"); g != "application/protobuf" {
+	if g := req.Header.Get("Content-Type"); g != "application/protobuf" {
 		t.Errorf("content type = %s, want %s", g, "application/protobuf")
 	}
-	if g := tr.Request().Header.Get("X-Server-Version"); g != version.Version {
+	if g := req.Header.Get("X-Server-Version"); g != version.Version {
 		t.Errorf("version = %s, want %s", g, version.Version)
 	}
-	if g := tr.Request().Header.Get("X-Min-Cluster-Version"); g != version.MinClusterVersion {
+	if g := req.Header.Get("X-Min-Cluster-Version"); g != version.MinClusterVersion {
 		t.Errorf("min version = %s, want %s", g, version.MinClusterVersion)
 	}
-	if g := tr.Request().Header.Get("X-Etcd-Cluster-ID"); g != "1" {
+	if g := req.Header.Get("X-Etcd-Cluster-ID"); g != "1" {
 		t.Errorf("cluster id = %s, want %s", g, "1")
 	}
-	b, err := ioutil.ReadAll(tr.Request().Body)
+	b, err := ioutil.ReadAll(req.Body)
 	if err != nil {
 		t.Fatalf("unexpected ReadAll error: %v", err)
 	}
@@ -278,21 +281,15 @@ func (t *respRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
 }
 
 type roundTripperRecorder struct {
-	req *http.Request
-	sync.Mutex
+	rec testutil.Recorder
 }
 
 func (t *roundTripperRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
-	t.Lock()
-	defer t.Unlock()
-	t.req = req
+	if t.rec != nil {
+		t.rec.Record(testutil.Action{Name: "req", Params: []interface{}{req}})
+	}
 	return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
 }
-func (t *roundTripperRecorder) Request() *http.Request {
-	t.Lock()
-	defer t.Unlock()
-	return t.req
-}
 
 type nopReadCloser struct{}
 

+ 11 - 5
rafthttp/stream.go

@@ -186,9 +186,8 @@ func (cw *streamWriter) run() {
 			cw.r.ReportUnreachable(m.To)
 
 		case conn := <-cw.connc:
-			if cw.close() {
-				plog.Warningf("closed an existing TCP streaming connection with peer %s (%s writer)", cw.peerID, t)
-			}
+			cw.mu.Lock()
+			closed := cw.closeUnlocked()
 			t = conn.t
 			switch conn.t {
 			case streamTypeMsgAppV2:
@@ -200,19 +199,22 @@ func (cw *streamWriter) run() {
 			}
 			flusher = conn.Flusher
 			unflushed = 0
-			cw.mu.Lock()
 			cw.status.activate()
 			cw.closer = conn.Closer
 			cw.working = true
 			cw.mu.Unlock()
+
+			if closed {
+				plog.Warningf("closed an existing TCP streaming connection with peer %s (%s writer)", cw.peerID, t)
+			}
 			plog.Infof("established a TCP streaming connection with peer %s (%s writer)", cw.peerID, t)
 			heartbeatc, msgc = tickc, cw.msgc
 		case <-cw.stopc:
 			if cw.close() {
 				plog.Infof("closed the TCP streaming connection with peer %s (%s writer)", cw.peerID, t)
 			}
-			close(cw.done)
 			plog.Infof("stopped streaming with peer %s (writer)", cw.peerID)
+			close(cw.done)
 			return
 		}
 	}
@@ -227,6 +229,10 @@ func (cw *streamWriter) writec() (chan<- raftpb.Message, bool) {
 func (cw *streamWriter) close() bool {
 	cw.mu.Lock()
 	defer cw.mu.Unlock()
+	return cw.closeUnlocked()
+}
+
+func (cw *streamWriter) closeUnlocked() bool {
 	if !cw.working {
 		return false
 	}

+ 51 - 40
rafthttp/stream_test.go

@@ -47,41 +47,34 @@ func TestStreamWriterAttachOutgoingConn(t *testing.T) {
 	var wfc *fakeWriteFlushCloser
 	for i := 0; i < 3; i++ {
 		prevwfc := wfc
-		wfc = &fakeWriteFlushCloser{}
+		wfc = newFakeWriteFlushCloser(nil)
 		sw.attach(&outgoingConn{t: streamTypeMessage, Writer: wfc, Flusher: wfc, Closer: wfc})
 
-		// sw.attach happens asynchronously. Waits for its result in a for loop to make the
-		// test more robust on slow CI.
-		for j := 0; j < 3; j++ {
-			testutil.WaitSchedule()
-			// previous attached connection should be closed
-			if prevwfc != nil && !prevwfc.Closed() {
-				continue
-			}
-			// write chan is available
-			if _, ok := sw.writec(); !ok {
-				continue
+		// previous attached connection should be closed
+		if prevwfc != nil {
+			select {
+			case <-prevwfc.closed:
+			case <-time.After(time.Second):
+				t.Errorf("#%d: close of previous connection timed out", i)
 			}
 		}
 
-		// previous attached connection should be closed
-		if prevwfc != nil && !prevwfc.Closed() {
-			t.Errorf("#%d: close of previous connection = %v, want true", i, prevwfc.Closed())
-		}
-		// write chan is available
-		if _, ok := sw.writec(); !ok {
+		// msgc has been swapped with a new one now that prevwfc is closed
+		msgc, ok := sw.writec()
+		if !ok {
 			t.Errorf("#%d: working status = %v, want true", i, ok)
 		}
+		msgc <- raftpb.Message{}
 
-		sw.msgc <- raftpb.Message{}
-		testutil.WaitSchedule()
-		// write chan is available
+		select {
+		case <-wfc.writec:
+		case <-time.After(time.Second):
+			t.Errorf("#%d: failed to write to the underlying connection", i)
+		}
+		// write chan is still available
 		if _, ok := sw.writec(); !ok {
 			t.Errorf("#%d: working status = %v, want true", i, ok)
 		}
-		if wfc.Written() == 0 {
-			t.Errorf("#%d: failed to write to the underlying connection", i)
-		}
 	}
 
 	sw.stop()
@@ -99,23 +92,24 @@ func TestStreamWriterAttachOutgoingConn(t *testing.T) {
 func TestStreamWriterAttachBadOutgoingConn(t *testing.T) {
 	sw := startStreamWriter(types.ID(1), newPeerStatus(types.ID(1)), &stats.FollowerStats{}, &fakeRaft{})
 	defer sw.stop()
-	wfc := &fakeWriteFlushCloser{err: errors.New("blah")}
+	wfc := newFakeWriteFlushCloser(errors.New("blah"))
 	sw.attach(&outgoingConn{t: streamTypeMessage, Writer: wfc, Flusher: wfc, Closer: wfc})
 
 	sw.msgc <- raftpb.Message{}
-	testutil.WaitSchedule()
+	select {
+	case <-wfc.closed:
+	case <-time.After(time.Second):
+		t.Errorf("failed to close the underlying connection in time")
+	}
 	// no longer working
 	if _, ok := sw.writec(); ok {
 		t.Errorf("working = %v, want false", ok)
 	}
-	if !wfc.Closed() {
-		t.Errorf("failed to close the underlying connection")
-	}
 }
 
 func TestStreamReaderDialRequest(t *testing.T) {
 	for i, tt := range []streamType{streamTypeMessage, streamTypeMsgAppV2} {
-		tr := &roundTripperRecorder{}
+		tr := &roundTripperRecorder{rec: &testutil.RecorderBuffered{}}
 		sr := &streamReader{
 			peerID: types.ID(2),
 			tr:     &Transport{streamRt: tr, ClusterID: types.ID(1), ID: types.ID(1)},
@@ -123,7 +117,12 @@ func TestStreamReaderDialRequest(t *testing.T) {
 		}
 		sr.dial(tt)
 
-		req := tr.Request()
+		act, err := tr.rec.Wait(1)
+		if err != nil {
+			t.Fatal(err)
+		}
+		req := act[0].Params[0].(*http.Request)
+
 		wurl := fmt.Sprintf("http://localhost:2380" + tt.endpoint() + "/1")
 		if req.URL.String() != wurl {
 			t.Errorf("#%d: url = %s, want %s", i, req.URL.String(), wurl)
@@ -377,13 +376,25 @@ type fakeWriteFlushCloser struct {
 	mu      sync.Mutex
 	err     error
 	written int
-	closed  bool
+	closed  chan struct{}
+	writec  chan struct{}
+}
+
+func newFakeWriteFlushCloser(err error) *fakeWriteFlushCloser {
+	return &fakeWriteFlushCloser{
+		err:    err,
+		closed: make(chan struct{}),
+		writec: make(chan struct{}, 1),
+	}
 }
 
 func (wfc *fakeWriteFlushCloser) Write(p []byte) (n int, err error) {
 	wfc.mu.Lock()
 	defer wfc.mu.Unlock()
-
+	select {
+	case wfc.writec <- struct{}{}:
+	default:
+	}
 	wfc.written += len(p)
 	return len(p), wfc.err
 }
@@ -391,10 +402,7 @@ func (wfc *fakeWriteFlushCloser) Write(p []byte) (n int, err error) {
 func (wfc *fakeWriteFlushCloser) Flush() {}
 
 func (wfc *fakeWriteFlushCloser) Close() error {
-	wfc.mu.Lock()
-	defer wfc.mu.Unlock()
-
-	wfc.closed = true
+	close(wfc.closed)
 	return wfc.err
 }
 
@@ -405,9 +413,12 @@ func (wfc *fakeWriteFlushCloser) Written() int {
 }
 
 func (wfc *fakeWriteFlushCloser) Closed() bool {
-	wfc.mu.Lock()
-	defer wfc.mu.Unlock()
-	return wfc.closed
+	select {
+	case <-wfc.closed:
+		return true
+	default:
+		return false
+	}
 }
 
 type fakeStreamHandler struct {