فهرست منبع

add timeout for transportation layer

Ivan7702 12 سال پیش
والد
کامیت
51941fa613
3فایلهای تغییر یافته به همراه103 افزوده شده و 17 حذف شده
  1. 13 13
      raft_server.go
  2. 55 4
      transporter.go
  3. 35 0
      transporter_test.go

+ 13 - 13
raft_server.go

@@ -16,13 +16,13 @@ import (
 
 type raftServer struct {
 	*raft.Server
-	version   string
-	joinIndex uint64
-	name      string
-	url       string
+	version    string
+	joinIndex  uint64
+	name       string
+	url        string
 	listenHost string
-	tlsConf   *TLSConfig
-	tlsInfo   *TLSInfo
+	tlsConf    *TLSConfig
+	tlsInfo    *TLSInfo
 }
 
 var r *raftServer
@@ -30,7 +30,7 @@ var r *raftServer
 func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer {
 
 	// Create transporter for raft
-	raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client)
+	raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client, raft.DefaultHeartbeatTimeout)
 
 	// Create raft server
 	server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil)
@@ -38,13 +38,13 @@ func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfi
 	check(err)
 
 	return &raftServer{
-		Server:  server,
-		version: raftVersion,
-		name:    name,
-		url:     url,
+		Server:     server,
+		version:    raftVersion,
+		name:       name,
+		url:        url,
 		listenHost: listenHost,
-		tlsConf: tlsConf,
-		tlsInfo: tlsInfo,
+		tlsConf:    tlsConf,
+		tlsInfo:    tlsInfo,
 	}
 }
 

+ 55 - 4
transporter.go

@@ -9,17 +9,19 @@ import (
 	"io"
 	"net"
 	"net/http"
+	"time"
 )
 
 // Transporter layer for communication between raft nodes
 type transporter struct {
-	client *http.Client
+	client  *http.Client
+	timeout time.Duration
 }
 
 // Create transporter using by raft server
 // Create http or https transporter based on
 // whether the user give the server cert and key
-func newTransporter(scheme string, tlsConf tls.Config) transporter {
+func newTransporter(scheme string, tlsConf tls.Config, timeout time.Duration) transporter {
 	t := transporter{}
 
 	tr := &http.Transport{
@@ -32,6 +34,7 @@ func newTransporter(scheme string, tlsConf tls.Config) transporter {
 	}
 
 	t.client = &http.Client{Transport: tr}
+	t.timeout = timeout
 
 	return t
 }
@@ -151,10 +154,58 @@ func (t transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raft
 
 // Send server side POST request
 func (t transporter) Post(path string, body io.Reader) (*http.Response, error) {
-	return t.client.Post(path, "application/json", body)
+
+	postChan := make(chan interface{}, 1)
+
+	go func() {
+		resp, err := t.client.Post(path, "application/json", body)
+		if err == nil {
+			postChan <- resp
+		} else {
+			postChan <- err
+		}
+	}()
+
+	return t.waitResponse(postChan)
+
 }
 
 // Send server side GET request
 func (t transporter) Get(path string) (*http.Response, error) {
-	return t.client.Get(path)
+
+	getChan := make(chan interface{}, 1)
+
+	go func() {
+		resp, err := t.client.Get(path)
+		if err == nil {
+			getChan <- resp
+		} else {
+			getChan <- err
+		}
+	}()
+
+	return t.waitResponse(getChan)
+}
+
+func (t transporter) waitResponse(responseChan chan interface{}) (*http.Response, error) {
+
+	timeoutChan := time.After(t.timeout)
+
+	select {
+	case <-timeoutChan:
+		return nil, fmt.Errorf("Wait Response Timeout: %v", t.timeout)
+
+	case r := <-responseChan:
+		switch r := r.(type) {
+		case error:
+			return nil, r
+
+		case *http.Response:
+			return r, nil
+
+		}
+	}
+
+	// for complier
+	return nil, nil
 }

+ 35 - 0
transporter_test.go

@@ -0,0 +1,35 @@
+package main
+
+import (
+	"crypto/tls"
+	"testing"
+	"time"
+)
+
+func TestTransporterTimeout(t *testing.T) {
+
+	conf := tls.Config{}
+
+	ts := newTransporter("http", conf, time.Second)
+
+	_, err := ts.Get("http://127.0.0.2:7000")
+	if err == nil || err.Error() != "Wait Response Timeout: 1s" {
+		t.Fatal("timeout error: ", err.Error())
+	}
+
+	_, err = ts.Post("http://127.0.0.2:7000", nil)
+	if err == nil || err.Error() != "Wait Response Timeout: 1s" {
+		t.Fatal("timeout error: ", err.Error())
+	}
+
+	_, err = ts.Get("http://www.google.com")
+	if err != nil {
+		t.Fatal("get error")
+	}
+
+	_, err = ts.Post("http://www.google.com", nil)
+	if err != nil {
+		t.Fatal("post error")
+	}
+
+}