Browse Source

Merge pull request #5506 from xiang90/r_rafthttp

rafthttp: simplify streamReader initilization
Xiang Li 9 years ago
parent
commit
2d4c7d6886
3 changed files with 69 additions and 52 deletions
  1. 20 2
      rafthttp/peer.go
  2. 28 35
      rafthttp/stream.go
  3. 21 15
      rafthttp/stream_test.go

+ 20 - 2
rafthttp/peer.go

@@ -177,8 +177,26 @@ func startPeer(transport *Transport, urls types.URLs, local, to, cid types.ID, r
 		}
 	}()
 
-	p.msgAppV2Reader = startStreamReader(transport, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc)
-	p.msgAppReader = startStreamReader(transport, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc)
+	p.msgAppV2Reader = &streamReader{
+		typ:    streamTypeMsgAppV2,
+		tr:     transport,
+		picker: picker,
+		to:     to,
+		status: status,
+		recvc:  p.recvc,
+		propc:  p.propc,
+	}
+	p.msgAppReader = &streamReader{
+		typ:    streamTypeMessage,
+		tr:     transport,
+		picker: picker,
+		to:     to,
+		status: status,
+		recvc:  p.recvc,
+		propc:  p.propc,
+	}
+	p.msgAppV2Reader.start()
+	p.msgAppReader.start()
 
 	return p
 }

+ 28 - 35
rafthttp/stream.go

@@ -244,46 +244,39 @@ func (cw *streamWriter) stop() {
 // streamReader is a long-running go-routine that dials to the remote stream
 // endpoint and reads messages from the response body returned.
 type streamReader struct {
-	tr            *Transport
-	picker        *urlPicker
-	t             streamType
-	local, remote types.ID
-	cid           types.ID
-	status        *peerStatus
-	recvc         chan<- raftpb.Message
-	propc         chan<- raftpb.Message
-	errorc        chan<- error
+	typ streamType
+
+	tr     *Transport
+	picker *urlPicker
+	to     types.ID
+	status *peerStatus
+	recvc  chan<- raftpb.Message
+	propc  chan<- raftpb.Message
+
+	errorc chan<- error
 
 	mu     sync.Mutex
 	paused bool
 	cancel func()
 	closer io.Closer
-	stopc  chan struct{}
-	done   chan struct{}
+
+	stopc chan struct{}
+	done  chan struct{}
 }
 
-func startStreamReader(tr *Transport, picker *urlPicker, t streamType, local, remote, cid types.ID, status *peerStatus, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader {
-	r := &streamReader{
-		tr:     tr,
-		picker: picker,
-		t:      t,
-		local:  local,
-		remote: remote,
-		cid:    cid,
-		status: status,
-		recvc:  recvc,
-		propc:  propc,
-		errorc: errorc,
-		stopc:  make(chan struct{}),
-		done:   make(chan struct{}),
+func (r *streamReader) start() {
+	r.stopc = make(chan struct{})
+	r.done = make(chan struct{})
+	if r.errorc != nil {
+		r.errorc = r.tr.ErrorC
 	}
+
 	go r.run()
-	return r
 }
 
 func (cr *streamReader) run() {
 	for {
-		t := cr.t
+		t := cr.typ
 		rc, err := cr.dial(t)
 		if err != nil {
 			if err != errUnsupportedStreamType {
@@ -317,7 +310,7 @@ func (cr *streamReader) decodeLoop(rc io.ReadCloser, t streamType) error {
 	cr.mu.Lock()
 	switch t {
 	case streamTypeMsgAppV2:
-		dec = newMsgAppV2Decoder(rc, cr.local, cr.remote)
+		dec = newMsgAppV2Decoder(rc, cr.tr.ID, cr.to)
 	case streamTypeMessage:
 		dec = &messageDecoder{r: rc}
 	default:
@@ -382,18 +375,18 @@ func (cr *streamReader) stop() {
 func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
 	u := cr.picker.pick()
 	uu := u
-	uu.Path = path.Join(t.endpoint(), cr.local.String())
+	uu.Path = path.Join(t.endpoint(), cr.tr.ID.String())
 
 	req, err := http.NewRequest("GET", uu.String(), nil)
 	if err != nil {
 		cr.picker.unreachable(u)
 		return nil, fmt.Errorf("failed to make http request to %v (%v)", u, err)
 	}
-	req.Header.Set("X-Server-From", cr.local.String())
+	req.Header.Set("X-Server-From", cr.tr.ID.String())
 	req.Header.Set("X-Server-Version", version.Version)
 	req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion)
-	req.Header.Set("X-Etcd-Cluster-ID", cr.cid.String())
-	req.Header.Set("X-Raft-To", cr.remote.String())
+	req.Header.Set("X-Etcd-Cluster-ID", cr.tr.ClusterID.String())
+	req.Header.Set("X-Raft-To", cr.to.String())
 
 	setPeerURLsHeader(req, cr.tr.URLs)
 
@@ -436,7 +429,7 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
 	case http.StatusNotFound:
 		httputil.GracefulClose(resp)
 		cr.picker.unreachable(u)
-		return nil, fmt.Errorf("remote member %s could not recognize local member", cr.remote)
+		return nil, fmt.Errorf("remote member %s could not recognize local member", cr.to)
 	case http.StatusPreconditionFailed:
 		b, err := ioutil.ReadAll(resp.Body)
 		if err != nil {
@@ -448,11 +441,11 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
 
 		switch strings.TrimSuffix(string(b), "\n") {
 		case errIncompatibleVersion.Error():
-			plog.Errorf("request sent was ignored by peer %s (server version incompatible)", cr.remote)
+			plog.Errorf("request sent was ignored by peer %s (server version incompatible)", cr.to)
 			return nil, errIncompatibleVersion
 		case errClusterIDMismatch.Error():
 			plog.Errorf("request sent was ignored (cluster ID mismatch: remote[%s]=%s, local=%s)",
-				cr.remote, resp.Header.Get("X-Etcd-Cluster-ID"), cr.cid)
+				cr.to, resp.Header.Get("X-Etcd-Cluster-ID"), cr.tr.ClusterID)
 			return nil, errClusterIDMismatch
 		default:
 			return nil, fmt.Errorf("unhandled error %q when precondition failed", string(b))

+ 21 - 15
rafthttp/stream_test.go

@@ -116,11 +116,9 @@ func TestStreamReaderDialRequest(t *testing.T) {
 	for i, tt := range []streamType{streamTypeMessage, streamTypeMsgAppV2} {
 		tr := &roundTripperRecorder{}
 		sr := &streamReader{
-			tr:     &Transport{streamRt: tr},
+			tr:     &Transport{streamRt: tr, ClusterID: types.ID(1), ID: types.ID(1)},
 			picker: mustNewURLPicker(t, []string{"http://localhost:2380"}),
-			local:  types.ID(1),
-			remote: types.ID(2),
-			cid:    types.ID(1),
+			to:     types.ID(2),
 		}
 		sr.dial(tt)
 
@@ -166,11 +164,9 @@ func TestStreamReaderDialResult(t *testing.T) {
 			err:    tt.err,
 		}
 		sr := &streamReader{
-			tr:     &Transport{streamRt: tr},
+			tr:     &Transport{streamRt: tr, ClusterID: types.ID(1)},
 			picker: mustNewURLPicker(t, []string{"http://localhost:2380"}),
-			local:  types.ID(1),
-			remote: types.ID(2),
-			cid:    types.ID(1),
+			to:     types.ID(2),
 			errorc: make(chan error, 1),
 		}
 
@@ -194,11 +190,9 @@ func TestStreamReaderDialDetectUnsupport(t *testing.T) {
 			header: http.Header{},
 		}
 		sr := &streamReader{
-			tr:     &Transport{streamRt: tr},
+			tr:     &Transport{streamRt: tr, ClusterID: types.ID(1)},
 			picker: mustNewURLPicker(t, []string{"http://localhost:2380"}),
-			local:  types.ID(1),
-			remote: types.ID(2),
-			cid:    types.ID(1),
+			to:     types.ID(2),
 		}
 
 		_, err := sr.dial(typ)
@@ -254,9 +248,19 @@ func TestStream(t *testing.T) {
 		h.sw = sw
 
 		picker := mustNewURLPicker(t, []string{srv.URL})
-		tr := &Transport{streamRt: &http.Transport{}}
-		sr := startStreamReader(tr, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), newPeerStatus(types.ID(1)), recvc, propc, nil)
-		defer sr.stop()
+		tr := &Transport{streamRt: &http.Transport{}, ClusterID: types.ID(1)}
+
+		sr := &streamReader{
+			typ:    tt.t,
+			tr:     tr,
+			picker: picker,
+			to:     types.ID(2),
+			status: newPeerStatus(types.ID(1)),
+			recvc:  recvc,
+			propc:  propc,
+		}
+		sr.start()
+
 		// wait for stream to work
 		var writec chan<- raftpb.Message
 		for {
@@ -277,6 +281,8 @@ func TestStream(t *testing.T) {
 		if !reflect.DeepEqual(m, tt.m) {
 			t.Fatalf("#%d: message = %+v, want %+v", i, m, tt.m)
 		}
+
+		sr.stop()
 	}
 }