Преглед изворни кода

chore(etcd): cleanup TLS configuration

the TLS configuration was getting rather complex with slices of
tls.Config's being passed around and pointer nil checking for schema
types.

Introduce a new TLSInfo type that is in charge of holding the various
TLS key/cert/CA filenames the user passes in.

Then create a new TlsConfig type that has a Scheme and the Client and
Server tls.Config objects inside of it. This is used by the two
transport start methods which had been using a slice of tls.Config
objects and guessing at the scheme based on the non-nil value of the
Config.
Brandon Philips пре 12 година
родитељ
комит
8c09f98882
1 измењених фајлова са 68 додато и 104 уклоњено
  1. 68 104
      etcd.go

+ 68 - 104
etcd.go

@@ -63,13 +63,13 @@ func init() {
 	flag.StringVar(&argInfo.RaftURL, "s", "127.0.0.1:7001", "the hostname:port for raft server communication")
 	flag.StringVar(&argInfo.WebURL, "w", "", "the hostname:port of web interface")
 
-	flag.StringVar(&argInfo.ServerCAFile, "serverCAFile", "", "the path of the CAFile")
-	flag.StringVar(&argInfo.ServerCertFile, "serverCert", "", "the cert file of the server")
-	flag.StringVar(&argInfo.ServerKeyFile, "serverKey", "", "the key file of the server")
+	flag.StringVar(&argInfo.RaftTLS.CAFile, "serverCAFile", "", "the path of the CAFile")
+	flag.StringVar(&argInfo.RaftTLS.CertFile, "serverCert", "", "the cert file of the server")
+	flag.StringVar(&argInfo.RaftTLS.KeyFile, "serverKey", "", "the key file of the server")
 
-	flag.StringVar(&argInfo.ClientCAFile, "clientCAFile", "", "the path of the client CAFile")
-	flag.StringVar(&argInfo.ClientCertFile, "clientCert", "", "the cert file of the client")
-	flag.StringVar(&argInfo.ClientKeyFile, "clientKey", "", "the key file of the client")
+	flag.StringVar(&argInfo.EtcdTLS.CAFile, "clientCAFile", "", "the path of the client CAFile")
+	flag.StringVar(&argInfo.EtcdTLS.CertFile, "clientCert", "", "the cert file of the client")
+	flag.StringVar(&argInfo.EtcdTLS.KeyFile, "clientKey", "", "the key file of the client")
 
 	flag.StringVar(&dirPath, "d", ".", "the directory to store log and snapshot")
 
@@ -86,12 +86,6 @@ func init() {
 	flag.StringVar(&cpuprofile, "cpuprofile", "", "write cpu profile to file")
 }
 
-// CONSTANTS
-const (
-	RaftServer = iota
-	EtcdServer
-)
-
 const (
 	ELECTIONTIMEOUT  = 200 * time.Millisecond
 	HEARTBEATTIMEOUT = 50 * time.Millisecond
@@ -109,6 +103,12 @@ const (
 //
 //------------------------------------------------------------------------------
 
+type TLSInfo struct {
+	CertFile string `json:"serverCertFile"`
+	KeyFile  string `json:"serverKeyFile"`
+	CAFile   string `json:"serverCAFile"`
+}
+
 type Info struct {
 	Name string `json:"name"`
 
@@ -116,13 +116,8 @@ type Info struct {
 	EtcdURL string `json:"etcdURL"`
 	WebURL  string `json:"webURL"`
 
-	ServerCertFile string `json:"serverCertFile"`
-	ServerKeyFile  string `json:"serverKeyFile"`
-	ServerCAFile   string `json:"serverCAFile"`
-
-	ClientCertFile string `json:"clientCertFile"`
-	ClientKeyFile  string `json:"clientKeyFile"`
-	ClientCAFile   string `json:"clientCAFile"`
+	RaftTLS TLSInfo `json:"raftTLS"`
+	EtcdTLS TLSInfo `json:"raftTLS"`
 }
 
 //------------------------------------------------------------------------------
@@ -208,35 +203,23 @@ func main() {
 		cluster = strings.Split(string(b), ",")
 	}
 
-	raftTlsConfs, ok := tlsConf(RaftServer)
+	raftTLSConfig, ok := tlsConfigFromInfo(argInfo.RaftTLS)
 	if !ok {
 		fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
 	}
 
-	raftDefaultScheme := "http"
-	if raftTlsConfs[0] != nil {
-		raftDefaultScheme = "https"
-	}
-
-	etcdTlsConfs, ok := tlsConf(EtcdServer)
+	etcdTLSConfig, ok := tlsConfigFromInfo(argInfo.EtcdTLS)
 	if !ok {
 		fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
 	}
 
-	etcdDefaultScheme := "http"
-	if etcdTlsConfs[0] != nil {
-		raftDefaultScheme = "https"
-	}
-
-	// Otherwise ask user for info and write it to file.
 	argInfo.Name = strings.TrimSpace(argInfo.Name)
-
 	if argInfo.Name == "" {
 		fatal("ERROR: server name required. e.g. '-n=server_name'")
 	}
 
-	argInfo.RaftURL = sanitizeURL(argInfo.RaftURL, raftTlsConfig.Scheme)
-	argInfo.EtcdURL = sanitizeURL(argInfo.EtcdURL, etcdTlsConfig.Scheme)
+	argInfo.RaftURL = sanitizeURL(argInfo.RaftURL, raftTLSConfig.Scheme)
+	argInfo.EtcdURL = sanitizeURL(argInfo.EtcdURL, etcdTLSConfig.Scheme)
 	argInfo.WebURL = sanitizeURL(argInfo.WebURL, "http")
 
 	// Setup commands.
@@ -252,27 +235,27 @@ func main() {
 	// Create etcd key-value store
 	etcdStore = store.CreateStore(maxSize)
 
-	startRaft(raftTlsConfs)
+	startRaft(raftTLSConfig)
 
 	if argInfo.WebURL != "" {
 		// start web
-		argInfo.WebURL = checkURL(argInfo.WebURL, "http")
+		argInfo.WebURL = sanitizeURL(argInfo.WebURL, "http")
 		go webHelper()
 		go web.Start(raftServer, argInfo.WebURL)
 	}
 
-	startEtcdTransport(*info, etcdTlsConfs[0])
+	startEtcdTransport(*info, etcdTLSConfig.Scheme, etcdTLSConfig.Server)
 
 }
 
 // Start the raft server
-func startRaft(tlsConfs []*tls.Config) {
+func startRaft(tlsConfig TLSConfig) {
 	var err error
 
 	raftName := info.Name
 
 	// Create transporter for raft
-	raftTransporter = newTransporter(tlsConfs[1])
+	raftTransporter = newTransporter(tlsConfig.Scheme, tlsConfig.Client)
 
 	// Create raft server
 	raftServer, err = raft.NewServer(raftName, dirPath, raftTransporter, etcdStore, nil)
@@ -367,37 +350,29 @@ func startRaft(tlsConfs []*tls.Config) {
 	}
 
 	// start to response to raft requests
-	go startRaftTransport(*info, tlsConfs[0])
+	go startRaftTransport(*info, tlsConfig.Scheme, tlsConfig.Server)
 
 }
 
 // Create transporter using by raft server
 // Create http or https transporter based on
 // whether the user give the server cert and key
-func newTransporter(tlsConf *tls.Config) transporter {
+func newTransporter(scheme string, tlsConf tls.Config) transporter {
 	t := transporter{}
 
-	if tlsConf == nil {
-		t.scheme = "http://"
-
-		t.client = &http.Client{
-			Transport: &http.Transport{
-				Dial: dialTimeout,
-			},
-		}
-
-	} else {
-		t.scheme = "https://"
+	t.scheme = scheme
 
-		tr := &http.Transport{
-			TLSClientConfig:    tlsConf,
-			Dial:               dialTimeout,
-			DisableCompression: true,
-		}
+	tr := &http.Transport{
+		Dial:               dialTimeout,
+	}
 
-		t.client = &http.Client{Transport: tr}
+	if scheme == "https" {
+		tr.TLSClientConfig = &tlsConf
+		tr.DisableCompression = true
 	}
 
+	t.client = &http.Client{Transport: tr}
+
 	return t
 }
 
@@ -407,7 +382,7 @@ func dialTimeout(network, addr string) (net.Conn, error) {
 }
 
 // Start to listen and response raft command
-func startRaftTransport(info Info, tlsConf *tls.Config) {
+func startRaftTransport(info Info, scheme string, tlsConf tls.Config) {
 	u, _ := url.Parse(info.RaftURL)
 	fmt.Printf("raft server [%s] listening on %s\n", info.Name, u)
 
@@ -415,7 +390,7 @@ func startRaftTransport(info Info, tlsConf *tls.Config) {
 
 	server := &http.Server{
 		Handler:   raftMux,
-		TLSConfig: tlsConf,
+		TLSConfig: &tlsConf,
 		Addr:      u.Host,
 	}
 
@@ -429,16 +404,16 @@ func startRaftTransport(info Info, tlsConf *tls.Config) {
 	raftMux.HandleFunc("/snapshotRecovery", SnapshotRecoveryHttpHandler)
 	raftMux.HandleFunc("/etcdURL", EtcdURLHttpHandler)
 
-	if tlsConf == nil {
+	if scheme == "http" {
 		fatal(server.ListenAndServe())
 	} else {
-		fatal(server.ListenAndServeTLS(info.ServerCertFile, argInfo.ServerKeyFile))
+		fatal(server.ListenAndServeTLS(info.RaftTLS.CertFile, info.RaftTLS.KeyFile))
 	}
 
 }
 
 // Start to listen and response client command
-func startEtcdTransport(info Info, tlsConf *tls.Config) {
+func startEtcdTransport(info Info, scheme string, tlsConf tls.Config) {
 	u, _ := url.Parse(info.EtcdURL)
 	fmt.Printf("etcd server [%s] listening on %s\n", info.Name, u)
 
@@ -446,7 +421,7 @@ func startEtcdTransport(info Info, tlsConf *tls.Config) {
 
 	server := &http.Server{
 		Handler:   etcdMux,
-		TLSConfig: tlsConf,
+		TLSConfig: &tlsConf,
 		Addr:      u.Host,
 	}
 
@@ -459,68 +434,57 @@ func startEtcdTransport(info Info, tlsConf *tls.Config) {
 	etcdMux.HandleFunc("/stats", StatsHttpHandler)
 	etcdMux.HandleFunc("/test/", TestHttpHandler)
 
-	if tlsConf == nil {
+	if scheme == "http" {
 		fatal(server.ListenAndServe())
 	} else {
-		fatal(server.ListenAndServeTLS(info.ClientCertFile, info.ClientKeyFile))
+		fatal(server.ListenAndServeTLS(info.EtcdTLS.CertFile, info.EtcdTLS.KeyFile))
 	}
 }
 
 //--------------------------------------
 // Config
 //--------------------------------------
-func tlsConf(source int) ([]*tls.Config, bool) {
+
+type TLSConfig struct {
+	Scheme string
+	Server tls.Config
+	Client tls.Config
+}
+
+func tlsConfigFromInfo(info TLSInfo) (t TLSConfig, ok bool) {
 	var keyFile, certFile, CAFile string
 	var tlsCert tls.Certificate
-	var isAuth bool
 	var err error
 
-	switch source {
-
-	case RaftServer:
-		keyFile = info.ServerKeyFile
-		certFile = info.ServerCertFile
-		CAFile = info.ServerCAFile
-
-		if keyFile != "" && certFile != "" {
-			tlsCert, err = tls.LoadX509KeyPair(certFile, keyFile)
-			if err == nil {
-				fatal(err)
-			}
-			isAuth = true
-		}
+	t.Scheme = "http"
 
-	case EtcdServer:
-		keyFile = info.ClientKeyFile
-		certFile = info.ClientCertFile
-		CAFile = info.ClientCAFile
-	}
+	keyFile = info.KeyFile
+	certFile = info.CertFile
+	CAFile = info.CAFile
 
 	// If the user do not specify key file, cert file and
 	// CA file, the type will be HTTP
 	if keyFile == "" && certFile == "" && CAFile == "" {
-		return []*tls.Config{nil, nil}, true
+		return t, true
 	}
 
-	if keyFile != "" && certFile != "" {
-		serverConf := &tls.Config{}
-		serverConf.ClientAuth, serverConf.ClientCAs = newCertPool(CAFile)
-
-		if isAuth {
-			raftTransConf := &tls.Config{
-				Certificates:       []tls.Certificate{tlsCert},
-				InsecureSkipVerify: true,
-			}
-			return []*tls.Config{serverConf, raftTransConf}, true
-		}
-
-		return []*tls.Config{serverConf, nil}, true
+	// both the key and cert must be present
+	if keyFile == "" || certFile == "" {
+		return t, false
+	}
 
+	tlsCert, err = tls.LoadX509KeyPair(certFile, keyFile)
+	if err == nil {
+		fatal(err)
 	}
 
-	// bad specification
-	return nil, false
+	t.Scheme = "https"
+	t.Server.Certificates = []tls.Certificate{tlsCert}
+	t.Server.InsecureSkipVerify = true
+
+	t.Client.ClientAuth, t.Client.ClientCAs = newCertPool(CAFile)
 
+	return t, true
 }
 
 func parseInfo(path string) *Info {