|
|
@@ -3,6 +3,7 @@ package main
|
|
|
import (
|
|
|
"bytes"
|
|
|
"encoding/json"
|
|
|
+ "encoding/pem"
|
|
|
"flag"
|
|
|
"fmt"
|
|
|
"github.com/benbjohnson/go-raft"
|
|
|
@@ -14,6 +15,8 @@ import (
|
|
|
"os"
|
|
|
"time"
|
|
|
"strconv"
|
|
|
+ "crypto/tls"
|
|
|
+ "crypto/x509"
|
|
|
"github.com/xiangli-cmu/raft-etcd/web"
|
|
|
"github.com/xiangli-cmu/raft-etcd/store"
|
|
|
)
|
|
|
@@ -28,8 +31,8 @@ var verbose bool
|
|
|
var leaderHost string
|
|
|
var address string
|
|
|
var webPort int
|
|
|
-var cert string
|
|
|
-var key string
|
|
|
+var certFile string
|
|
|
+var keyFile string
|
|
|
var CAFile string
|
|
|
|
|
|
func init() {
|
|
|
@@ -37,6 +40,9 @@ func init() {
|
|
|
flag.StringVar(&leaderHost, "c", "", "join to a existing cluster")
|
|
|
flag.StringVar(&address, "a", "", "the address of the local machine")
|
|
|
flag.IntVar(&webPort, "w", -1, "the port of web interface")
|
|
|
+ flag.StringVar(&CAFile, "CAFile", "", "the path of the CAFile")
|
|
|
+ flag.StringVar(&certFile, "cert", "", "the cert file of the server")
|
|
|
+ flag.StringVar(&keyFile, "key", "", "the key file of the server")
|
|
|
}
|
|
|
|
|
|
const (
|
|
|
@@ -68,6 +74,14 @@ var logger *log.Logger
|
|
|
|
|
|
var storeMsg chan string
|
|
|
|
|
|
+// CONSTANTS
|
|
|
+const (
|
|
|
+ HTTP = iota
|
|
|
+ HTTPS
|
|
|
+ HTTPSANDVERIFY
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
//------------------------------------------------------------------------------
|
|
|
//
|
|
|
// Functions
|
|
|
@@ -107,11 +121,19 @@ func main() {
|
|
|
|
|
|
fmt.Printf("Name: %s\n\n", name)
|
|
|
|
|
|
- t := transHandler{}
|
|
|
+ // secrity type
|
|
|
+ st := securityType()
|
|
|
+
|
|
|
+ if st == -1 {
|
|
|
+ panic("ERROR type")
|
|
|
+ }
|
|
|
+
|
|
|
+ t := createTranHandler(st)
|
|
|
|
|
|
// Setup new raft server.
|
|
|
s := store.GetStore()
|
|
|
|
|
|
+ // create raft server
|
|
|
server, err = raft.NewServer(name, path, t, s, nil)
|
|
|
|
|
|
if err != nil {
|
|
|
@@ -144,7 +166,10 @@ func main() {
|
|
|
server.StartElectionTimeout()
|
|
|
server.StartFollower()
|
|
|
|
|
|
- Join(server, leaderHost)
|
|
|
+ err := Join(server, leaderHost)
|
|
|
+ if err != nil {
|
|
|
+ panic(err)
|
|
|
+ }
|
|
|
fmt.Println("success join")
|
|
|
}
|
|
|
|
|
|
@@ -157,9 +182,60 @@ func main() {
|
|
|
|
|
|
// open the snapshot
|
|
|
go server.Snapshot()
|
|
|
-
|
|
|
|
|
|
- // internal commands
|
|
|
+
|
|
|
+ if webPort != -1 {
|
|
|
+ // start web
|
|
|
+ s.SetMessager(&storeMsg)
|
|
|
+ go webHelper()
|
|
|
+ go web.Start(server, webPort)
|
|
|
+ }
|
|
|
+
|
|
|
+ startTransport(info.Port, st)
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+func usage() {
|
|
|
+ fatal("usage: raftd [PATH]")
|
|
|
+}
|
|
|
+
|
|
|
+func createTranHandler(st int) transHandler {
|
|
|
+ t := transHandler{}
|
|
|
+
|
|
|
+ switch st {
|
|
|
+ case HTTP:
|
|
|
+ t := transHandler{}
|
|
|
+ t.client = nil
|
|
|
+ return t
|
|
|
+
|
|
|
+ case HTTPS:
|
|
|
+ fallthrough
|
|
|
+ case HTTPSANDVERIFY:
|
|
|
+ tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ panic(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ tr := &http.Transport{
|
|
|
+ TLSClientConfig: &tls.Config{
|
|
|
+ Certificates: []tls.Certificate{tlsCert},
|
|
|
+ InsecureSkipVerify: true,
|
|
|
+ },
|
|
|
+ DisableCompression: true,
|
|
|
+ }
|
|
|
+
|
|
|
+ t.client = &http.Client{Transport: tr}
|
|
|
+ return t
|
|
|
+ }
|
|
|
+
|
|
|
+ // for complier
|
|
|
+ return transHandler{}
|
|
|
+}
|
|
|
+
|
|
|
+func startTransport(port int, st int) {
|
|
|
+
|
|
|
+ // internal commands
|
|
|
http.HandleFunc("/join", JoinHttpHandler)
|
|
|
http.HandleFunc("/vote", VoteHttpHandler)
|
|
|
http.HandleFunc("/log", GetLogHttpHandler)
|
|
|
@@ -172,26 +248,70 @@ func main() {
|
|
|
http.HandleFunc("/delete/", DeleteHttpHandler)
|
|
|
http.HandleFunc("/watch/", WatchHttpHandler)
|
|
|
|
|
|
+ switch st {
|
|
|
|
|
|
- if webPort != -1 {
|
|
|
- // start web
|
|
|
- s.SetMessager(&storeMsg)
|
|
|
- go webHelper()
|
|
|
- go web.Start(server, webPort)
|
|
|
- }
|
|
|
+ case HTTP:
|
|
|
+ log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil))
|
|
|
|
|
|
- // listen on http port
|
|
|
- log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.Port), nil))
|
|
|
-}
|
|
|
+ case HTTPS:
|
|
|
+ http.ListenAndServeTLS(fmt.Sprintf(":%d", port), certFile, keyFile, nil)
|
|
|
+
|
|
|
+ case HTTPSANDVERIFY:
|
|
|
+ pemByte, _ := ioutil.ReadFile(CAFile)
|
|
|
+
|
|
|
+ block, pemByte := pem.Decode(pemByte)
|
|
|
+
|
|
|
+
|
|
|
+ cert, err := x509.ParseCertificate(block.Bytes)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ certPool := x509.NewCertPool()
|
|
|
+
|
|
|
+ certPool.AddCert(cert)
|
|
|
+
|
|
|
+ server := &http.Server{
|
|
|
+ TLSConfig: &tls.Config{
|
|
|
+ ClientAuth: tls.RequireAndVerifyClientCert,
|
|
|
+ ClientCAs: certPool,
|
|
|
+ },
|
|
|
+ Addr:fmt.Sprintf(":%d", port),
|
|
|
+ }
|
|
|
+ err = server.ListenAndServeTLS(certFile, keyFile)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ log.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
-func usage() {
|
|
|
- fatal("usage: raftd [PATH]")
|
|
|
}
|
|
|
|
|
|
//--------------------------------------
|
|
|
// Config
|
|
|
//--------------------------------------
|
|
|
|
|
|
+func securityType() int{
|
|
|
+ if keyFile == "" && certFile == "" && CAFile == ""{
|
|
|
+
|
|
|
+ return HTTP
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ if keyFile != "" && certFile != "" {
|
|
|
+
|
|
|
+ if CAFile != "" {
|
|
|
+ return HTTPSANDVERIFY
|
|
|
+ }
|
|
|
+
|
|
|
+ return HTTPS
|
|
|
+ }
|
|
|
+
|
|
|
+ return -1
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
func getInfo(path string) *Info {
|
|
|
info := &Info{}
|
|
|
|
|
|
@@ -253,8 +373,21 @@ func Join(s *raft.Server, serverName string) error {
|
|
|
command.Name = s.Name()
|
|
|
|
|
|
json.NewEncoder(&b).Encode(command)
|
|
|
- debug("[send] POST http://%v/join", "localhost:4001")
|
|
|
- resp, err := http.Post(fmt.Sprintf("http://%s/join", serverName), "application/json", &b)
|
|
|
+
|
|
|
+
|
|
|
+ var resp *http.Response
|
|
|
+ var err error
|
|
|
+
|
|
|
+ // t must be ok
|
|
|
+ t,_ := server.Transporter().(transHandler)
|
|
|
+ if t.client != nil {
|
|
|
+ debug("[send] POST https://%v/join", "localhost:4001")
|
|
|
+ resp, err = t.client.Post(fmt.Sprintf("https://%s/join", serverName), "application/json", &b)
|
|
|
+ } else {
|
|
|
+ debug("[send] POST http://%v/join", "localhost:4001")
|
|
|
+ resp, err = http.Post(fmt.Sprintf("https://%s/join", serverName), "application/json", &b)
|
|
|
+ }
|
|
|
+
|
|
|
if resp != nil {
|
|
|
resp.Body.Close()
|
|
|
if resp.StatusCode == http.StatusOK {
|