Browse Source

Merge pull request #7255 from sinsharat/use_requestWithContext_for_cancel

rafthttp: use http.Request.WithContext instead of Cancel
Xiang Li 9 years ago
parent
commit
42e7d4d09d
4 changed files with 15 additions and 5 deletions
  1. 3 0
      rafthttp/fake_roundtripper_test.go
  2. 3 2
      rafthttp/pipeline.go
  3. 4 2
      rafthttp/snapshot_sender.go
  4. 5 1
      rafthttp/stream.go

+ 3 - 0
rafthttp/fake_roundtripper_test.go

@@ -24,11 +24,14 @@ func (t *roundTripperBlocker) RoundTrip(req *http.Request) (*http.Response, erro
 	t.mu.Lock()
 	t.cancel[req] = c
 	t.mu.Unlock()
+	ctx := req.Context()
 	select {
 	case <-t.unblockc:
 		return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
 	case <-req.Cancel:
 		return nil, errors.New("request canceled")
+	case <-ctx.Done():
+		return nil, errors.New("request canceled")
 	case <-c:
 		return nil, errors.New("request canceled")
 	}

+ 3 - 2
rafthttp/pipeline.go

@@ -16,13 +16,13 @@ package rafthttp
 
 import (
 	"bytes"
+	"context"
 	"errors"
 	"io/ioutil"
 	"sync"
 	"time"
 
 	"github.com/coreos/etcd/etcdserver/stats"
-	"github.com/coreos/etcd/pkg/httputil"
 	"github.com/coreos/etcd/pkg/pbutil"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/raft"
@@ -118,7 +118,8 @@ func (p *pipeline) post(data []byte) (err error) {
 	req := createPostRequest(u, RaftPrefix, bytes.NewBuffer(data), "application/protobuf", p.tr.URLs, p.tr.ID, p.tr.ClusterID)
 
 	done := make(chan struct{}, 1)
-	cancel := httputil.RequestCanceler(req)
+	ctx, cancel := context.WithCancel(context.Background())
+	req = req.WithContext(ctx)
 	go func() {
 		select {
 		case <-done:

+ 4 - 2
rafthttp/snapshot_sender.go

@@ -16,6 +16,7 @@ package rafthttp
 
 import (
 	"bytes"
+	"context"
 	"io"
 	"io/ioutil"
 	"net/http"
@@ -104,7 +105,9 @@ func (s *snapshotSender) send(merged snap.Message) {
 // post posts the given request.
 // It returns nil when request is sent out and processed successfully.
 func (s *snapshotSender) post(req *http.Request) (err error) {
-	cancel := httputil.RequestCanceler(req)
+	ctx, cancel := context.WithCancel(context.Background())
+	req = req.WithContext(ctx)
+	defer cancel()
 
 	type responseAndError struct {
 		resp *http.Response
@@ -130,7 +133,6 @@ func (s *snapshotSender) post(req *http.Request) (err error) {
 
 	select {
 	case <-s.stopc:
-		cancel()
 		return errStopped
 	case r := <-result:
 		if r.err != nil {

+ 5 - 1
rafthttp/stream.go

@@ -15,6 +15,7 @@
 package rafthttp
 
 import (
+	"context"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -427,14 +428,17 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
 
 	setPeerURLsHeader(req, cr.tr.URLs)
 
+	ctx, cancel := context.WithCancel(context.Background())
+	req = req.WithContext(ctx)
+
 	cr.mu.Lock()
+	cr.cancel = cancel
 	select {
 	case <-cr.stopc:
 		cr.mu.Unlock()
 		return nil, fmt.Errorf("stream reader is stopped")
 	default:
 	}
-	cr.cancel = httputil.RequestCanceler(req)
 	cr.mu.Unlock()
 
 	resp, err := cr.tr.streamRt.RoundTrip(req)