Browse Source

Allow different key,cert,CA for client and server communication

Xiang Li 12 years ago
parent
commit
f67115b935
1 changed files with 46 additions and 18 deletions
  1. 46 18
      etcd.go

+ 46 - 18
etcd.go

@@ -36,9 +36,13 @@ var clientPort int
 var serverPort int
 var webPort int
 
-var certFile string
-var keyFile string
-var CAFile string
+var serverCertFile string
+var serverKeyFile string
+var serverCAFile string
+
+var clientCertFile string
+var clientKeyFile string
+var clientCAFile string
 
 var dirPath string
 
@@ -53,9 +57,13 @@ func init() {
 	flag.IntVar(&serverPort, "s", 7001, "the port of server")
 	flag.IntVar(&webPort, "w", -1, "the port of web interface")
 
-	flag.StringVar(&CAFile, "CAFile", "", "the path of the CAFile")
-	flag.StringVar(&certFile, "cert", "", "the cert file of the server")
-	flag.StringVar(&keyFile, "key", "", "the key file of the server")
+	flag.StringVar(&serverCAFile, "serverCAFile", "", "the path of the CAFile")
+	flag.StringVar(&serverCertFile, "serverCert", "", "the cert file of the server")
+	flag.StringVar(&serverKeyFile, "serverKey", "", "the key file of the server")
+
+	flag.StringVar(&clientCAFile, "clientCAFile", "", "the path of the CAFile")
+	flag.StringVar(&clientCertFile, "clientCert", "", "the cert file of the client")
+	flag.StringVar(&clientKeyFile, "clientKey", "", "the key file of the client")
 
 	flag.StringVar(&dirPath, "d", "./", "the directory to store log and snapshot")
 }
@@ -67,6 +75,11 @@ const (
 	HTTPSANDVERIFY
 )
 
+const (
+	SERVER = iota
+	CLIENT
+)
+
 const (
 	ELECTIONTIMTOUT  = 200 * time.Millisecond
 	HEARTBEATTIMEOUT = 50 * time.Millisecond
@@ -130,7 +143,7 @@ func main() {
 	fmt.Printf("ServerName: %s\n\n", name)
 
 	// secrity type
-	st := securityType()
+	st := securityType(SERVER)
 
 	if st == -1 {
 		panic("ERROR type")
@@ -196,7 +209,7 @@ func main() {
 	}
 
 	go startServTransport(info.ServerPort, st)
-	startClientTransport(info.ClientPort, st)
+	startClientTransport(info.ClientPort, securityType(CLIENT))
 
 }
 
@@ -216,7 +229,7 @@ func createTranHandler(st int) transHandler {
 	case HTTPS:
 		fallthrough
 	case HTTPSANDVERIFY:
-		tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile)
+		tlsCert, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile)
 
 		if err != nil {
 			panic(err)
@@ -251,14 +264,14 @@ func startServTransport(port int, st int) {
 	switch st {
 
 	case HTTP:
-		debug("%s listen on http", server.Name())
+		debug("raft server [%s] listen on http", server.Name())
 		log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil))
 
 	case HTTPS:
-		http.ListenAndServeTLS(fmt.Sprintf(":%d", port), certFile, keyFile, nil)
+		http.ListenAndServeTLS(fmt.Sprintf(":%d", port), serverCertFile, serverKeyFile, nil)
 
 	case HTTPSANDVERIFY:
-		pemByte, _ := ioutil.ReadFile(CAFile)
+		pemByte, _ := ioutil.ReadFile(serverCAFile)
 
 		block, pemByte := pem.Decode(pemByte)
 
@@ -279,7 +292,7 @@ func startServTransport(port int, st int) {
 			},
 			Addr: fmt.Sprintf(":%d", port),
 		}
-		err = server.ListenAndServeTLS(certFile, keyFile)
+		err = server.ListenAndServeTLS(serverCertFile, serverKeyFile)
 
 		if err != nil {
 			log.Fatal(err)
@@ -299,14 +312,14 @@ func startClientTransport(port int, st int) {
 	switch st {
 
 	case HTTP:
-		debug("%s listen on http", server.Name())
+		debug("etcd [%s] listen on http", server.Name())
 		log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil))
 
 	case HTTPS:
-		http.ListenAndServeTLS(fmt.Sprintf(":%d", port), certFile, keyFile, nil)
+		http.ListenAndServeTLS(fmt.Sprintf(":%d", port), clientCertFile, clientKeyFile, nil)
 
 	case HTTPSANDVERIFY:
-		pemByte, _ := ioutil.ReadFile(CAFile)
+		pemByte, _ := ioutil.ReadFile(clientCAFile)
 
 		block, pemByte := pem.Decode(pemByte)
 
@@ -327,7 +340,7 @@ func startClientTransport(port int, st int) {
 			},
 			Addr: fmt.Sprintf(":%d", port),
 		}
-		err = server.ListenAndServeTLS(certFile, keyFile)
+		err = server.ListenAndServeTLS(clientCertFile, clientKeyFile)
 
 		if err != nil {
 			log.Fatal(err)
@@ -340,7 +353,22 @@ func startClientTransport(port int, st int) {
 // Config
 //--------------------------------------
 
-func securityType() int {
+func securityType(source int) int {
+
+	var keyFile, certFile, CAFile string 
+
+	switch source {
+	case SERVER:
+		keyFile = serverKeyFile
+		certFile = serverCertFile
+		CAFile = serverCAFile
+
+	case CLIENT:
+		keyFile = clientKeyFile
+		certFile = clientCertFile
+		CAFile = clientCAFile
+	}
+
 	if keyFile == "" && certFile == "" && CAFile == "" {
 
 		return HTTP