Browse Source

add tls fetures

Xiang Li 12 years ago
parent
commit
30da72623f
4 changed files with 192 additions and 30 deletions
  1. 4 1
      handlers.go
  2. 152 19
      raftd.go
  3. 0 4
      store/store.go
  4. 36 6
      trans_handler.go

+ 4 - 1
handlers.go

@@ -215,7 +215,10 @@ func Dispatch(server *raft.Server, command Command, w http.ResponseWriter) {
 
 
 			reader := bytes.NewReader([]byte(command.GetValue()))
 			reader := bytes.NewReader([]byte(command.GetValue()))
 
 
-			reps, _ := http.Post(fmt.Sprintf("http://%v/%s", 
+			// t must be ok
+			t,_ := server.Transporter().(transHandler)
+
+			reps, _ := t.client.Post(fmt.Sprintf("http://%v/%s", 
 				leaderName, command.GeneratePath()), "application/json", reader)
 				leaderName, command.GeneratePath()), "application/json", reader)
 
 
 			if reps == nil {
 			if reps == nil {

+ 152 - 19
raftd.go

@@ -3,6 +3,7 @@ package main
 import (
 import (
 	"bytes"
 	"bytes"
 	"encoding/json"
 	"encoding/json"
+	"encoding/pem"
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 	"github.com/benbjohnson/go-raft"
 	"github.com/benbjohnson/go-raft"
@@ -14,6 +15,8 @@ import (
 	"os"
 	"os"
 	"time"
 	"time"
 	"strconv"
 	"strconv"
+	"crypto/tls"
+	"crypto/x509"
 	"github.com/xiangli-cmu/raft-etcd/web"
 	"github.com/xiangli-cmu/raft-etcd/web"
 	"github.com/xiangli-cmu/raft-etcd/store"
 	"github.com/xiangli-cmu/raft-etcd/store"
 )
 )
@@ -28,8 +31,8 @@ var verbose bool
 var leaderHost string
 var leaderHost string
 var address string
 var address string
 var webPort int
 var webPort int
-var cert string
-var key string
+var certFile string
+var keyFile string
 var CAFile string
 var CAFile string
 
 
 func init() {
 func init() {
@@ -37,6 +40,9 @@ func init() {
 	flag.StringVar(&leaderHost, "c", "", "join to a existing cluster")
 	flag.StringVar(&leaderHost, "c", "", "join to a existing cluster")
 	flag.StringVar(&address, "a", "", "the address of the local machine")
 	flag.StringVar(&address, "a", "", "the address of the local machine")
 	flag.IntVar(&webPort, "w", -1, "the port of web interface")
 	flag.IntVar(&webPort, "w", -1, "the port of web interface")
+	flag.StringVar(&CAFile, "CAFile", "", "the path of the CAFile")
+	flag.StringVar(&certFile, "cert", "", "the cert file of the server")
+	flag.StringVar(&keyFile, "key", "", "the key file of the server")
 }
 }
 
 
 const (
 const (
@@ -68,6 +74,14 @@ var logger *log.Logger
 
 
 var storeMsg chan string
 var storeMsg chan string
 
 
+// CONSTANTS
+const (	
+	HTTP = iota
+	HTTPS
+	HTTPSANDVERIFY
+)
+
+
 //------------------------------------------------------------------------------
 //------------------------------------------------------------------------------
 //
 //
 // Functions
 // Functions
@@ -107,11 +121,19 @@ func main() {
 
 
 	fmt.Printf("Name: %s\n\n", name)
 	fmt.Printf("Name: %s\n\n", name)
 	
 	
-	t := transHandler{}
+	// secrity type
+	st := securityType()
+
+	if st == -1 {
+		panic("ERROR type")
+	}
+
+    t := createTranHandler(st)
 
 
 	// Setup new raft server.
 	// Setup new raft server.
 	s := store.GetStore()
 	s := store.GetStore()
 
 
+	// create raft server
 	server, err = raft.NewServer(name, path, t, s, nil)
 	server, err = raft.NewServer(name, path, t, s, nil)
 
 
 	if err != nil {
 	if err != nil {
@@ -144,7 +166,10 @@ func main() {
 			server.StartElectionTimeout()
 			server.StartElectionTimeout()
 			server.StartFollower()
 			server.StartFollower()
 
 
-			Join(server, leaderHost)
+			err := Join(server, leaderHost)
+			if err != nil {
+				panic(err)
+			}
 			fmt.Println("success join")
 			fmt.Println("success join")
 		}
 		}
 
 
@@ -157,9 +182,60 @@ func main() {
 
 
 	// open the snapshot
 	// open the snapshot
 	go server.Snapshot()
 	go server.Snapshot()
-	
 
 
-    // internal commands
+
+    if webPort != -1 {
+    	// start web
+    	s.SetMessager(&storeMsg)
+    	go webHelper()
+    	go web.Start(server, webPort)
+    } 
+
+    startTransport(info.Port, st)
+
+}
+
+func usage() {
+	fatal("usage: raftd [PATH]")
+}
+
+func createTranHandler(st int) transHandler {
+	t := transHandler{}
+
+	switch st {
+	case HTTP:
+		t := transHandler{}
+		t.client = nil
+		return t
+
+	case HTTPS:
+		fallthrough
+	case HTTPSANDVERIFY:
+		tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile)
+
+		if err != nil {
+			panic(err)
+		}
+
+		tr := &http.Transport{
+			TLSClientConfig:   &tls.Config{
+				Certificates: []tls.Certificate{tlsCert},
+				InsecureSkipVerify: true,
+				},
+				DisableCompression: true,
+			}
+
+		t.client = &http.Client{Transport: tr}
+		return t
+	}
+
+	// for complier
+	return transHandler{}
+}
+
+func startTransport(port int, st int) {	
+
+	// internal commands
     http.HandleFunc("/join", JoinHttpHandler)
     http.HandleFunc("/join", JoinHttpHandler)
     http.HandleFunc("/vote", VoteHttpHandler)
     http.HandleFunc("/vote", VoteHttpHandler)
     http.HandleFunc("/log", GetLogHttpHandler)
     http.HandleFunc("/log", GetLogHttpHandler)
@@ -172,26 +248,70 @@ func main() {
     http.HandleFunc("/delete/", DeleteHttpHandler)
     http.HandleFunc("/delete/", DeleteHttpHandler)
     http.HandleFunc("/watch/", WatchHttpHandler)
     http.HandleFunc("/watch/", WatchHttpHandler)
 
 
+    switch st {
 
 
-    if webPort != -1 {
-    	// start web
-    	s.SetMessager(&storeMsg)
-    	go webHelper()
-    	go web.Start(server, webPort)
-    } 
+    case HTTP:
+    	log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil))
 
 
-    // listen on http port
-	log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.Port), nil))
-}
+    case HTTPS:
+    	http.ListenAndServeTLS(fmt.Sprintf(":%d", port), certFile, keyFile, nil)
+
+    case HTTPSANDVERIFY:
+    	pemByte, _ := ioutil.ReadFile(CAFile)
+
+		block, pemByte := pem.Decode(pemByte)
+
+
+		cert, err := x509.ParseCertificate(block.Bytes)
+
+		if err != nil {
+			fmt.Println(err)
+		}
+
+		certPool := x509.NewCertPool()
+
+		certPool.AddCert(cert)
+
+		server := &http.Server{
+			TLSConfig: &tls.Config{
+				ClientAuth: tls.RequireAndVerifyClientCert,
+				ClientCAs: certPool,
+				},
+			Addr:fmt.Sprintf(":%d", port),
+		}
+		err = server.ListenAndServeTLS(certFile, keyFile)
+
+		if err != nil {
+			log.Fatal(err)
+		}
+    }
 
 
-func usage() {
-	fatal("usage: raftd [PATH]")
 }
 }
 
 
 //--------------------------------------
 //--------------------------------------
 // Config
 // Config
 //--------------------------------------
 //--------------------------------------
 
 
+func securityType() int{
+	if keyFile == "" && certFile == "" && CAFile == ""{
+
+		return HTTP
+
+	}
+
+	if keyFile != "" && certFile != "" {
+
+		if CAFile != "" {
+			return HTTPSANDVERIFY
+		}
+
+		return HTTPS
+	}
+
+	return -1
+}
+
+
 func getInfo(path string) *Info {
 func getInfo(path string) *Info {
 	info := &Info{}
 	info := &Info{}
 
 
@@ -253,8 +373,21 @@ func Join(s *raft.Server, serverName string) error {
 	command.Name = s.Name()
 	command.Name = s.Name()
 
 
 	json.NewEncoder(&b).Encode(command)
 	json.NewEncoder(&b).Encode(command)
-	debug("[send] POST http://%v/join", "localhost:4001")
-	resp, err := http.Post(fmt.Sprintf("http://%s/join", serverName), "application/json", &b)
+	
+
+	var resp *http.Response
+	var err error
+
+	// t must be ok
+	t,_ := server.Transporter().(transHandler)
+	if t.client != nil {
+		debug("[send] POST https://%v/join", "localhost:4001")
+		resp, err = t.client.Post(fmt.Sprintf("https://%s/join", serverName), "application/json", &b)
+	} else {
+		debug("[send] POST http://%v/join", "localhost:4001")
+		resp, err = http.Post(fmt.Sprintf("https://%s/join", serverName), "application/json", &b)
+	}
+
 	if resp != nil {
 	if resp != nil {
 		resp.Body.Close()
 		resp.Body.Close()
 		if resp.StatusCode == http.StatusOK {
 		if resp.StatusCode == http.StatusOK {

+ 0 - 4
store/store.go

@@ -78,10 +78,6 @@ func Set(key string, value string, expireTime time.Time) ([]byte, error) {
 	node, ok := s.Nodes[key]
 	node, ok := s.Nodes[key]
 
 
 	if ok {
 	if ok {
-		//update := make(chan time.Time)
-		//s.Nodes[key] = Node{value, expireTime, update}
-
-		
 		
 		
 		// if node is not permanent before 
 		// if node is not permanent before 
 		// update its expireTime
 		// update its expireTime

+ 36 - 6
trans_handler.go

@@ -11,6 +11,7 @@ import(
 
 
 type transHandler struct {
 type transHandler struct {
 	name string
 	name string
+	client *http.Client
 }
 }
 
 
 // Sends AppendEntries RPCs to a peer when the server is the leader.
 // Sends AppendEntries RPCs to a peer when the server is the leader.
@@ -18,8 +19,18 @@ func (t transHandler) SendAppendEntriesRequest(server *raft.Server, peer *raft.P
 	var aersp *raft.AppendEntriesResponse
 	var aersp *raft.AppendEntriesResponse
 	var b bytes.Buffer
 	var b bytes.Buffer
 	json.NewEncoder(&b).Encode(req)
 	json.NewEncoder(&b).Encode(req)
-	debug("[send] POST http://%s/log/append [%d]", peer.Name(), len(req.Entries))
-	resp, err := http.Post(fmt.Sprintf("http://%s/log/append", peer.Name()), "application/json", &b)
+	
+
+	var resp *http.Response
+	var err error
+
+	if t.client != nil {
+		debug("[send] POST https://%s/log/append [%d]", peer.Name(), len(req.Entries))
+		resp, err = http.Post(fmt.Sprintf("https://%s/log/append", peer.Name()), "application/json", &b)
+	} else {
+		debug("[send] POST http://%s/log/append [%d]", peer.Name(), len(req.Entries))
+		resp, err = t.client.Post(fmt.Sprintf("http://%s/log/append", peer.Name()), "application/json", &b)
+	}
 	if resp != nil {
 	if resp != nil {
 		defer resp.Body.Close()
 		defer resp.Body.Close()
 		aersp = &raft.AppendEntriesResponse{}
 		aersp = &raft.AppendEntriesResponse{}
@@ -36,8 +47,18 @@ func (t transHandler) SendVoteRequest(server *raft.Server, peer *raft.Peer, req
 	var rvrsp *raft.RequestVoteResponse
 	var rvrsp *raft.RequestVoteResponse
 	var b bytes.Buffer
 	var b bytes.Buffer
 	json.NewEncoder(&b).Encode(req)
 	json.NewEncoder(&b).Encode(req)
-	debug("[send] POST http://%s/vote", peer.Name())
-	resp, err := http.Post(fmt.Sprintf("http://%s/vote", peer.Name()), "application/json", &b)
+
+	var resp *http.Response
+	var err error
+
+	if t.client != nil {
+		debug("[send] POST https://%s/vote", peer.Name())
+		resp, err = t.client.Post(fmt.Sprintf("https://%s/vote", peer.Name()), "application/json", &b)
+	} else {
+		debug("[send] POST http://%s/vote", peer.Name())
+		resp, err = http.Post(fmt.Sprintf("http://%s/vote", peer.Name()), "application/json", &b)
+	}
+
 	if resp != nil {
 	if resp != nil {
 		defer resp.Body.Close()
 		defer resp.Body.Close()
 		rvrsp := &raft.RequestVoteResponse{}
 		rvrsp := &raft.RequestVoteResponse{}
@@ -54,8 +75,17 @@ func (t transHandler) SendSnapshotRequest(server *raft.Server, peer *raft.Peer,
 	var aersp *raft.SnapshotResponse
 	var aersp *raft.SnapshotResponse
 	var b bytes.Buffer
 	var b bytes.Buffer
 	json.NewEncoder(&b).Encode(req)
 	json.NewEncoder(&b).Encode(req)
-	debug("[send] POST http://%s/snapshot [%d %d]", peer.Name(), req.LastTerm, req.LastIndex)
-	resp, err := http.Post(fmt.Sprintf("http://%s/snapshot", peer.Name()), "application/json", &b)
+
+	var resp *http.Response
+	var err error
+
+	if t.client != nil {
+		debug("[send] POST https://%s/snapshot [%d %d]", peer.Name(), req.LastTerm, req.LastIndex)
+		resp, err = t.client.Post(fmt.Sprintf("https://%s/snapshot", peer.Name()), "application/json", &b)
+	} else {
+		debug("[send] POST http://%s/snapshot [%d %d]", peer.Name(), req.LastTerm, req.LastIndex)
+		resp, err = http.Post(fmt.Sprintf("http://%s/snapshot", peer.Name()), "application/json", &b)
+	}
 	if resp != nil {
 	if resp != nil {
 		defer resp.Body.Close()
 		defer resp.Body.Close()
 		aersp = &raft.SnapshotResponse{}
 		aersp = &raft.SnapshotResponse{}