Преглед на файлове

rafthttp: fix pipeline.stop may block

This PR makes pipeline.stop stop quickly. It cancels inflight requests,
and stops sending messages in the buffer.
Yicheng Qin преди 10 години
родител
ревизия
7f8925e172
променени са 2 файла, в които са добавени 78 реда и са изтрити 5 реда
  1. 36 2
      rafthttp/pipeline.go
  2. 42 3
      rafthttp/pipeline_test.go

+ 36 - 2
rafthttp/pipeline.go

@@ -16,6 +16,7 @@ package rafthttp
 
 import (
 	"bytes"
+	"errors"
 	"fmt"
 	"io/ioutil"
 	"log"
@@ -41,6 +42,12 @@ const (
 	pipelineBufSize = 64
 )
 
+var errStopped = errors.New("stopped")
+
+type canceler interface {
+	CancelRequest(*http.Request)
+}
+
 type pipeline struct {
 	from, to types.ID
 	cid      types.ID
@@ -53,7 +60,8 @@ type pipeline struct {
 
 	msgc chan raftpb.Message
 	// wait for the handling routines
-	wg sync.WaitGroup
+	wg    sync.WaitGroup
+	stopc chan struct{}
 	sync.Mutex
 	// if the last send was successful, the pipeline is active.
 	// Or it is inactive
@@ -71,6 +79,7 @@ func newPipeline(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID
 		fs:     fs,
 		r:      r,
 		errorc: errorc,
+		stopc:  make(chan struct{}),
 		msgc:   make(chan raftpb.Message, pipelineBufSize),
 		active: true,
 	}
@@ -83,6 +92,7 @@ func newPipeline(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID
 
 func (p *pipeline) stop() {
 	close(p.msgc)
+	close(p.stopc)
 	p.wg.Wait()
 }
 
@@ -91,6 +101,9 @@ func (p *pipeline) handle() {
 	for m := range p.msgc {
 		start := time.Now()
 		err := p.post(pbutil.MustMarshal(&m))
+		if err == errStopped {
+			return
+		}
 		end := time.Now()
 
 		p.Lock()
@@ -132,7 +145,7 @@ func (p *pipeline) handle() {
 
 // post POSTs a data payload to a url. Returns nil if the POST succeeds,
 // error on any failure.
-func (p *pipeline) post(data []byte) error {
+func (p *pipeline) post(data []byte) (err error) {
 	u := p.picker.pick()
 	uu := u
 	uu.Path = RaftPrefix
@@ -146,7 +159,28 @@ func (p *pipeline) post(data []byte) error {
 	req.Header.Set("X-Server-Version", version.Version)
 	req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion)
 	req.Header.Set("X-Etcd-Cluster-ID", p.cid.String())
+
+	var stopped bool
+	defer func() {
+		if stopped {
+			// rewrite to errStopped so the caller goroutine can stop itself
+			err = errStopped
+		}
+	}()
+	done := make(chan struct{}, 1)
+	go func() {
+		select {
+		case <-done:
+		case <-p.stopc:
+			stopped = true
+			if cancel, ok := p.tr.(canceler); ok {
+				cancel.CancelRequest(req)
+			}
+		}
+	}()
+
 	resp, err := p.tr.RoundTrip(req)
+	done <- struct{}{}
 	if err != nil {
 		p.picker.unreachable(u)
 		return err

+ 42 - 3
rafthttp/pipeline_test.go

@@ -21,6 +21,7 @@ import (
 	"net/http"
 	"sync"
 	"testing"
+	"time"
 
 	"github.com/coreos/etcd/etcdserver/stats"
 	"github.com/coreos/etcd/pkg/testutil"
@@ -38,6 +39,7 @@ func TestPipelineSend(t *testing.T) {
 	p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
+	testutil.ForceGosched()
 	p.stop()
 
 	if tr.Request() == nil {
@@ -97,6 +99,7 @@ func TestPipelineSendFailed(t *testing.T) {
 	p := newPipeline(newRespRoundTripper(0, errors.New("blah")), picker, types.ID(2), types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
+	testutil.ForceGosched()
 	p.stop()
 
 	fs.Lock()
@@ -188,20 +191,56 @@ func TestPipelinePostErrorc(t *testing.T) {
 	}
 }
 
+func TestStopBlockedPipeline(t *testing.T) {
+	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
+	p := newPipeline(newRoundTripperBlocker(), picker, types.ID(2), types.ID(1), types.ID(1), nil, &fakeRaft{}, nil)
+	// send many messages that most of them will be blocked in buffer
+	for i := 0; i < connPerPipeline*10; i++ {
+		p.msgc <- raftpb.Message{}
+	}
+
+	done := make(chan struct{})
+	go func() {
+		p.stop()
+		done <- struct{}{}
+	}()
+	select {
+	case <-done:
+	case <-time.After(time.Second):
+		t.Fatalf("failed to stop pipeline in 1s")
+	}
+}
+
 type roundTripperBlocker struct {
-	c chan struct{}
+	c         chan error
+	mu        sync.Mutex
+	unblocked bool
 }
 
 func newRoundTripperBlocker() *roundTripperBlocker {
-	return &roundTripperBlocker{c: make(chan struct{})}
+	return &roundTripperBlocker{c: make(chan error)}
 }
 func (t *roundTripperBlocker) RoundTrip(req *http.Request) (*http.Response, error) {
-	<-t.c
+	err := <-t.c
+	if err != nil {
+		return nil, err
+	}
 	return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
 }
 func (t *roundTripperBlocker) unblock() {
+	t.mu.Lock()
+	t.unblocked = true
+	t.mu.Unlock()
 	close(t.c)
 }
+func (t *roundTripperBlocker) CancelRequest(req *http.Request) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	if t.unblocked {
+		return
+	}
+	t.c <- errors.New("request canceled")
+}
 
 type respRoundTripper struct {
 	code   int