Browse Source

Merge pull request #192 from xiangli-cmu/master

fix timeout
Xiang Li 12 years ago
parent
commit
c565ac23a7
4 changed files with 94 additions and 67 deletions
  1. 1 6
      etcd.go
  2. 10 4
      raft_server.go
  3. 48 47
      transporter.go
  4. 35 10
      transporter_test.go

+ 1 - 6
etcd.go

@@ -90,12 +90,7 @@ func init() {
 const (
 	ElectionTimeout  = 200 * time.Millisecond
 	HeartbeatTimeout = 50 * time.Millisecond
-
-	// Timeout for internal raft http connection
-	// The original timeout for http is 45 seconds
-	// which is too long for our usage.
-	HTTPTimeout   = 10 * time.Second
-	RetryInterval = 10
+	RetryInterval    = 10
 )
 
 //------------------------------------------------------------------------------

+ 10 - 4
raft_server.go

@@ -33,7 +33,7 @@ var r *raftServer
 func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer {
 
 	// Create transporter for raft
-	raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client, ElectionTimeout)
+	raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client)
 
 	// Create raft server
 	server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil, "")
@@ -185,13 +185,16 @@ func (r *raftServer) startTransport(scheme string, tlsConf tls.Config) {
 // will need to do something more sophisticated later when we allow mixed
 // version clusters.
 func getVersion(t *transporter, versionURL url.URL) (string, error) {
-	resp, err := t.Get(versionURL.String())
+	resp, req, err := t.Get(versionURL.String())
 
 	if err != nil {
 		return "", err
 	}
 
 	defer resp.Body.Close()
+
+	t.CancelWhenTimeout(req)
+
 	body, err := ioutil.ReadAll(resp.Body)
 
 	return string(body), nil
@@ -246,7 +249,7 @@ func joinByMachine(s *raft.Server, machine string, scheme string) error {
 
 	debugf("Send Join Request to %s", joinURL.String())
 
-	resp, err := t.Post(joinURL.String(), &b)
+	resp, req, err := t.Post(joinURL.String(), &b)
 
 	for {
 		if err != nil {
@@ -254,6 +257,9 @@ func joinByMachine(s *raft.Server, machine string, scheme string) error {
 		}
 		if resp != nil {
 			defer resp.Body.Close()
+
+			t.CancelWhenTimeout(req)
+
 			if resp.StatusCode == http.StatusOK {
 				b, _ := ioutil.ReadAll(resp.Body)
 				r.joinIndex, _ = binary.Uvarint(b)
@@ -266,7 +272,7 @@ func joinByMachine(s *raft.Server, machine string, scheme string) error {
 
 				json.NewEncoder(&b).Encode(newJoinCommand())
 
-				resp, err = t.Post(address, &b)
+				resp, req, err = t.Post(address, &b)
 
 			} else if resp.StatusCode == http.StatusBadRequest {
 				debug("Reach max number machines in the cluster")

+ 48 - 47
transporter.go

@@ -13,26 +13,33 @@ import (
 	"github.com/coreos/go-raft"
 )
 
+// Timeout for setup internal raft http connection
+// This should not exceed 3 * RTT
+var dailTimeout = 3 * HeartbeatTimeout
+
+// Timeout for setup internal raft http connection + receive response header
+// This should not exceed 3 * RTT + RTT
+var responseHeaderTimeout = 4 * HeartbeatTimeout
+
+// Timeout for receiving the response body from the server
+// This should not exceed election timeout
+var tranTimeout = ElectionTimeout
+
 // Transporter layer for communication between raft nodes
 type transporter struct {
-	client  *http.Client
-	timeout time.Duration
-}
-
-// response struct
-type transporterResponse struct {
-	resp *http.Response
-	err  error
+	client    *http.Client
+	transport *http.Transport
 }
 
 // Create transporter using by raft server
 // Create http or https transporter based on
 // whether the user give the server cert and key
-func newTransporter(scheme string, tlsConf tls.Config, timeout time.Duration) *transporter {
+func newTransporter(scheme string, tlsConf tls.Config) *transporter {
 	t := transporter{}
 
 	tr := &http.Transport{
-		Dial: dialTimeout,
+		Dial: dialWithTimeout,
+		ResponseHeaderTimeout: responseHeaderTimeout,
 	}
 
 	if scheme == "https" {
@@ -41,14 +48,14 @@ func newTransporter(scheme string, tlsConf tls.Config, timeout time.Duration) *t
 	}
 
 	t.client = &http.Client{Transport: tr}
-	t.timeout = timeout
+	t.transport = tr
 
 	return &t
 }
 
 // Dial with timeout
-func dialTimeout(network, addr string) (net.Conn, error) {
-	return net.DialTimeout(network, addr, HTTPTimeout)
+func dialWithTimeout(network, addr string) (net.Conn, error) {
+	return net.DialTimeout(network, addr, dailTimeout)
 }
 
 // Sends AppendEntries RPCs to a peer when the server is the leader.
@@ -76,7 +83,7 @@ func (t *transporter) SendAppendEntriesRequest(server *raft.Server, peer *raft.P
 
 	start := time.Now()
 
-	resp, err := t.Post(fmt.Sprintf("%s/log/append", u), &b)
+	resp, httpRequest, err := t.Post(fmt.Sprintf("%s/log/append", u), &b)
 
 	end := time.Now()
 
@@ -93,6 +100,9 @@ func (t *transporter) SendAppendEntriesRequest(server *raft.Server, peer *raft.P
 
 	if resp != nil {
 		defer resp.Body.Close()
+
+		t.CancelWhenTimeout(httpRequest)
+
 		aersp = &raft.AppendEntriesResponse{}
 		if err := json.NewDecoder(resp.Body).Decode(&aersp); err == nil || err == io.EOF {
 			return aersp
@@ -112,7 +122,7 @@ func (t *transporter) SendVoteRequest(server *raft.Server, peer *raft.Peer, req
 	u, _ := nameToRaftURL(peer.Name)
 	debugf("Send Vote to %s", u)
 
-	resp, err := t.Post(fmt.Sprintf("%s/vote", u), &b)
+	resp, httpRequest, err := t.Post(fmt.Sprintf("%s/vote", u), &b)
 
 	if err != nil {
 		debugf("Cannot send VoteRequest to %s : %s", u, err)
@@ -120,6 +130,9 @@ func (t *transporter) SendVoteRequest(server *raft.Server, peer *raft.Peer, req
 
 	if resp != nil {
 		defer resp.Body.Close()
+
+		t.CancelWhenTimeout(httpRequest)
+
 		rvrsp := &raft.RequestVoteResponse{}
 		if err := json.NewDecoder(resp.Body).Decode(&rvrsp); err == nil || err == io.EOF {
 			return rvrsp
@@ -139,7 +152,7 @@ func (t *transporter) SendSnapshotRequest(server *raft.Server, peer *raft.Peer,
 	debugf("Send Snapshot to %s [Last Term: %d, LastIndex %d]", u,
 		req.LastTerm, req.LastIndex)
 
-	resp, err := t.Post(fmt.Sprintf("%s/snapshot", u), &b)
+	resp, httpRequest, err := t.Post(fmt.Sprintf("%s/snapshot", u), &b)
 
 	if err != nil {
 		debugf("Cannot send SendSnapshotRequest to %s : %s", u, err)
@@ -147,6 +160,9 @@ func (t *transporter) SendSnapshotRequest(server *raft.Server, peer *raft.Peer,
 
 	if resp != nil {
 		defer resp.Body.Close()
+
+		t.CancelWhenTimeout(httpRequest)
+
 		aersp = &raft.SnapshotResponse{}
 		if err = json.NewDecoder(resp.Body).Decode(&aersp); err == nil || err == io.EOF {
 
@@ -167,7 +183,7 @@ func (t *transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raf
 	debugf("Send SnapshotRecovery to %s [Last Term: %d, LastIndex %d]", u,
 		req.LastTerm, req.LastIndex)
 
-	resp, err := t.Post(fmt.Sprintf("%s/snapshotRecovery", u), &b)
+	resp, _, err := t.Post(fmt.Sprintf("%s/snapshotRecovery", u), &b)
 
 	if err != nil {
 		debugf("Cannot send SendSnapshotRecoveryRequest to %s : %s", u, err)
@@ -176,6 +192,7 @@ func (t *transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raf
 	if resp != nil {
 		defer resp.Body.Close()
 		aersp = &raft.SnapshotRecoveryResponse{}
+
 		if err = json.NewDecoder(resp.Body).Decode(&aersp); err == nil || err == io.EOF {
 			return aersp
 		}
@@ -185,46 +202,30 @@ func (t *transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raf
 }
 
 // Send server side POST request
-func (t *transporter) Post(path string, body io.Reader) (*http.Response, error) {
+func (t *transporter) Post(urlStr string, body io.Reader) (*http.Response, *http.Request, error) {
 
-	c := make(chan *transporterResponse, 1)
+	req, _ := http.NewRequest("POST", urlStr, body)
 
-	go func() {
-		tr := new(transporterResponse)
-		tr.resp, tr.err = t.client.Post(path, "application/json", body)
-		c <- tr
-	}()
+	resp, err := t.client.Do(req)
 
-	return t.waitResponse(c)
+	return resp, req, err
 
 }
 
 // Send server side GET request
-func (t *transporter) Get(path string) (*http.Response, error) {
+func (t *transporter) Get(urlStr string) (*http.Response, *http.Request, error) {
 
-	c := make(chan *transporterResponse, 1)
+	req, _ := http.NewRequest("GET", urlStr, nil)
 
-	go func() {
-		tr := new(transporterResponse)
-		tr.resp, tr.err = t.client.Get(path)
-		c <- tr
-	}()
+	resp, err := t.client.Do(req)
 
-	return t.waitResponse(c)
+	return resp, req, err
 }
 
-func (t *transporter) waitResponse(responseChan chan *transporterResponse) (*http.Response, error) {
-
-	timeoutChan := time.After(t.timeout * 10)
-
-	select {
-	case <-timeoutChan:
-		return nil, fmt.Errorf("Wait Response Timeout: %v", t.timeout)
-
-	case r := <-responseChan:
-		return r.resp, r.err
-	}
-
-	// for complier
-	return nil, nil
+// Cancel the on fly HTTP transaction when timeout happens
+func (t *transporter) CancelWhenTimeout(req *http.Request) {
+	go func() {
+		time.Sleep(ElectionTimeout)
+		t.transport.CancelRequest(req)
+	}()
 }

+ 35 - 10
transporter_test.go

@@ -2,33 +2,58 @@ package main
 
 import (
 	"crypto/tls"
+	"fmt"
+	"io/ioutil"
+	"net/http"
 	"testing"
 	"time"
 )
 
 func TestTransporterTimeout(t *testing.T) {
 
+	http.HandleFunc("/timeout", func(w http.ResponseWriter, r *http.Request) {
+		fmt.Fprintf(w, "timeout")
+		w.(http.Flusher).Flush() // send headers and some body
+		time.Sleep(time.Second * 100)
+	})
+
+	go http.ListenAndServe(":8080", nil)
+
 	conf := tls.Config{}
 
-	ts := newTransporter("http", conf, time.Second)
+	ts := newTransporter("http", conf)
 
 	ts.Get("http://google.com")
-	_, err := ts.Get("http://google.com:9999") // it doesn't exisit
-	if err == nil || err.Error() != "Wait Response Timeout: 1s" {
-		t.Fatal("timeout error: ", err.Error())
+	_, _, err := ts.Get("http://google.com:9999")
+	if err == nil {
+		t.Fatal("timeout error")
+	}
+
+	res, req, err := ts.Get("http://localhost:8080/timeout")
+
+	if err != nil {
+		t.Fatal("should not timeout")
+	}
+
+	ts.CancelWhenTimeout(req)
+
+	body, err := ioutil.ReadAll(res.Body)
+	if err == nil {
+		fmt.Println(string(body))
+		t.Fatal("expected an error reading the body")
 	}
 
-	_, err = ts.Post("http://google.com:9999", nil) // it doesn't exisit
-	if err == nil || err.Error() != "Wait Response Timeout: 1s" {
-		t.Fatal("timeout error: ", err.Error())
+	_, _, err = ts.Post("http://google.com:9999", nil)
+	if err == nil {
+		t.Fatal("timeout error")
 	}
 
-	_, err = ts.Get("http://www.google.com")
+	_, _, err = ts.Get("http://www.google.com")
 	if err != nil {
-		t.Fatal("get error")
+		t.Fatal("get error: ", err.Error())
 	}
 
-	_, err = ts.Post("http://www.google.com", nil)
+	_, _, err = ts.Post("http://www.google.com", nil)
 	if err != nil {
 		t.Fatal("post error")
 	}