Browse Source

Merge pull request #2906 from yichengq/fix-pipeline-stop

rafthttp: fix pipeline.stop may block
Yicheng Qin 10 years ago
parent
commit
9e8d589163
2 changed files with 78 additions and 5 deletions
  1. 36 2
      rafthttp/pipeline.go
  2. 42 3
      rafthttp/pipeline_test.go

+ 36 - 2
rafthttp/pipeline.go

@@ -16,6 +16,7 @@ package rafthttp
 
 import (
 	"bytes"
+	"errors"
 	"fmt"
 	"io/ioutil"
 	"log"
@@ -41,6 +42,12 @@ const (
 	pipelineBufSize = 64
 )
 
+var errStopped = errors.New("stopped")
+
+type canceler interface {
+	CancelRequest(*http.Request)
+}
+
 type pipeline struct {
 	from, to types.ID
 	cid      types.ID
@@ -53,7 +60,8 @@ type pipeline struct {
 
 	msgc chan raftpb.Message
 	// wait for the handling routines
-	wg sync.WaitGroup
+	wg    sync.WaitGroup
+	stopc chan struct{}
 	sync.Mutex
 	// if the last send was successful, the pipeline is active.
 	// Or it is inactive
@@ -71,6 +79,7 @@ func newPipeline(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID
 		fs:     fs,
 		r:      r,
 		errorc: errorc,
+		stopc:  make(chan struct{}),
 		msgc:   make(chan raftpb.Message, pipelineBufSize),
 		active: true,
 	}
@@ -83,6 +92,7 @@ func newPipeline(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID
 
 func (p *pipeline) stop() {
 	close(p.msgc)
+	close(p.stopc)
 	p.wg.Wait()
 }
 
@@ -91,6 +101,9 @@ func (p *pipeline) handle() {
 	for m := range p.msgc {
 		start := time.Now()
 		err := p.post(pbutil.MustMarshal(&m))
+		if err == errStopped {
+			return
+		}
 		end := time.Now()
 
 		p.Lock()
@@ -132,7 +145,7 @@ func (p *pipeline) handle() {
 
 // post POSTs a data payload to a url. Returns nil if the POST succeeds,
 // error on any failure.
-func (p *pipeline) post(data []byte) error {
+func (p *pipeline) post(data []byte) (err error) {
 	u := p.picker.pick()
 	uu := u
 	uu.Path = RaftPrefix
@@ -146,7 +159,28 @@ func (p *pipeline) post(data []byte) error {
 	req.Header.Set("X-Server-Version", version.Version)
 	req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion)
 	req.Header.Set("X-Etcd-Cluster-ID", p.cid.String())
+
+	var stopped bool
+	defer func() {
+		if stopped {
+			// rewrite to errStopped so the caller goroutine can stop itself
+			err = errStopped
+		}
+	}()
+	done := make(chan struct{}, 1)
+	go func() {
+		select {
+		case <-done:
+		case <-p.stopc:
+			stopped = true
+			if cancel, ok := p.tr.(canceler); ok {
+				cancel.CancelRequest(req)
+			}
+		}
+	}()
+
 	resp, err := p.tr.RoundTrip(req)
+	done <- struct{}{}
 	if err != nil {
 		p.picker.unreachable(u)
 		return err

+ 42 - 3
rafthttp/pipeline_test.go

@@ -21,6 +21,7 @@ import (
 	"net/http"
 	"sync"
 	"testing"
+	"time"
 
 	"github.com/coreos/etcd/etcdserver/stats"
 	"github.com/coreos/etcd/pkg/testutil"
@@ -38,6 +39,7 @@ func TestPipelineSend(t *testing.T) {
 	p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
+	testutil.ForceGosched()
 	p.stop()
 
 	if tr.Request() == nil {
@@ -97,6 +99,7 @@ func TestPipelineSendFailed(t *testing.T) {
 	p := newPipeline(newRespRoundTripper(0, errors.New("blah")), picker, types.ID(2), types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
+	testutil.ForceGosched()
 	p.stop()
 
 	fs.Lock()
@@ -188,20 +191,56 @@ func TestPipelinePostErrorc(t *testing.T) {
 	}
 }
 
+func TestStopBlockedPipeline(t *testing.T) {
+	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
+	p := newPipeline(newRoundTripperBlocker(), picker, types.ID(2), types.ID(1), types.ID(1), nil, &fakeRaft{}, nil)
+	// send many messages that most of them will be blocked in buffer
+	for i := 0; i < connPerPipeline*10; i++ {
+		p.msgc <- raftpb.Message{}
+	}
+
+	done := make(chan struct{})
+	go func() {
+		p.stop()
+		done <- struct{}{}
+	}()
+	select {
+	case <-done:
+	case <-time.After(time.Second):
+		t.Fatalf("failed to stop pipeline in 1s")
+	}
+}
+
 type roundTripperBlocker struct {
-	c chan struct{}
+	c         chan error
+	mu        sync.Mutex
+	unblocked bool
 }
 
 func newRoundTripperBlocker() *roundTripperBlocker {
-	return &roundTripperBlocker{c: make(chan struct{})}
+	return &roundTripperBlocker{c: make(chan error)}
 }
 func (t *roundTripperBlocker) RoundTrip(req *http.Request) (*http.Response, error) {
-	<-t.c
+	err := <-t.c
+	if err != nil {
+		return nil, err
+	}
 	return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
 }
 func (t *roundTripperBlocker) unblock() {
+	t.mu.Lock()
+	t.unblocked = true
+	t.mu.Unlock()
 	close(t.c)
 }
+func (t *roundTripperBlocker) CancelRequest(req *http.Request) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	if t.unblocked {
+		return
+	}
+	t.c <- errors.New("request canceled")
+}
 
 type respRoundTripper struct {
 	code   int