Browse Source

simplify createTrans

Xiang Li 12 years ago
parent
commit
06fab60dd6
1 changed files with 64 additions and 90 deletions
  1. 64 90
      etcd.go

+ 64 - 90
etcd.go

@@ -89,14 +89,8 @@ func init() {
 
 // CONSTANTS
 const (
-	HTTP = iota
-	HTTPS
-	HTTPSANDVERIFY
-)
-
-const (
-	SERVER = iota
-	CLIENT
+	RaftServer = iota
+	EtcdServer
 )
 
 const (
@@ -200,19 +194,20 @@ func main() {
 
 	info = getInfo(dirPath)
 
-	// security type
-	st := securityType(SERVER)
-
-	clientSt := securityType(CLIENT)
+	raftTlsConfs, ok := tlsConf(RaftServer)
+	if !ok {
+		fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
+	}
 
-	if st == -1 || clientSt == -1 {
+	etcdTlsConfs, ok := tlsConf(EtcdServer)
+	if !ok {
 		fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
 	}
 
 	// Create etcd key-value store
 	etcdStore = store.CreateStore(maxSize)
 
-	startRaft(st)
+	startRaft(raftTlsConfs)
 
 	if argInfo.WebPort != -1 {
 		// start web
@@ -221,18 +216,18 @@ func main() {
 		go web.Start(raftServer, argInfo.WebPort)
 	}
 
-	startClientTransport(*info, clientSt)
+	startEtcdTransport(*info, etcdTlsConfs[0])
 
 }
 
 // Start the raft server
-func startRaft(securityType int) {
+func startRaft(tlsConfs []*tls.Config) {
 	var err error
 
 	raftName := fmt.Sprintf("%s:%d", info.Hostname, info.RaftPort)
 
 	// Create transporter for raft
-	raftTransporter = createTransporter(securityType)
+	raftTransporter = newTransporter(tlsConfs[1])
 
 	// Create raft server
 	raftServer, err = raft.NewServer(raftName, dirPath, raftTransporter, etcdStore, nil)
@@ -328,44 +323,30 @@ func startRaft(securityType int) {
 	}
 
 	// start to response to raft requests
-	go startRaftTransport(*info, securityType)
+	go startRaftTransport(*info, tlsConfs[0])
 
 }
 
 // Create transporter using by raft server
 // Create http or https transporter based on
 // whether the user give the server cert and key
-func createTransporter(st int) transporter {
+func newTransporter(tlsConf *tls.Config) transporter {
 	t := transporter{}
 
-	switch st {
-	case HTTP:
+	if tlsConf == nil {
 		t.scheme = "http://"
 
-		tr := &http.Transport{
-			Dial: dialTimeout,
-		}
-
 		t.client = &http.Client{
-			Transport: tr,
+			Transport: &http.Transport{
+				Dial: dialTimeout,
+			},
 		}
 
-	case HTTPS:
-		fallthrough
-	case HTTPSANDVERIFY:
+	} else {
 		t.scheme = "https://"
 
-		tlsCert, err := tls.LoadX509KeyPair(argInfo.ServerCertFile, argInfo.ServerKeyFile)
-
-		if err != nil {
-			fatal(err)
-		}
-
 		tr := &http.Transport{
-			TLSClientConfig: &tls.Config{
-				Certificates:       []tls.Certificate{tlsCert},
-				InsecureSkipVerify: true,
-			},
+			TLSClientConfig:    tlsConf,
 			Dial:               dialTimeout,
 			DisableCompression: true,
 		}
@@ -382,7 +363,7 @@ func dialTimeout(network, addr string) (net.Conn, error) {
 }
 
 // Start to listen and response raft command
-func startRaftTransport(info Info, st int) {
+func startRaftTransport(info Info, tlsConf *tls.Config) {
 
 	// internal commands
 	http.HandleFunc("/join", JoinHttpHandler)
@@ -393,24 +374,14 @@ func startRaftTransport(info Info, st int) {
 	http.HandleFunc("/snapshotRecovery", SnapshotRecoveryHttpHandler)
 	http.HandleFunc("/client", ClientHttpHandler)
 
-	switch st {
-
-	case HTTP:
+	if tlsConf == nil {
 		fmt.Printf("raft server [%s] listen on http port %v\n", info.Hostname, info.RaftPort)
 		fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.RaftPort), nil))
 
-	case HTTPS:
-		fmt.Printf("raft server [%s] listen on https port %v\n", info.Hostname, info.RaftPort)
-		fatal(http.ListenAndServeTLS(fmt.Sprintf(":%d", info.RaftPort), info.ServerCertFile, argInfo.ServerKeyFile, nil))
-
-	case HTTPSANDVERIFY:
-
+	} else {
 		server := &http.Server{
-			TLSConfig: &tls.Config{
-				ClientAuth: tls.RequireAndVerifyClientCert,
-				ClientCAs:  createCertPool(info.ServerCAFile),
-			},
-			Addr: fmt.Sprintf(":%d", info.RaftPort),
+			TLSConfig: tlsConf,
+			Addr:      fmt.Sprintf(":%d", info.RaftPort),
 		}
 		fmt.Printf("raft server [%s] listen on https port %v\n", info.Hostname, info.RaftPort)
 		fatal(server.ListenAndServeTLS(info.ServerCertFile, argInfo.ServerKeyFile))
@@ -419,7 +390,7 @@ func startRaftTransport(info Info, st int) {
 }
 
 // Start to listen and response client command
-func startClientTransport(info Info, st int) {
+func startEtcdTransport(info Info, tlsConf *tls.Config) {
 	// external commands
 	http.HandleFunc("/"+version+"/keys/", Multiplexer)
 	http.HandleFunc("/"+version+"/watch/", WatchHttpHandler)
@@ -429,24 +400,13 @@ func startClientTransport(info Info, st int) {
 	http.HandleFunc("/stats", StatsHttpHandler)
 	http.HandleFunc("/test/", TestHttpHandler)
 
-	switch st {
-
-	case HTTP:
+	if tlsConf == nil {
 		fmt.Printf("etcd [%s] listen on http port %v\n", info.Hostname, info.ClientPort)
 		fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.ClientPort), nil))
-
-	case HTTPS:
-		fmt.Printf("etcd [%s] listen on https port %v\n", info.Hostname, info.ClientPort)
-		http.ListenAndServeTLS(fmt.Sprintf(":%d", info.ClientPort), info.ClientCertFile, info.ClientKeyFile, nil)
-
-	case HTTPSANDVERIFY:
-
+	} else {
 		server := &http.Server{
-			TLSConfig: &tls.Config{
-				ClientAuth: tls.RequireAndVerifyClientCert,
-				ClientCAs:  createCertPool(info.ClientCAFile),
-			},
-			Addr: fmt.Sprintf(":%d", info.ClientPort),
+			TLSConfig: tlsConf,
+			Addr:      fmt.Sprintf(":%d", info.ClientPort),
 		}
 		fmt.Printf("etcd [%s] listen on https port %v\n", info.Hostname, info.ClientPort)
 		fatal(server.ListenAndServeTLS(info.ClientCertFile, info.ClientKeyFile))
@@ -456,20 +416,28 @@ func startClientTransport(info Info, st int) {
 //--------------------------------------
 // Config
 //--------------------------------------
-
-// Get the security type
-func securityType(source int) int {
-
+func tlsConf(source int) ([]*tls.Config, bool) {
 	var keyFile, certFile, CAFile string
+	var tlsCert tls.Certificate
+	var isAuth bool
+	var err error
 
 	switch source {
 
-	case SERVER:
+	case RaftServer:
 		keyFile = info.ServerKeyFile
 		certFile = info.ServerCertFile
 		CAFile = info.ServerCAFile
 
-	case CLIENT:
+		if keyFile != "" && certFile != "" {
+			tlsCert, err = tls.LoadX509KeyPair(certFile, keyFile)
+			if err == nil {
+				fatal(err)
+			}
+			isAuth = true
+		}
+
+	case EtcdServer:
 		keyFile = info.ClientKeyFile
 		certFile = info.ClientCertFile
 		CAFile = info.ClientCAFile
@@ -478,25 +446,28 @@ func securityType(source int) int {
 	// If the user do not specify key file, cert file and
 	// CA file, the type will be HTTP
 	if keyFile == "" && certFile == "" && CAFile == "" {
-
-		return HTTP
-
+		return []*tls.Config{nil, nil}, true
 	}
 
 	if keyFile != "" && certFile != "" {
-		if CAFile != "" {
-			// If the user specify all the three file, the type
-			// will be HTTPS with client cert auth
-			return HTTPSANDVERIFY
+		serverConf := &tls.Config{}
+		serverConf.ClientAuth, serverConf.ClientCAs = newCertPool(CAFile)
+
+		if isAuth {
+			raftTransConf := &tls.Config{
+				Certificates:       []tls.Certificate{tlsCert},
+				InsecureSkipVerify: true,
+			}
+			return []*tls.Config{serverConf, raftTransConf}, true
 		}
-		// If the user specify key file and cert file but not
-		// CA file, the type will be HTTPS without client cert
-		// auth
-		return HTTPS
+
+		return []*tls.Config{serverConf, nil}, true
+
 	}
 
 	// bad specification
-	return -1
+	return nil, false
+
 }
 
 func parseInfo(path string) *Info {
@@ -569,7 +540,10 @@ func getInfo(path string) *Info {
 }
 
 // Create client auth certpool
-func createCertPool(CAFile string) *x509.CertPool {
+func newCertPool(CAFile string) (tls.ClientAuthType, *x509.CertPool) {
+	if CAFile == "" {
+		return tls.NoClientCert, nil
+	}
 	pemByte, _ := ioutil.ReadFile(CAFile)
 
 	block, pemByte := pem.Decode(pemByte)
@@ -584,7 +558,7 @@ func createCertPool(CAFile string) *x509.CertPool {
 
 	certPool.AddCert(cert)
 
-	return certPool
+	return tls.RequireAndVerifyClientCert, certPool
 }
 
 // Send join requests to the leader.