Browse Source

Merge pull request #4542 from xiang90/t

rafthttp: refactoring
Xiang Li 9 years ago
parent
commit
a4105b5cce
2 changed files with 33 additions and 24 deletions
  1. 31 24
      rafthttp/stream.go
  2. 2 0
      rafthttp/transport.go

+ 31 - 24
rafthttp/stream.go

@@ -95,8 +95,7 @@ type outgoingConn struct {
 	io.Closer
 }
 
-// streamWriter is a long-running go-routine that writes messages into the
-// attached outgoingConn.
+// streamWriter writes messages to the attached outgoingConn.
 type streamWriter struct {
 	id     types.ID
 	status *peerStatus
@@ -113,6 +112,8 @@ type streamWriter struct {
 	done  chan struct{}
 }
 
+// startStreamWriter creates a streamWrite and starts a long running go-routine that accepts
+// messages and writes to the attached outgoing connection.
 func startStreamWriter(id types.ID, status *peerStatus, fs *stats.FollowerStats, r Raft) *streamWriter {
 	w := &streamWriter{
 		id:     id,
@@ -129,40 +130,46 @@ func startStreamWriter(id types.ID, status *peerStatus, fs *stats.FollowerStats,
 }
 
 func (cw *streamWriter) run() {
-	var msgc chan raftpb.Message
-	var heartbeatc <-chan time.Time
-	var t streamType
-	var enc encoder
-	var flusher http.Flusher
+	var (
+		msgc       chan raftpb.Message
+		heartbeatc <-chan time.Time
+		t          streamType
+		enc        encoder
+		flusher    http.Flusher
+	)
 	tickc := time.Tick(ConnReadTimeout / 3)
 
 	for {
 		select {
 		case <-heartbeatc:
 			start := time.Now()
-			if err := enc.encode(linkHeartbeatMessage); err != nil {
-				reportSentFailure(string(t), linkHeartbeatMessage)
-
-				cw.status.deactivate(failureType{source: t.String(), action: "heartbeat"}, err.Error())
-				cw.close()
-				heartbeatc, msgc = nil, nil
+			err := enc.encode(linkHeartbeatMessage)
+			if err == nil {
+				flusher.Flush()
+				reportSentDuration(string(t), linkHeartbeatMessage, time.Since(start))
 				continue
 			}
-			flusher.Flush()
-			reportSentDuration(string(t), linkHeartbeatMessage, time.Since(start))
+
+			reportSentFailure(string(t), linkHeartbeatMessage)
+			cw.status.deactivate(failureType{source: t.String(), action: "heartbeat"}, err.Error())
+			cw.close()
+			heartbeatc, msgc = nil, nil
+
 		case m := <-msgc:
 			start := time.Now()
-			if err := enc.encode(m); err != nil {
-				reportSentFailure(string(t), m)
-
-				cw.status.deactivate(failureType{source: t.String(), action: "write"}, err.Error())
-				cw.close()
-				heartbeatc, msgc = nil, nil
-				cw.r.ReportUnreachable(m.To)
+			err := enc.encode(m)
+			if err == nil {
+				flusher.Flush()
+				reportSentDuration(string(t), m, time.Since(start))
 				continue
 			}
-			flusher.Flush()
-			reportSentDuration(string(t), m, time.Since(start))
+
+			reportSentFailure(string(t), m)
+			cw.status.deactivate(failureType{source: t.String(), action: "write"}, err.Error())
+			cw.close()
+			heartbeatc, msgc = nil, nil
+			cw.r.ReportUnreachable(m.To)
+
 		case conn := <-cw.connc:
 			cw.close()
 			t = conn.t

+ 2 - 0
rafthttp/transport.go

@@ -289,6 +289,8 @@ func (t *Transport) ActiveSince(id types.ID) time.Time {
 }
 
 func (t *Transport) SendSnapshot(m snap.Message) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
 	p := t.peers[types.ID(m.To)]
 	if p == nil {
 		m.CloseWithError(errMemberNotFound)