Browse Source

fix(transporter): CancelRequest doesn't work on HTTPS connections blocked

Currently this is a workaround. And it should be fixed in Go1.3.
Yicheng Qin 11 years ago
parent
commit
69adb78433

+ 12 - 6
server/transporter.go

@@ -9,8 +9,10 @@ import (
 	"net/http"
 	"time"
 
-	"github.com/coreos/etcd/log"
 	"github.com/coreos/etcd/third_party/github.com/coreos/raft"
+	httpclient "github.com/coreos/etcd/third_party/github.com/mreiferson/go-httpclient"
+
+	"github.com/coreos/etcd/log"
 )
 
 // Transporter layer for communication between raft nodes
@@ -21,7 +23,7 @@ type transporter struct {
 	registry	*Registry
 
 	client		*http.Client
-	transport	*http.Transport
+	transport	*httpclient.Transport
 }
 
 type dialer func(network, addr string) (net.Conn, error)
@@ -30,11 +32,15 @@ type dialer func(network, addr string) (net.Conn, error)
 // Create http or https transporter based on
 // whether the user give the server cert and key
 func NewTransporter(followersStats *raftFollowersStats, serverStats *raftServerStats, registry *Registry, dialTimeout, requestTimeout, responseHeaderTimeout time.Duration) *transporter {
-	tr := &http.Transport{
-		Dial: func(network, addr string) (net.Conn, error) {
-			return net.DialTimeout(network, addr, dialTimeout)
-		},
+	tr := &httpclient.Transport{
 		ResponseHeaderTimeout:	responseHeaderTimeout,
+		// This is a workaround for Transport.CancelRequest doesn't work on
+		// HTTPS connections blocked. The patch for it is in progress,
+		// and would be available in Go1.3
+		// More: https://codereview.appspot.com/69280043/
+		ConnectTimeout: dialTimeout,
+		RequestTimeout: dialTimeout + responseHeaderTimeout,
+		ReadWriteTimeout: responseHeaderTimeout,
 	}
 
 	t := transporter{

+ 69 - 0
tests/functional/multi_node_kill_all_and_recovery_test.go

@@ -76,3 +76,72 @@ func TestMultiNodeKillAllAndRecovery(t *testing.T) {
 		t.Fatalf("recovery failed! [%d/16]", result.Node.ModifiedIndex)
 	}
 }
+
+// TestTLSMultiNodeKillAllAndRecovery create a five nodes
+// then kill all the nodes and restart
+func TestTLSMultiNodeKillAllAndRecovery(t *testing.T) {
+	procAttr := new(os.ProcAttr)
+	procAttr.Files = []*os.File{nil, os.Stdout, os.Stderr}
+
+	stop := make(chan bool)
+	leaderChan := make(chan string, 1)
+	all := make(chan bool, 1)
+
+	clusterSize := 5
+	argGroup, etcds, err := CreateCluster(clusterSize, procAttr, true)
+	defer DestroyCluster(etcds)
+
+	if err != nil {
+		t.Fatal("cannot create cluster")
+	}
+
+	c := etcd.NewClient(nil)
+
+	go Monitor(clusterSize, clusterSize, leaderChan, all, stop)
+	<-all
+	<-leaderChan
+	stop <-true
+
+	c.SyncCluster()
+
+	// send 10 commands
+	for i := 0; i < 10; i++ {
+		// Test Set
+		_, err := c.Set("foo", "bar", 0)
+		if err != nil {
+			panic(err)
+		}
+	}
+
+	time.Sleep(time.Second)
+
+	// kill all
+	DestroyCluster(etcds)
+
+	time.Sleep(time.Second)
+
+	stop = make(chan bool)
+	leaderChan = make(chan string, 1)
+	all = make(chan bool, 1)
+
+	time.Sleep(time.Second)
+
+	for i := 0; i < clusterSize; i++ {
+		etcds[i], err = os.StartProcess(EtcdBinPath, argGroup[i], procAttr)
+	}
+
+	go Monitor(clusterSize, 1, leaderChan, all, stop)
+
+	<-all
+	<-leaderChan
+
+	result, err := c.Set("foo", "bar", 0)
+
+	if err != nil {
+		t.Fatalf("Recovery error: %s", err)
+	}
+
+	if result.Node.ModifiedIndex != 16 {
+		t.Fatalf("recovery failed! [%d/16]", result.Node.ModifiedIndex)
+	}
+}

+ 17 - 0
third_party/github.com/mreiferson/go-httpclient/LICENSE

@@ -0,0 +1,17 @@
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.

+ 39 - 0
third_party/github.com/mreiferson/go-httpclient/README.md

@@ -0,0 +1,39 @@
+## go-httpclient
+
+**requires Go 1.1+** as of `v0.4.0` the API has been completely re-written for Go 1.1 (for a Go
+1.0.x compatible release see [1adef50](https://github.com/mreiferson/go-httpclient/tree/1adef50))
+
+[![Build
+Status](https://secure.travis-ci.org/mreiferson/go-httpclient.png?branch=master)](http://travis-ci.org/mreiferson/go-httpclient)
+
+Provides an HTTP Transport that implements the `RoundTripper` interface and
+can be used as a built in replacement for the standard library's, providing:
+
+ * connection timeouts
+ * request timeouts
+
+This is a thin wrapper around `http.Transport` that sets dial timeouts and uses
+Go's internal timer scheduler to call the Go 1.1+ `CancelRequest()` API.
+
+### Example
+
+```go
+transport := &httpclient.Transport{
+    ConnectTimeout:        1*time.Second,
+    RequestTimeout:        10*time.Second,
+    ResponseHeaderTimeout: 5*time.Second,
+}
+defer transport.Close()
+
+client := &http.Client{Transport: transport}
+req, _ := http.NewRequest("GET", "http://127.0.0.1/test", nil)
+resp, err := client.Do(req)
+if err != nil {
+    return err
+}
+defer resp.Body.Close()
+```
+
+### Reference Docs
+
+For API docs see [godoc](http://godoc.org/github.com/mreiferson/go-httpclient).

+ 210 - 0
third_party/github.com/mreiferson/go-httpclient/httpclient.go

@@ -0,0 +1,210 @@
+/*
+Provides an HTTP Transport that implements the `RoundTripper` interface and
+can be used as a built in replacement for the standard library's, providing:
+
+	* connection timeouts
+	* request timeouts
+
+This is a thin wrapper around `http.Transport` that sets dial timeouts and uses
+Go's internal timer scheduler to call the Go 1.1+ `CancelRequest()` API.
+*/
+package httpclient
+
+import (
+	"crypto/tls"
+	"io"
+	"net"
+	"net/http"
+	"net/url"
+	"sync"
+	"time"
+)
+
+// returns the current version of the package
+func Version() string {
+	return "0.4.1"
+}
+
+// Transport implements the RoundTripper interface and can be used as a replacement
+// for Go's built in http.Transport implementing end-to-end request timeouts.
+//
+// 	transport := &httpclient.Transport{
+// 	    ConnectTimeout: 1*time.Second,
+// 	    ResponseHeaderTimeout: 5*time.Second,
+// 	    RequestTimeout: 10*time.Second,
+// 	}
+// 	defer transport.Close()
+//
+// 	client := &http.Client{Transport: transport}
+// 	req, _ := http.NewRequest("GET", "http://127.0.0.1/test", nil)
+// 	resp, err := client.Do(req)
+// 	if err != nil {
+// 	    return err
+// 	}
+// 	defer resp.Body.Close()
+//
+type Transport struct {
+	// Proxy specifies a function to return a proxy for a given
+	// *http.Request. If the function returns a non-nil error, the
+	// request is aborted with the provided error.
+	// If Proxy is nil or returns a nil *url.URL, no proxy is used.
+	Proxy func(*http.Request) (*url.URL, error)
+
+	// Dial specifies the dial function for creating TCP
+	// connections. This will override the Transport's ConnectTimeout and
+	// ReadWriteTimeout settings.
+	// If Dial is nil, a dialer is generated on demand matching the Transport's
+	// options.
+	Dial func(network, addr string) (net.Conn, error)
+
+	// TLSClientConfig specifies the TLS configuration to use with
+	// tls.Client. If nil, the default configuration is used.
+	TLSClientConfig *tls.Config
+
+	// DisableKeepAlives, if true, prevents re-use of TCP connections
+	// between different HTTP requests.
+	DisableKeepAlives bool
+
+	// DisableCompression, if true, prevents the Transport from
+	// requesting compression with an "Accept-Encoding: gzip"
+	// request header when the Request contains no existing
+	// Accept-Encoding value. If the Transport requests gzip on
+	// its own and gets a gzipped response, it's transparently
+	// decoded in the Response.Body. However, if the user
+	// explicitly requested gzip it is not automatically
+	// uncompressed.
+	DisableCompression bool
+
+	// MaxIdleConnsPerHost, if non-zero, controls the maximum idle
+	// (keep-alive) to keep per-host.  If zero,
+	// http.DefaultMaxIdleConnsPerHost is used.
+	MaxIdleConnsPerHost int
+
+	// ConnectTimeout, if non-zero, is the maximum amount of time a dial will wait for
+	// a connect to complete.
+	ConnectTimeout time.Duration
+
+	// ResponseHeaderTimeout, if non-zero, specifies the amount of
+	// time to wait for a server's response headers after fully
+	// writing the request (including its body, if any). This
+	// time does not include the time to read the response body.
+	ResponseHeaderTimeout time.Duration
+
+	// RequestTimeout, if non-zero, specifies the amount of time for the entire
+	// request to complete (including all of the above timeouts + entire response body).
+	// This should never be less than the sum total of the above two timeouts.
+	RequestTimeout time.Duration
+
+	// ReadWriteTimeout, if non-zero, will set a deadline for every Read and
+	// Write operation on the request connection.
+	ReadWriteTimeout time.Duration
+
+	starter   sync.Once
+	transport *http.Transport
+}
+
+// Close cleans up the Transport, currently a no-op
+func (t *Transport) Close() error {
+	return nil
+}
+
+func (t *Transport) lazyStart() {
+	if t.Dial == nil {
+		t.Dial = func(netw, addr string) (net.Conn, error) {
+			c, err := net.DialTimeout(netw, addr, t.ConnectTimeout)
+			if err != nil {
+				return nil, err
+			}
+
+			if t.ReadWriteTimeout > 0 {
+				timeoutConn := &rwTimeoutConn{
+					TCPConn:   c.(*net.TCPConn),
+					rwTimeout: t.ReadWriteTimeout,
+				}
+				return timeoutConn, nil
+			}
+			return c, nil
+		}
+	}
+
+	t.transport = &http.Transport{
+		Dial:                  t.Dial,
+		Proxy:                 t.Proxy,
+		TLSClientConfig:       t.TLSClientConfig,
+		DisableKeepAlives:     t.DisableKeepAlives,
+		DisableCompression:    t.DisableCompression,
+		MaxIdleConnsPerHost:   t.MaxIdleConnsPerHost,
+		ResponseHeaderTimeout: t.ResponseHeaderTimeout,
+	}
+}
+
+func (t *Transport) CancelRequest(req *http.Request) {
+	t.starter.Do(t.lazyStart)
+
+	t.transport.CancelRequest(req)
+}
+
+func (t *Transport) CloseIdleConnections() {
+	t.starter.Do(t.lazyStart)
+
+	t.transport.CloseIdleConnections()
+}
+
+func (t *Transport) RegisterProtocol(scheme string, rt http.RoundTripper) {
+	t.starter.Do(t.lazyStart)
+
+	t.transport.RegisterProtocol(scheme, rt)
+}
+
+func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
+	t.starter.Do(t.lazyStart)
+
+	if t.RequestTimeout > 0 {
+		timer := time.AfterFunc(t.RequestTimeout, func() {
+			t.transport.CancelRequest(req)
+		})
+
+		resp, err = t.transport.RoundTrip(req)
+		if err != nil {
+			timer.Stop()
+		} else {
+			resp.Body = &bodyCloseInterceptor{ReadCloser: resp.Body, timer: timer}
+		}
+	} else {
+		resp, err = t.transport.RoundTrip(req)
+	}
+
+	return
+}
+
+type bodyCloseInterceptor struct {
+	io.ReadCloser
+	timer *time.Timer
+}
+
+func (bci *bodyCloseInterceptor) Close() error {
+	bci.timer.Stop()
+	return bci.ReadCloser.Close()
+}
+
+// A net.Conn that sets a deadline for every Read or Write operation
+type rwTimeoutConn struct {
+	*net.TCPConn
+	rwTimeout time.Duration
+}
+
+func (c *rwTimeoutConn) Read(b []byte) (int, error) {
+	err := c.TCPConn.SetReadDeadline(time.Now().Add(c.rwTimeout))
+	if err != nil {
+		return 0, err
+	}
+	return c.TCPConn.Read(b)
+}
+
+func (c *rwTimeoutConn) Write(b []byte) (int, error) {
+	err := c.TCPConn.SetWriteDeadline(time.Now().Add(c.rwTimeout))
+	if err != nil {
+		return 0, err
+	}
+	return c.TCPConn.Write(b)
+}

+ 233 - 0
third_party/github.com/mreiferson/go-httpclient/httpclient_test.go

@@ -0,0 +1,233 @@
+package httpclient
+
+import (
+	"crypto/tls"
+	"io"
+	"io/ioutil"
+	"net"
+	"net/http"
+	"sync"
+	"testing"
+	"time"
+)
+
+var starter sync.Once
+var addr net.Addr
+
+func testHandler(w http.ResponseWriter, req *http.Request) {
+	time.Sleep(200 * time.Millisecond)
+	io.WriteString(w, "hello, world!\n")
+}
+
+func postHandler(w http.ResponseWriter, req *http.Request) {
+	ioutil.ReadAll(req.Body)
+	w.Header().Set("Content-Length", "2")
+	io.WriteString(w, "OK")
+}
+
+func closeHandler(w http.ResponseWriter, req *http.Request) {
+	hj, _ := w.(http.Hijacker)
+	conn, bufrw, _ := hj.Hijack()
+	defer conn.Close()
+	bufrw.WriteString("HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n")
+	bufrw.Flush()
+}
+
+func redirectHandler(w http.ResponseWriter, req *http.Request) {
+	ioutil.ReadAll(req.Body)
+	http.Redirect(w, req, "/post", 302)
+}
+
+func redirect2Handler(w http.ResponseWriter, req *http.Request) {
+	ioutil.ReadAll(req.Body)
+	http.Redirect(w, req, "/redirect", 302)
+}
+
+func slowHandler(w http.ResponseWriter, r *http.Request) {
+	w.WriteHeader(200)
+	io.WriteString(w, "START\n")
+	f := w.(http.Flusher)
+	f.Flush()
+	time.Sleep(200 * time.Millisecond)
+	io.WriteString(w, "WORKING\n")
+	f.Flush()
+	time.Sleep(200 * time.Millisecond)
+	io.WriteString(w, "DONE\n")
+	return
+}
+
+func setupMockServer(t *testing.T) {
+	http.HandleFunc("/test", testHandler)
+	http.HandleFunc("/post", postHandler)
+	http.HandleFunc("/redirect", redirectHandler)
+	http.HandleFunc("/redirect2", redirect2Handler)
+	http.HandleFunc("/close", closeHandler)
+	http.HandleFunc("/slow", slowHandler)
+	ln, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatalf("failed to listen - %s", err.Error())
+	}
+	go func() {
+		err = http.Serve(ln, nil)
+		if err != nil {
+			t.Fatalf("failed to start HTTP server - %s", err.Error())
+		}
+	}()
+	addr = ln.Addr()
+}
+
+func TestHttpsConnection(t *testing.T) {
+	transport := &Transport{
+		ConnectTimeout: 1 * time.Second,
+		RequestTimeout: 2 * time.Second,
+		TLSClientConfig: &tls.Config{
+			InsecureSkipVerify: true,
+		},
+	}
+	defer transport.Close()
+	client := &http.Client{Transport: transport}
+
+	req, _ := http.NewRequest("GET", "https://httpbin.org/ip", nil)
+	resp, err := client.Do(req)
+	if err != nil {
+		t.Fatalf("1st request failed - %s", err.Error())
+	}
+	_, err = ioutil.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatalf("1st failed to read body - %s", err.Error())
+	}
+	resp.Body.Close()
+
+	req2, _ := http.NewRequest("GET", "https://httpbin.org/delay/5", nil)
+	_, err = client.Do(req2)
+	if err == nil {
+		t.Fatalf("HTTPS request should have timed out")
+	}
+}
+
+func TestHttpClient(t *testing.T) {
+	starter.Do(func() { setupMockServer(t) })
+
+	transport := &Transport{
+		ConnectTimeout:   1 * time.Second,
+		RequestTimeout:   5 * time.Second,
+		ReadWriteTimeout: 3 * time.Second,
+	}
+	client := &http.Client{Transport: transport}
+
+	req, _ := http.NewRequest("GET", "http://"+addr.String()+"/test", nil)
+	resp, err := client.Do(req)
+	if err != nil {
+		t.Fatalf("1st request failed - %s", err.Error())
+	}
+	_, err = ioutil.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatalf("1st failed to read body - %s", err.Error())
+	}
+	resp.Body.Close()
+	transport.Close()
+
+	transport = &Transport{
+		ConnectTimeout:   25 * time.Millisecond,
+		RequestTimeout:   50 * time.Millisecond,
+		ReadWriteTimeout: 50 * time.Millisecond,
+	}
+	client = &http.Client{Transport: transport}
+
+	req2, _ := http.NewRequest("GET", "http://"+addr.String()+"/test", nil)
+	resp, err = client.Do(req2)
+	if err == nil {
+		t.Fatal("2nd request should have timed out")
+	}
+	transport.Close()
+
+	transport = &Transport{
+		ConnectTimeout:   25 * time.Millisecond,
+		RequestTimeout:   250 * time.Millisecond,
+		ReadWriteTimeout: 250 * time.Millisecond,
+	}
+	client = &http.Client{Transport: transport}
+
+	req3, _ := http.NewRequest("GET", "http://"+addr.String()+"/test", nil)
+	resp, err = client.Do(req3)
+	if err != nil {
+		t.Fatal("3rd request should not have timed out")
+	}
+	resp.Body.Close()
+	transport.Close()
+}
+
+func TestSlowServer(t *testing.T) {
+	starter.Do(func() { setupMockServer(t) })
+
+	transport := &Transport{
+		ConnectTimeout:   25 * time.Millisecond,
+		RequestTimeout:   500 * time.Millisecond,
+		ReadWriteTimeout: 250 * time.Millisecond,
+	}
+
+	client := &http.Client{Transport: transport}
+
+	req, _ := http.NewRequest("GET", "http://"+addr.String()+"/slow", nil)
+	resp, err := client.Do(req)
+	if err != nil {
+		t.Fatalf("1st request failed - %s", err)
+	}
+	_, err = ioutil.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatalf("1st failed to read body - %s", err)
+	}
+	resp.Body.Close()
+	transport.Close()
+
+	transport = &Transport{
+		ConnectTimeout:   25 * time.Millisecond,
+		RequestTimeout:   500 * time.Millisecond,
+		ReadWriteTimeout: 100 * time.Millisecond,
+	}
+	client = &http.Client{Transport: transport}
+
+	req, _ = http.NewRequest("GET", "http://"+addr.String()+"/slow", nil)
+	resp, err = client.Do(req)
+	if err != nil {
+		t.Fatalf("2nd request failed - %s", err)
+	}
+	_, err = ioutil.ReadAll(resp.Body)
+	netErr, ok := err.(net.Error)
+	if !ok {
+		t.Fatalf("2nd request dind't return a net.Error - %s", netErr)
+	}
+
+	if !netErr.Timeout() {
+		t.Fatalf("2nd request should have timed out - %s", netErr)
+	}
+
+	resp.Body.Close()
+	transport.Close()
+}
+
+func TestMultipleRequests(t *testing.T) {
+	starter.Do(func() { setupMockServer(t) })
+
+	transport := &Transport{
+		ConnectTimeout:        1 * time.Second,
+		RequestTimeout:        5 * time.Second,
+		ReadWriteTimeout:      3 * time.Second,
+		ResponseHeaderTimeout: 400 * time.Millisecond,
+	}
+	client := &http.Client{Transport: transport}
+
+	req, _ := http.NewRequest("GET", "http://"+addr.String()+"/test", nil)
+	for i := 0; i < 10; i++ {
+		resp, err := client.Do(req)
+		if err != nil {
+			t.Fatalf("%d request failed - %s", i, err.Error())
+		}
+		_, err = ioutil.ReadAll(resp.Body)
+		if err != nil {
+			t.Fatalf("%d failed to read body - %s", i, err.Error())
+		}
+		resp.Body.Close()
+	}
+	transport.Close()
+}