Procházet zdrojové kódy

Merge pull request #6084 from heyitsanthony/srv-servername

etcdctl: set TLS servername on discovery
Anthony Romano před 9 roky
rodič
revize
eb36d0dbba

+ 2 - 0
Documentation/op-guide/clustering.md

@@ -357,6 +357,8 @@ To help clients discover the etcd cluster, the following DNS SRV records are loo
 
 If `_etcd-client-ssl._tcp.example.com` is found, clients will attempt to communicate with the etcd cluster over SSL/TLS.
 
+If etcd is using TLS without a custom certificate authority, the discovery domain (e.g., example.com) must match the SRV record domain (e.g., infra1.example.com). This is to mitigate attacks that forge SRV records to point to a different domain; the domain would have a valid certificate under PKI but be controlled by an unknown third party.
+
 #### Create DNS SRV records
 
 ```

+ 3 - 0
embed/config.go

@@ -281,6 +281,9 @@ func (cfg *Config) PeerURLsMapAndToken(which string) (urlsmap types.URLsMap, tok
 		if err != nil {
 			return nil, "", err
 		}
+		if strings.Contains(clusterStr, "https://") && cfg.PeerTLSInfo.CAFile == "" {
+			cfg.PeerTLSInfo.ServerName = cfg.DNSCluster
+		}
 		urlsmap, err = types.NewURLsMap(clusterStr)
 		// only etcd member must belong to the discovered cluster.
 		// proxy does not need to belong to the discovered cluster.

+ 32 - 11
etcdctl/ctlv2/command/util.go

@@ -85,13 +85,7 @@ func getPeersFlagValue(c *cli.Context) []string {
 }
 
 func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
-	domainstr := c.GlobalString("discovery-srv")
-
-	// Use an environment variable if nothing was supplied on the
-	// command line
-	if domainstr == "" {
-		domainstr = os.Getenv("ETCDCTL_DISCOVERY_SRV")
-	}
+	domainstr, insecure := getDiscoveryDomain(c)
 
 	// If we still don't have domain discovery, return nothing
 	if domainstr == "" {
@@ -103,8 +97,30 @@ func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
 	if err != nil {
 		return nil, err
 	}
+	if insecure {
+		return eps, err
+	}
+	// strip insecure connections
+	ret := []string{}
+	for _, ep := range eps {
+		if strings.HasPrefix("http://", ep) {
+			fmt.Fprintf(os.Stderr, "ignoring discovered insecure endpoint %q\n", ep)
+			continue
+		}
+		ret = append(ret, ep)
+	}
+	return ret, err
+}
 
-	return eps, err
+func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool) {
+	domainstr = c.GlobalString("discovery-srv")
+	// Use an environment variable if nothing was supplied on the
+	// command line
+	if domainstr == "" {
+		domainstr = os.Getenv("ETCDCTL_DISCOVERY_SRV")
+	}
+	insecure = c.GlobalBool("insecure-discovery") || (os.Getenv("ETCDCTL_INSECURE_DISCOVERY") != "")
+	return domainstr, insecure
 }
 
 func getEndpoints(c *cli.Context) ([]string, error) {
@@ -151,10 +167,15 @@ func getTransport(c *cli.Context) (*http.Transport, error) {
 		keyfile = os.Getenv("ETCDCTL_KEY_FILE")
 	}
 
+	discoveryDomain, insecure := getDiscoveryDomain(c)
+	if insecure {
+		discoveryDomain = ""
+	}
 	tls := transport.TLSInfo{
-		CAFile:   cafile,
-		CertFile: certfile,
-		KeyFile:  keyfile,
+		CAFile:     cafile,
+		CertFile:   certfile,
+		KeyFile:    keyfile,
+		ServerName: discoveryDomain,
 	}
 
 	dialTimeout := defaultDialTimeout

+ 1 - 0
etcdctl/ctlv2/ctl.go

@@ -39,6 +39,7 @@ func Start() {
 		cli.BoolFlag{Name: "no-sync", Usage: "don't synchronize cluster information before sending request"},
 		cli.StringFlag{Name: "output, o", Value: "simple", Usage: "output response in the given format (`simple`, `extended` or `json`)"},
 		cli.StringFlag{Name: "discovery-srv, D", Usage: "domain name to query for SRV records describing cluster endpoints"},
+		cli.BoolFlag{Name: "insecure-discovery", Usage: "accept insecure SRV records describing cluster endpoints"},
 		cli.StringFlag{Name: "peers, C", Value: "", Usage: "DEPRECATED - \"--endpoints\" should be used instead"},
 		cli.StringFlag{Name: "endpoint", Value: "", Usage: "DEPRECATED - \"--endpoints\" should be used instead"},
 		cli.StringFlag{Name: "endpoints", Value: "", Usage: "a comma-delimited list of machine addresses in the cluster (default: \"http://127.0.0.1:2379,http://127.0.0.1:4001\")"},

+ 26 - 4
etcdmain/gateway.go

@@ -21,15 +21,18 @@ import (
 	"time"
 
 	"github.com/coreos/etcd/client"
+	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/proxy/tcpproxy"
 	"github.com/spf13/cobra"
 )
 
 var (
-	gatewayListenAddr string
-	gatewayEndpoints  []string
-	gatewayDNSCluster string
-	getewayRetryDelay time.Duration
+	gatewayListenAddr        string
+	gatewayEndpoints         []string
+	gatewayDNSCluster        string
+	gatewayInsecureDiscovery bool
+	getewayRetryDelay        time.Duration
+	gatewayCA                string
 )
 
 var (
@@ -64,6 +67,8 @@ func newGatewayStartCommand() *cobra.Command {
 
 	cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address")
 	cmd.Flags().StringVar(&gatewayDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster")
+	cmd.Flags().BoolVar(&gatewayInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records")
+	cmd.Flags().StringVar(&gatewayCA, "trusted-ca-file", "", "path to the client server TLS CA file.")
 
 	cmd.Flags().StringSliceVar(&gatewayEndpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd cluster endpoints")
 
@@ -81,6 +86,23 @@ func startGateway(cmd *cobra.Command, args []string) {
 			os.Exit(1)
 		}
 		plog.Infof("discovered the cluster %s from %s", eps, gatewayDNSCluster)
+		// confirm TLS connections are good
+		if !gatewayInsecureDiscovery {
+			tlsInfo := transport.TLSInfo{
+				TrustedCAFile: gatewayCA,
+				ServerName:    gatewayDNSCluster,
+			}
+			plog.Infof("validating discovered endpoints %v", eps)
+			endpoints, err = transport.ValidateSecureEndpoints(tlsInfo, eps)
+			if err != nil {
+				plog.Warningf("%v", err)
+			}
+			plog.Infof("using discovered endpoints %v", endpoints)
+		}
+	}
+
+	if len(endpoints) == 0 {
+		plog.Fatalf("no endpoints found")
 	}
 
 	l, err := net.Listen("tcp", gatewayListenAddr)

+ 7 - 1
pkg/transport/listener.go

@@ -67,6 +67,9 @@ type TLSInfo struct {
 	TrustedCAFile  string
 	ClientCertAuth bool
 
+	// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
+	ServerName string
+
 	selfCert bool
 
 	// parseFunc exists to simplify testing. Typically, parseFunc
@@ -167,6 +170,7 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
 	cfg := &tls.Config{
 		Certificates: []tls.Certificate{*tlsCert},
 		MinVersion:   tls.VersionTLS12,
+		ServerName:   info.ServerName,
 	}
 	return cfg, nil
 }
@@ -218,7 +222,7 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) {
 			return nil, err
 		}
 	} else {
-		cfg = &tls.Config{}
+		cfg = &tls.Config{ServerName: info.ServerName}
 	}
 
 	CAFiles := info.cafiles()
@@ -227,6 +231,8 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) {
 		if err != nil {
 			return nil, err
 		}
+		// if given a CA, trust any host with a cert signed by the CA
+		cfg.ServerName = ""
 	}
 
 	if info.selfCert {

+ 49 - 0
pkg/transport/tls.go

@@ -0,0 +1,49 @@
+// Copyright 2016 The etcd Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+	"fmt"
+	"strings"
+	"time"
+)
+
+// ValidateSecureEndpoints scans the given endpoints against tls info, returning only those
+// endpoints that could be validated as secure.
+func ValidateSecureEndpoints(tlsInfo TLSInfo, eps []string) ([]string, error) {
+	t, err := NewTransport(tlsInfo, 5*time.Second)
+	if err != nil {
+		return nil, err
+	}
+	var errs []string
+	var endpoints []string
+	for _, ep := range eps {
+		if !strings.HasPrefix(ep, "https://") {
+			errs = append(errs, fmt.Sprintf("%q is insecure", ep))
+			continue
+		}
+		conn, cerr := t.Dial("tcp", ep[len("https://"):])
+		if cerr != nil {
+			errs = append(errs, fmt.Sprintf("%q failed to dial (%v)", ep, cerr))
+			continue
+		}
+		conn.Close()
+		endpoints = append(endpoints, ep)
+	}
+	if len(errs) != 0 {
+		err = fmt.Errorf("%s", strings.Join(errs, ","))
+	}
+	return endpoints, err
+}