Browse Source

simplify createTrans

Xiang Li 12 years ago
parent
commit
06fab60dd6
1 changed files with 64 additions and 90 deletions
  1. 64 90
      etcd.go

+ 64 - 90
etcd.go

@@ -89,14 +89,8 @@ func init() {
 
 
 // CONSTANTS
 // CONSTANTS
 const (
 const (
-	HTTP = iota
-	HTTPS
-	HTTPSANDVERIFY
-)
-
-const (
-	SERVER = iota
-	CLIENT
+	RaftServer = iota
+	EtcdServer
 )
 )
 
 
 const (
 const (
@@ -200,19 +194,20 @@ func main() {
 
 
 	info = getInfo(dirPath)
 	info = getInfo(dirPath)
 
 
-	// security type
-	st := securityType(SERVER)
-
-	clientSt := securityType(CLIENT)
+	raftTlsConfs, ok := tlsConf(RaftServer)
+	if !ok {
+		fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
+	}
 
 
-	if st == -1 || clientSt == -1 {
+	etcdTlsConfs, ok := tlsConf(EtcdServer)
+	if !ok {
 		fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
 		fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
 	}
 	}
 
 
 	// Create etcd key-value store
 	// Create etcd key-value store
 	etcdStore = store.CreateStore(maxSize)
 	etcdStore = store.CreateStore(maxSize)
 
 
-	startRaft(st)
+	startRaft(raftTlsConfs)
 
 
 	if argInfo.WebPort != -1 {
 	if argInfo.WebPort != -1 {
 		// start web
 		// start web
@@ -221,18 +216,18 @@ func main() {
 		go web.Start(raftServer, argInfo.WebPort)
 		go web.Start(raftServer, argInfo.WebPort)
 	}
 	}
 
 
-	startClientTransport(*info, clientSt)
+	startEtcdTransport(*info, etcdTlsConfs[0])
 
 
 }
 }
 
 
 // Start the raft server
 // Start the raft server
-func startRaft(securityType int) {
+func startRaft(tlsConfs []*tls.Config) {
 	var err error
 	var err error
 
 
 	raftName := fmt.Sprintf("%s:%d", info.Hostname, info.RaftPort)
 	raftName := fmt.Sprintf("%s:%d", info.Hostname, info.RaftPort)
 
 
 	// Create transporter for raft
 	// Create transporter for raft
-	raftTransporter = createTransporter(securityType)
+	raftTransporter = newTransporter(tlsConfs[1])
 
 
 	// Create raft server
 	// Create raft server
 	raftServer, err = raft.NewServer(raftName, dirPath, raftTransporter, etcdStore, nil)
 	raftServer, err = raft.NewServer(raftName, dirPath, raftTransporter, etcdStore, nil)
@@ -328,44 +323,30 @@ func startRaft(securityType int) {
 	}
 	}
 
 
 	// start to response to raft requests
 	// start to response to raft requests
-	go startRaftTransport(*info, securityType)
+	go startRaftTransport(*info, tlsConfs[0])
 
 
 }
 }
 
 
 // Create transporter using by raft server
 // Create transporter using by raft server
 // Create http or https transporter based on
 // Create http or https transporter based on
 // whether the user give the server cert and key
 // whether the user give the server cert and key
-func createTransporter(st int) transporter {
+func newTransporter(tlsConf *tls.Config) transporter {
 	t := transporter{}
 	t := transporter{}
 
 
-	switch st {
-	case HTTP:
+	if tlsConf == nil {
 		t.scheme = "http://"
 		t.scheme = "http://"
 
 
-		tr := &http.Transport{
-			Dial: dialTimeout,
-		}
-
 		t.client = &http.Client{
 		t.client = &http.Client{
-			Transport: tr,
+			Transport: &http.Transport{
+				Dial: dialTimeout,
+			},
 		}
 		}
 
 
-	case HTTPS:
-		fallthrough
-	case HTTPSANDVERIFY:
+	} else {
 		t.scheme = "https://"
 		t.scheme = "https://"
 
 
-		tlsCert, err := tls.LoadX509KeyPair(argInfo.ServerCertFile, argInfo.ServerKeyFile)
-
-		if err != nil {
-			fatal(err)
-		}
-
 		tr := &http.Transport{
 		tr := &http.Transport{
-			TLSClientConfig: &tls.Config{
-				Certificates:       []tls.Certificate{tlsCert},
-				InsecureSkipVerify: true,
-			},
+			TLSClientConfig:    tlsConf,
 			Dial:               dialTimeout,
 			Dial:               dialTimeout,
 			DisableCompression: true,
 			DisableCompression: true,
 		}
 		}
@@ -382,7 +363,7 @@ func dialTimeout(network, addr string) (net.Conn, error) {
 }
 }
 
 
 // Start to listen and response raft command
 // Start to listen and response raft command
-func startRaftTransport(info Info, st int) {
+func startRaftTransport(info Info, tlsConf *tls.Config) {
 
 
 	// internal commands
 	// internal commands
 	http.HandleFunc("/join", JoinHttpHandler)
 	http.HandleFunc("/join", JoinHttpHandler)
@@ -393,24 +374,14 @@ func startRaftTransport(info Info, st int) {
 	http.HandleFunc("/snapshotRecovery", SnapshotRecoveryHttpHandler)
 	http.HandleFunc("/snapshotRecovery", SnapshotRecoveryHttpHandler)
 	http.HandleFunc("/client", ClientHttpHandler)
 	http.HandleFunc("/client", ClientHttpHandler)
 
 
-	switch st {
-
-	case HTTP:
+	if tlsConf == nil {
 		fmt.Printf("raft server [%s] listen on http port %v\n", info.Hostname, info.RaftPort)
 		fmt.Printf("raft server [%s] listen on http port %v\n", info.Hostname, info.RaftPort)
 		fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.RaftPort), nil))
 		fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.RaftPort), nil))
 
 
-	case HTTPS:
-		fmt.Printf("raft server [%s] listen on https port %v\n", info.Hostname, info.RaftPort)
-		fatal(http.ListenAndServeTLS(fmt.Sprintf(":%d", info.RaftPort), info.ServerCertFile, argInfo.ServerKeyFile, nil))
-
-	case HTTPSANDVERIFY:
-
+	} else {
 		server := &http.Server{
 		server := &http.Server{
-			TLSConfig: &tls.Config{
-				ClientAuth: tls.RequireAndVerifyClientCert,
-				ClientCAs:  createCertPool(info.ServerCAFile),
-			},
-			Addr: fmt.Sprintf(":%d", info.RaftPort),
+			TLSConfig: tlsConf,
+			Addr:      fmt.Sprintf(":%d", info.RaftPort),
 		}
 		}
 		fmt.Printf("raft server [%s] listen on https port %v\n", info.Hostname, info.RaftPort)
 		fmt.Printf("raft server [%s] listen on https port %v\n", info.Hostname, info.RaftPort)
 		fatal(server.ListenAndServeTLS(info.ServerCertFile, argInfo.ServerKeyFile))
 		fatal(server.ListenAndServeTLS(info.ServerCertFile, argInfo.ServerKeyFile))
@@ -419,7 +390,7 @@ func startRaftTransport(info Info, st int) {
 }
 }
 
 
 // Start to listen and response client command
 // Start to listen and response client command
-func startClientTransport(info Info, st int) {
+func startEtcdTransport(info Info, tlsConf *tls.Config) {
 	// external commands
 	// external commands
 	http.HandleFunc("/"+version+"/keys/", Multiplexer)
 	http.HandleFunc("/"+version+"/keys/", Multiplexer)
 	http.HandleFunc("/"+version+"/watch/", WatchHttpHandler)
 	http.HandleFunc("/"+version+"/watch/", WatchHttpHandler)
@@ -429,24 +400,13 @@ func startClientTransport(info Info, st int) {
 	http.HandleFunc("/stats", StatsHttpHandler)
 	http.HandleFunc("/stats", StatsHttpHandler)
 	http.HandleFunc("/test/", TestHttpHandler)
 	http.HandleFunc("/test/", TestHttpHandler)
 
 
-	switch st {
-
-	case HTTP:
+	if tlsConf == nil {
 		fmt.Printf("etcd [%s] listen on http port %v\n", info.Hostname, info.ClientPort)
 		fmt.Printf("etcd [%s] listen on http port %v\n", info.Hostname, info.ClientPort)
 		fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.ClientPort), nil))
 		fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.ClientPort), nil))
-
-	case HTTPS:
-		fmt.Printf("etcd [%s] listen on https port %v\n", info.Hostname, info.ClientPort)
-		http.ListenAndServeTLS(fmt.Sprintf(":%d", info.ClientPort), info.ClientCertFile, info.ClientKeyFile, nil)
-
-	case HTTPSANDVERIFY:
-
+	} else {
 		server := &http.Server{
 		server := &http.Server{
-			TLSConfig: &tls.Config{
-				ClientAuth: tls.RequireAndVerifyClientCert,
-				ClientCAs:  createCertPool(info.ClientCAFile),
-			},
-			Addr: fmt.Sprintf(":%d", info.ClientPort),
+			TLSConfig: tlsConf,
+			Addr:      fmt.Sprintf(":%d", info.ClientPort),
 		}
 		}
 		fmt.Printf("etcd [%s] listen on https port %v\n", info.Hostname, info.ClientPort)
 		fmt.Printf("etcd [%s] listen on https port %v\n", info.Hostname, info.ClientPort)
 		fatal(server.ListenAndServeTLS(info.ClientCertFile, info.ClientKeyFile))
 		fatal(server.ListenAndServeTLS(info.ClientCertFile, info.ClientKeyFile))
@@ -456,20 +416,28 @@ func startClientTransport(info Info, st int) {
 //--------------------------------------
 //--------------------------------------
 // Config
 // Config
 //--------------------------------------
 //--------------------------------------
-
-// Get the security type
-func securityType(source int) int {
-
+func tlsConf(source int) ([]*tls.Config, bool) {
 	var keyFile, certFile, CAFile string
 	var keyFile, certFile, CAFile string
+	var tlsCert tls.Certificate
+	var isAuth bool
+	var err error
 
 
 	switch source {
 	switch source {
 
 
-	case SERVER:
+	case RaftServer:
 		keyFile = info.ServerKeyFile
 		keyFile = info.ServerKeyFile
 		certFile = info.ServerCertFile
 		certFile = info.ServerCertFile
 		CAFile = info.ServerCAFile
 		CAFile = info.ServerCAFile
 
 
-	case CLIENT:
+		if keyFile != "" && certFile != "" {
+			tlsCert, err = tls.LoadX509KeyPair(certFile, keyFile)
+			if err == nil {
+				fatal(err)
+			}
+			isAuth = true
+		}
+
+	case EtcdServer:
 		keyFile = info.ClientKeyFile
 		keyFile = info.ClientKeyFile
 		certFile = info.ClientCertFile
 		certFile = info.ClientCertFile
 		CAFile = info.ClientCAFile
 		CAFile = info.ClientCAFile
@@ -478,25 +446,28 @@ func securityType(source int) int {
 	// If the user do not specify key file, cert file and
 	// If the user do not specify key file, cert file and
 	// CA file, the type will be HTTP
 	// CA file, the type will be HTTP
 	if keyFile == "" && certFile == "" && CAFile == "" {
 	if keyFile == "" && certFile == "" && CAFile == "" {
-
-		return HTTP
-
+		return []*tls.Config{nil, nil}, true
 	}
 	}
 
 
 	if keyFile != "" && certFile != "" {
 	if keyFile != "" && certFile != "" {
-		if CAFile != "" {
-			// If the user specify all the three file, the type
-			// will be HTTPS with client cert auth
-			return HTTPSANDVERIFY
+		serverConf := &tls.Config{}
+		serverConf.ClientAuth, serverConf.ClientCAs = newCertPool(CAFile)
+
+		if isAuth {
+			raftTransConf := &tls.Config{
+				Certificates:       []tls.Certificate{tlsCert},
+				InsecureSkipVerify: true,
+			}
+			return []*tls.Config{serverConf, raftTransConf}, true
 		}
 		}
-		// If the user specify key file and cert file but not
-		// CA file, the type will be HTTPS without client cert
-		// auth
-		return HTTPS
+
+		return []*tls.Config{serverConf, nil}, true
+
 	}
 	}
 
 
 	// bad specification
 	// bad specification
-	return -1
+	return nil, false
+
 }
 }
 
 
 func parseInfo(path string) *Info {
 func parseInfo(path string) *Info {
@@ -569,7 +540,10 @@ func getInfo(path string) *Info {
 }
 }
 
 
 // Create client auth certpool
 // Create client auth certpool
-func createCertPool(CAFile string) *x509.CertPool {
+func newCertPool(CAFile string) (tls.ClientAuthType, *x509.CertPool) {
+	if CAFile == "" {
+		return tls.NoClientCert, nil
+	}
 	pemByte, _ := ioutil.ReadFile(CAFile)
 	pemByte, _ := ioutil.ReadFile(CAFile)
 
 
 	block, pemByte := pem.Decode(pemByte)
 	block, pemByte := pem.Decode(pemByte)
@@ -584,7 +558,7 @@ func createCertPool(CAFile string) *x509.CertPool {
 
 
 	certPool.AddCert(cert)
 	certPool.AddCert(cert)
 
 
-	return certPool
+	return tls.RequireAndVerifyClientCert, certPool
 }
 }
 
 
 // Send join requests to the leader.
 // Send join requests to the leader.