Browse Source

Merge pull request #152 from evan-gu/tranWithTimeout

Transporter with timeout
Xiang Li 12 years ago
parent
commit
40dcde42aa
3 changed files with 107 additions and 26 deletions
  1. 16 15
      raft_server.go
  2. 55 11
      transporter.go
  3. 36 0
      transporter_test.go

+ 16 - 15
raft_server.go

@@ -16,13 +16,13 @@ import (
 
 
 type raftServer struct {
 type raftServer struct {
 	*raft.Server
 	*raft.Server
-	version   string
-	joinIndex uint64
-	name      string
-	url       string
+	version    string
+	joinIndex  uint64
+	name       string
+	url        string
 	listenHost string
 	listenHost string
-	tlsConf   *TLSConfig
-	tlsInfo   *TLSInfo
+	tlsConf    *TLSConfig
+	tlsInfo    *TLSInfo
 }
 }
 
 
 var r *raftServer
 var r *raftServer
@@ -30,7 +30,7 @@ var r *raftServer
 func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer {
 func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer {
 
 
 	// Create transporter for raft
 	// Create transporter for raft
-	raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client)
+	raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client, ElectionTimeout)
 
 
 	// Create raft server
 	// Create raft server
 	server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil)
 	server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil)
@@ -38,13 +38,13 @@ func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfi
 	check(err)
 	check(err)
 
 
 	return &raftServer{
 	return &raftServer{
-		Server:  server,
-		version: raftVersion,
-		name:    name,
-		url:     url,
+		Server:     server,
+		version:    raftVersion,
+		name:       name,
+		url:        url,
 		listenHost: listenHost,
 		listenHost: listenHost,
-		tlsConf: tlsConf,
-		tlsInfo: tlsInfo,
+		tlsConf:    tlsConf,
+		tlsInfo:    tlsInfo,
 	}
 	}
 }
 }
 
 
@@ -169,7 +169,7 @@ func (r *raftServer) startTransport(scheme string, tlsConf tls.Config) {
 // getVersion fetches the raft version of a peer. This works for now but we
 // getVersion fetches the raft version of a peer. This works for now but we
 // will need to do something more sophisticated later when we allow mixed
 // will need to do something more sophisticated later when we allow mixed
 // version clusters.
 // version clusters.
-func getVersion(t transporter, versionURL url.URL) (string, error) {
+func getVersion(t *transporter, versionURL url.URL) (string, error) {
 	resp, err := t.Get(versionURL.String())
 	resp, err := t.Get(versionURL.String())
 
 
 	if err != nil {
 	if err != nil {
@@ -198,6 +198,7 @@ func joinCluster(cluster []string) bool {
 			if _, ok := err.(etcdErr.Error); ok {
 			if _, ok := err.(etcdErr.Error); ok {
 				fatal(err)
 				fatal(err)
 			}
 			}
+
 			debugf("cannot join to cluster via machine %s %s", machine, err)
 			debugf("cannot join to cluster via machine %s %s", machine, err)
 		}
 		}
 	}
 	}
@@ -209,7 +210,7 @@ func joinByMachine(s *raft.Server, machine string, scheme string) error {
 	var b bytes.Buffer
 	var b bytes.Buffer
 
 
 	// t must be ok
 	// t must be ok
-	t, _ := r.Transporter().(transporter)
+	t, _ := r.Transporter().(*transporter)
 
 
 	// Our version must match the leaders version
 	// Our version must match the leaders version
 	versionURL := url.URL{Host: machine, Scheme: scheme, Path: "/version"}
 	versionURL := url.URL{Host: machine, Scheme: scheme, Path: "/version"}

+ 55 - 11
transporter.go

@@ -9,17 +9,25 @@ import (
 	"io"
 	"io"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"time"
 )
 )
 
 
 // Transporter layer for communication between raft nodes
 // Transporter layer for communication between raft nodes
 type transporter struct {
 type transporter struct {
-	client *http.Client
+	client  *http.Client
+	timeout time.Duration
+}
+
+// response struct
+type transporterResponse struct {
+	resp *http.Response
+	err  error
 }
 }
 
 
 // Create transporter using by raft server
 // Create transporter using by raft server
 // Create http or https transporter based on
 // Create http or https transporter based on
 // whether the user give the server cert and key
 // whether the user give the server cert and key
-func newTransporter(scheme string, tlsConf tls.Config) transporter {
+func newTransporter(scheme string, tlsConf tls.Config, timeout time.Duration) *transporter {
 	t := transporter{}
 	t := transporter{}
 
 
 	tr := &http.Transport{
 	tr := &http.Transport{
@@ -32,8 +40,9 @@ func newTransporter(scheme string, tlsConf tls.Config) transporter {
 	}
 	}
 
 
 	t.client = &http.Client{Transport: tr}
 	t.client = &http.Client{Transport: tr}
+	t.timeout = timeout
 
 
-	return t
+	return &t
 }
 }
 
 
 // Dial with timeout
 // Dial with timeout
@@ -42,7 +51,7 @@ func dialTimeout(network, addr string) (net.Conn, error) {
 }
 }
 
 
 // Sends AppendEntries RPCs to a peer when the server is the leader.
 // Sends AppendEntries RPCs to a peer when the server is the leader.
-func (t transporter) SendAppendEntriesRequest(server *raft.Server, peer *raft.Peer, req *raft.AppendEntriesRequest) *raft.AppendEntriesResponse {
+func (t *transporter) SendAppendEntriesRequest(server *raft.Server, peer *raft.Peer, req *raft.AppendEntriesRequest) *raft.AppendEntriesResponse {
 	var aersp *raft.AppendEntriesResponse
 	var aersp *raft.AppendEntriesResponse
 	var b bytes.Buffer
 	var b bytes.Buffer
 	json.NewEncoder(&b).Encode(req)
 	json.NewEncoder(&b).Encode(req)
@@ -69,7 +78,7 @@ func (t transporter) SendAppendEntriesRequest(server *raft.Server, peer *raft.Pe
 }
 }
 
 
 // Sends RequestVote RPCs to a peer when the server is the candidate.
 // Sends RequestVote RPCs to a peer when the server is the candidate.
-func (t transporter) SendVoteRequest(server *raft.Server, peer *raft.Peer, req *raft.RequestVoteRequest) *raft.RequestVoteResponse {
+func (t *transporter) SendVoteRequest(server *raft.Server, peer *raft.Peer, req *raft.RequestVoteRequest) *raft.RequestVoteResponse {
 	var rvrsp *raft.RequestVoteResponse
 	var rvrsp *raft.RequestVoteResponse
 	var b bytes.Buffer
 	var b bytes.Buffer
 	json.NewEncoder(&b).Encode(req)
 	json.NewEncoder(&b).Encode(req)
@@ -95,7 +104,7 @@ func (t transporter) SendVoteRequest(server *raft.Server, peer *raft.Peer, req *
 }
 }
 
 
 // Sends SnapshotRequest RPCs to a peer when the server is the candidate.
 // Sends SnapshotRequest RPCs to a peer when the server is the candidate.
-func (t transporter) SendSnapshotRequest(server *raft.Server, peer *raft.Peer, req *raft.SnapshotRequest) *raft.SnapshotResponse {
+func (t *transporter) SendSnapshotRequest(server *raft.Server, peer *raft.Peer, req *raft.SnapshotRequest) *raft.SnapshotResponse {
 	var aersp *raft.SnapshotResponse
 	var aersp *raft.SnapshotResponse
 	var b bytes.Buffer
 	var b bytes.Buffer
 	json.NewEncoder(&b).Encode(req)
 	json.NewEncoder(&b).Encode(req)
@@ -123,7 +132,7 @@ func (t transporter) SendSnapshotRequest(server *raft.Server, peer *raft.Peer, r
 }
 }
 
 
 // Sends SnapshotRecoveryRequest RPCs to a peer when the server is the candidate.
 // Sends SnapshotRecoveryRequest RPCs to a peer when the server is the candidate.
-func (t transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raft.Peer, req *raft.SnapshotRecoveryRequest) *raft.SnapshotRecoveryResponse {
+func (t *transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raft.Peer, req *raft.SnapshotRecoveryRequest) *raft.SnapshotRecoveryResponse {
 	var aersp *raft.SnapshotRecoveryResponse
 	var aersp *raft.SnapshotRecoveryResponse
 	var b bytes.Buffer
 	var b bytes.Buffer
 	json.NewEncoder(&b).Encode(req)
 	json.NewEncoder(&b).Encode(req)
@@ -150,11 +159,46 @@ func (t transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raft
 }
 }
 
 
 // Send server side POST request
 // Send server side POST request
-func (t transporter) Post(path string, body io.Reader) (*http.Response, error) {
-	return t.client.Post(path, "application/json", body)
+func (t *transporter) Post(path string, body io.Reader) (*http.Response, error) {
+
+	c := make(chan *transporterResponse, 1)
+
+	go func() {
+		tr := new(transporterResponse)
+		tr.resp, tr.err = t.client.Post(path, "application/json", body)
+		c <- tr
+	}()
+
+	return t.waitResponse(c)
+
 }
 }
 
 
 // Send server side GET request
 // Send server side GET request
-func (t transporter) Get(path string) (*http.Response, error) {
-	return t.client.Get(path)
+func (t *transporter) Get(path string) (*http.Response, error) {
+
+	c := make(chan *transporterResponse, 1)
+
+	go func() {
+		tr := new(transporterResponse)
+		tr.resp, tr.err = t.client.Get(path)
+		c <- tr
+	}()
+
+	return t.waitResponse(c)
+}
+
+func (t *transporter) waitResponse(responseChan chan *transporterResponse) (*http.Response, error) {
+
+	timeoutChan := time.After(t.timeout)
+
+	select {
+	case <-timeoutChan:
+		return nil, fmt.Errorf("Wait Response Timeout: %v", t.timeout)
+
+	case r := <-responseChan:
+		return r.resp, r.err
+	}
+
+	// for complier
+	return nil, nil
 }
 }

+ 36 - 0
transporter_test.go

@@ -0,0 +1,36 @@
+package main
+
+import (
+	"crypto/tls"
+	"testing"
+	"time"
+)
+
+func TestTransporterTimeout(t *testing.T) {
+
+	conf := tls.Config{}
+
+	ts := newTransporter("http", conf, time.Second)
+
+	ts.Get("http://google.com")
+	_, err := ts.Get("http://google.com:9999") // it doesn't exisit
+	if err == nil || err.Error() != "Wait Response Timeout: 1s" {
+		t.Fatal("timeout error: ", err.Error())
+	}
+
+	_, err = ts.Post("http://google.com:9999", nil) // it doesn't exisit
+	if err == nil || err.Error() != "Wait Response Timeout: 1s" {
+		t.Fatal("timeout error: ", err.Error())
+	}
+
+	_, err = ts.Get("http://www.google.com")
+	if err != nil {
+		t.Fatal("get error")
+	}
+
+	_, err = ts.Post("http://www.google.com", nil)
+	if err != nil {
+		t.Fatal("post error")
+	}
+
+}