Pārlūkot izejas kodu

Merge pull request #3757 from xiang90/race

rafthttp: fix data races detected by go race detector
Xiang Li 10 gadi atpakaļ
vecāks
revīzija
306dd7183b
3 mainītis faili ar 56 papildinājumiem un 37 dzēšanām
  1. 27 32
      rafthttp/pipeline.go
  2. 2 0
      rafthttp/pipeline_test.go
  3. 27 5
      rafthttp/stream_test.go

+ 27 - 32
rafthttp/pipeline.go

@@ -80,43 +80,45 @@ func newPipeline(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID
 }
 
 func (p *pipeline) stop() {
-	close(p.msgc)
 	close(p.stopc)
 	p.wg.Wait()
 }
 
 func (p *pipeline) handle() {
 	defer p.wg.Done()
-	for m := range p.msgc {
-		start := time.Now()
-		err := p.post(pbutil.MustMarshal(&m))
-		if err == errStopped {
-			return
-		}
-		end := time.Now()
 
-		if err != nil {
-			p.status.deactivate(failureType{source: pipelineMsg, action: "write"}, err.Error())
+	for {
+		select {
+		case m := <-p.msgc:
+			start := time.Now()
+			err := p.post(pbutil.MustMarshal(&m))
+			end := time.Now()
+
+			if err != nil {
+				p.status.deactivate(failureType{source: pipelineMsg, action: "write"}, err.Error())
+
+				reportSentFailure(pipelineMsg, m)
+				if m.Type == raftpb.MsgApp && p.fs != nil {
+					p.fs.Fail()
+				}
+				p.r.ReportUnreachable(m.To)
+				if isMsgSnap(m) {
+					p.r.ReportSnapshot(m.To, raft.SnapshotFailure)
+				}
+				continue
+			}
 
-			reportSentFailure(pipelineMsg, m)
+			p.status.activate()
 			if m.Type == raftpb.MsgApp && p.fs != nil {
-				p.fs.Fail()
+				p.fs.Succ(end.Sub(start))
 			}
-			p.r.ReportUnreachable(m.To)
 			if isMsgSnap(m) {
-				p.r.ReportSnapshot(m.To, raft.SnapshotFailure)
+				p.r.ReportSnapshot(m.To, raft.SnapshotFinish)
 			}
-			continue
-		}
-
-		p.status.activate()
-		if m.Type == raftpb.MsgApp && p.fs != nil {
-			p.fs.Succ(end.Sub(start))
-		}
-		if isMsgSnap(m) {
-			p.r.ReportSnapshot(m.To, raft.SnapshotFinish)
+			reportSentDuration(pipelineMsg, m, time.Since(start))
+		case <-p.stopc:
+			return
 		}
-		reportSentDuration(pipelineMsg, m, time.Since(start))
 	}
 }
 
@@ -126,13 +128,6 @@ func (p *pipeline) post(data []byte) (err error) {
 	u := p.picker.pick()
 	req := createPostRequest(u, RaftPrefix, bytes.NewBuffer(data), "application/protobuf", p.from, p.cid)
 
-	var stopped bool
-	defer func() {
-		if stopped {
-			// rewrite to errStopped so the caller goroutine can stop itself
-			err = errStopped
-		}
-	}()
 	done := make(chan struct{}, 1)
 	cancel := httputil.RequestCanceler(p.tr, req)
 	go func() {
@@ -140,7 +135,6 @@ func (p *pipeline) post(data []byte) (err error) {
 		case <-done:
 		case <-p.stopc:
 			waitSchedule()
-			stopped = true
 			cancel()
 		}
 	}()
@@ -165,6 +159,7 @@ func (p *pipeline) post(data []byte) (err error) {
 		reportCriticalError(err, p.errorc)
 		return nil
 	}
+
 	return err
 }
 

+ 2 - 0
rafthttp/pipeline_test.go

@@ -246,9 +246,11 @@ func newRoundTripperBlocker() *roundTripperBlocker {
 		cancel:   make(map[*http.Request]chan struct{}),
 	}
 }
+
 func (t *roundTripperBlocker) unblock() {
 	close(t.unblockc)
 }
+
 func (t *roundTripperBlocker) CancelRequest(req *http.Request) {
 	t.mu.Lock()
 	defer t.mu.Unlock()

+ 27 - 5
rafthttp/stream_test.go

@@ -20,6 +20,7 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"reflect"
+	"sync"
 	"testing"
 	"time"
 
@@ -49,8 +50,8 @@ func TestStreamWriterAttachOutgoingConn(t *testing.T) {
 		sw.attach(&outgoingConn{t: streamTypeMessage, Writer: wfc, Flusher: wfc, Closer: wfc})
 		testutil.WaitSchedule()
 		// previous attached connection should be closed
-		if prevwfc != nil && prevwfc.closed != true {
-			t.Errorf("#%d: close of previous connection = %v, want true", i, prevwfc.closed)
+		if prevwfc != nil && prevwfc.Closed() != true {
+			t.Errorf("#%d: close of previous connection = %v, want true", i, prevwfc.Closed())
 		}
 		// starts working
 		if _, ok := sw.writec(); ok != true {
@@ -63,7 +64,7 @@ func TestStreamWriterAttachOutgoingConn(t *testing.T) {
 		if _, ok := sw.writec(); ok != true {
 			t.Errorf("#%d: working status = %v, want true", i, ok)
 		}
-		if wfc.written == 0 {
+		if wfc.Written() == 0 {
 			t.Errorf("#%d: failed to write to the underlying connection", i)
 		}
 	}
@@ -73,7 +74,7 @@ func TestStreamWriterAttachOutgoingConn(t *testing.T) {
 	if _, ok := sw.writec(); ok != false {
 		t.Errorf("working status after stop = %v, want false", ok)
 	}
-	if wfc.closed != true {
+	if wfc.Closed() != true {
 		t.Errorf("failed to close the underlying connection")
 	}
 }
@@ -92,7 +93,7 @@ func TestStreamWriterAttachBadOutgoingConn(t *testing.T) {
 	if _, ok := sw.writec(); ok != false {
 		t.Errorf("working = %v, want false", ok)
 	}
-	if wfc.closed != true {
+	if wfc.Closed() != true {
 		t.Errorf("failed to close the underlying connection")
 	}
 }
@@ -297,21 +298,42 @@ func TestCheckStreamSupport(t *testing.T) {
 }
 
 type fakeWriteFlushCloser struct {
+	mu      sync.Mutex
 	err     error
 	written int
 	closed  bool
 }
 
 func (wfc *fakeWriteFlushCloser) Write(p []byte) (n int, err error) {
+	wfc.mu.Lock()
+	defer wfc.mu.Unlock()
+
 	wfc.written += len(p)
 	return len(p), wfc.err
 }
+
 func (wfc *fakeWriteFlushCloser) Flush() {}
+
 func (wfc *fakeWriteFlushCloser) Close() error {
+	wfc.mu.Lock()
+	defer wfc.mu.Unlock()
+
 	wfc.closed = true
 	return wfc.err
 }
 
+func (wfc *fakeWriteFlushCloser) Written() int {
+	wfc.mu.Lock()
+	defer wfc.mu.Unlock()
+	return wfc.written
+}
+
+func (wfc *fakeWriteFlushCloser) Closed() bool {
+	wfc.mu.Lock()
+	defer wfc.mu.Unlock()
+	return wfc.closed
+}
+
 type fakeStreamHandler struct {
 	t  streamType
 	sw *streamWriter