Pārlūkot izejas kodu

*: add client support for discovery-srv-name

Signed-off-by: Sam Batschelet <sbatsche@redhat.com>
Sam Batschelet 7 gadi atpakaļ
vecāks
revīzija
fa35126ef8

+ 3 - 3
client/discover.go

@@ -21,7 +21,7 @@ import (
 // Discoverer is an interface that wraps the Discover method.
 type Discoverer interface {
 	// Discover looks up the etcd servers for the domain.
-	Discover(domain string) ([]string, error)
+	Discover(domain string, serviceName string) ([]string, error)
 }
 
 type srvDiscover struct{}
@@ -31,8 +31,8 @@ func NewSRVDiscover() Discoverer {
 	return &srvDiscover{}
 }
 
-func (d *srvDiscover) Discover(domain string) ([]string, error) {
-	srvs, err := srv.GetClient("etcd-client", domain)
+func (d *srvDiscover) Discover(domain string, serviceName string) ([]string, error) {
+	srvs, err := srv.GetClient("etcd-client", domain, serviceName)
 	if err != nil {
 		return nil, err
 	}

+ 9 - 5
etcdctl/ctlv2/command/util.go

@@ -86,7 +86,7 @@ func getPeersFlagValue(c *cli.Context) []string {
 }
 
 func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
-	domainstr, insecure := getDiscoveryDomain(c)
+	domainstr, insecure, serviceName := getDiscoveryDomain(c)
 
 	// If we still don't have domain discovery, return nothing
 	if domainstr == "" {
@@ -94,7 +94,7 @@ func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
 	}
 
 	discoverer := client.NewSRVDiscover()
-	eps, err := discoverer.Discover(domainstr)
+	eps, err := discoverer.Discover(domainstr, serviceName)
 	if err != nil {
 		return nil, err
 	}
@@ -113,7 +113,7 @@ func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
 	return ret, err
 }
 
-func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool) {
+func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool, serviceName string) {
 	domainstr = c.GlobalString("discovery-srv")
 	// Use an environment variable if nothing was supplied on the
 	// command line
@@ -121,7 +121,11 @@ func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool) {
 		domainstr = os.Getenv("ETCDCTL_DISCOVERY_SRV")
 	}
 	insecure = c.GlobalBool("insecure-discovery") || (os.Getenv("ETCDCTL_INSECURE_DISCOVERY") != "")
-	return domainstr, insecure
+	serviceName = c.GlobalString("discovery-srv-name")
+	if serviceName == "" {
+		serviceName = os.Getenv("ETCDCTL_DISCOVERY_SRV_NAME")
+	}
+	return domainstr, insecure, serviceName
 }
 
 func getEndpoints(c *cli.Context) ([]string, error) {
@@ -168,7 +172,7 @@ func getTransport(c *cli.Context) (*http.Transport, error) {
 		keyfile = os.Getenv("ETCDCTL_KEY_FILE")
 	}
 
-	discoveryDomain, insecure := getDiscoveryDomain(c)
+	discoveryDomain, insecure, _ := getDiscoveryDomain(c)
 	if insecure {
 		discoveryDomain = ""
 	}

+ 24 - 13
etcdctl/ctlv3/command/global.go

@@ -39,14 +39,15 @@ import (
 // GlobalFlags are flags that defined globally
 // and are inherited to all sub-commands.
 type GlobalFlags struct {
-	Insecure           bool
-	InsecureSkipVerify bool
-	InsecureDiscovery  bool
-	Endpoints          []string
-	DialTimeout        time.Duration
-	CommandTimeOut     time.Duration
-	KeepAliveTime      time.Duration
-	KeepAliveTimeout   time.Duration
+	Insecure              bool
+	InsecureSkipVerify    bool
+	InsecureDiscovery     bool
+	Endpoints             []string
+	DialTimeout           time.Duration
+	CommandTimeOut        time.Duration
+	KeepAliveTime         time.Duration
+	KeepAliveTimeout      time.Duration
+	DNSClusterServiceName string
 
 	TLS transport.TLSInfo
 
@@ -75,8 +76,9 @@ type authCfg struct {
 }
 
 type discoveryCfg struct {
-	domain   string
-	insecure bool
+	domain      string
+	insecure    bool
+	serviceName string
 }
 
 var display printer = &simplePrinter{}
@@ -390,10 +392,19 @@ func discoverySrvFromCmd(cmd *cobra.Command) string {
 	return domainStr
 }
 
+func discoveryDNSClusterServiceNameFromCmd(cmd *cobra.Command) string {
+	serviceNameStr, err := cmd.Flags().GetString("discovery-srv-name")
+	if err != nil {
+		ExitWithError(ExitBadArgs, err)
+	}
+	return serviceNameStr
+}
+
 func discoveryCfgFromCmd(cmd *cobra.Command) *discoveryCfg {
 	return &discoveryCfg{
-		domain:   discoverySrvFromCmd(cmd),
-		insecure: insecureDiscoveryFromCmd(cmd),
+		domain:      discoverySrvFromCmd(cmd),
+		insecure:    insecureDiscoveryFromCmd(cmd),
+		serviceName: discoveryDNSClusterServiceNameFromCmd(cmd),
 	}
 }
 
@@ -422,7 +433,7 @@ func endpointsFromFlagValue(cmd *cobra.Command) ([]string, error) {
 		return []string{}, nil
 	}
 
-	srvs, err := srv.GetClient("etcd-client", discoveryCfg.domain)
+	srvs, err := srv.GetClient("etcd-client", discoveryCfg.domain, discoveryCfg.serviceName)
 	if err != nil {
 		return nil, err
 	}

+ 1 - 0
etcdctl/ctlv3/ctl.go

@@ -67,6 +67,7 @@ func init() {
 	rootCmd.PersistentFlags().StringVar(&globalFlags.User, "user", "", "username[:password] for authentication (prompt if password is not supplied)")
 	rootCmd.PersistentFlags().StringVar(&globalFlags.Password, "password", "", "password for authentication (if this option is used, --user option shouldn't include password)")
 	rootCmd.PersistentFlags().StringVarP(&globalFlags.TLS.ServerName, "discovery-srv", "d", "", "domain name to query for SRV records describing cluster endpoints")
+	rootCmd.PersistentFlags().StringVarP(&globalFlags.DNSClusterServiceName, "discovery-srv-name", "", "", "service name to query when using DNS discovery")
 
 	rootCmd.AddCommand(
 		command.NewGetCommand(),

+ 9 - 7
etcdmain/gateway.go

@@ -28,12 +28,13 @@ import (
 )
 
 var (
-	gatewayListenAddr        string
-	gatewayEndpoints         []string
-	gatewayDNSCluster        string
-	gatewayInsecureDiscovery bool
-	getewayRetryDelay        time.Duration
-	gatewayCA                string
+	gatewayListenAddr            string
+	gatewayEndpoints             []string
+	gatewayDNSCluster            string
+	gatewayDNSClusterServiceName string
+	gatewayInsecureDiscovery     bool
+	getewayRetryDelay            time.Duration
+	gatewayCA                    string
 )
 
 var (
@@ -68,6 +69,7 @@ func newGatewayStartCommand() *cobra.Command {
 
 	cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address")
 	cmd.Flags().StringVar(&gatewayDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster")
+	cmd.Flags().StringVar(&gatewayDNSClusterServiceName, "discovery-srv-name", "", "service name to query when using DNS discovery")
 	cmd.Flags().BoolVar(&gatewayInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records")
 	cmd.Flags().StringVar(&gatewayCA, "trusted-ca-file", "", "path to the client server TLS CA file.")
 
@@ -97,7 +99,7 @@ func startGateway(cmd *cobra.Command, args []string) {
 		os.Exit(1)
 	}
 
-	srvs := discoverEndpoints(lg, gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery)
+	srvs := discoverEndpoints(lg, gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery, gatewayDNSClusterServiceName)
 	if len(srvs.Endpoints) == 0 {
 		// no endpoints discovered, fall back to provided endpoints
 		srvs.Endpoints = gatewayEndpoints

+ 12 - 10
etcdmain/grpc_proxy.go

@@ -49,14 +49,15 @@ import (
 )
 
 var (
-	grpcProxyListenAddr        string
-	grpcProxyMetricsListenAddr string
-	grpcProxyEndpoints         []string
-	grpcProxyDNSCluster        string
-	grpcProxyInsecureDiscovery bool
-	grpcProxyDataDir           string
-	grpcMaxCallSendMsgSize     int
-	grpcMaxCallRecvMsgSize     int
+	grpcProxyListenAddr            string
+	grpcProxyMetricsListenAddr     string
+	grpcProxyEndpoints             []string
+	grpcProxyDNSCluster            string
+	grpcProxyDNSClusterServiceName string
+	grpcProxyInsecureDiscovery     bool
+	grpcProxyDataDir               string
+	grpcMaxCallSendMsgSize         int
+	grpcMaxCallRecvMsgSize         int
 
 	// tls for connecting to etcd
 
@@ -111,7 +112,8 @@ func newGRPCProxyStartCommand() *cobra.Command {
 	}
 
 	cmd.Flags().StringVar(&grpcProxyListenAddr, "listen-addr", "127.0.0.1:23790", "listen address")
-	cmd.Flags().StringVar(&grpcProxyDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster")
+	cmd.Flags().StringVar(&grpcProxyDNSCluster, "discovery-srv", "", "domain name to query for SRV records describing cluster endpoints")
+	cmd.Flags().StringVar(&grpcProxyDNSClusterServiceName, "discovery-srv-name", "", "service name to query when using DNS discovery")
 	cmd.Flags().StringVar(&grpcProxyMetricsListenAddr, "metrics-addr", "", "listen for /metrics requests on an additional interface")
 	cmd.Flags().BoolVar(&grpcProxyInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records")
 	cmd.Flags().StringSliceVar(&grpcProxyEndpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd cluster endpoints")
@@ -249,7 +251,7 @@ func checkArgs() {
 }
 
 func mustNewClient(lg *zap.Logger) *clientv3.Client {
-	srvs := discoverEndpoints(lg, grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery)
+	srvs := discoverEndpoints(lg, grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery, grpcProxyDNSClusterServiceName)
 	eps := srvs.Endpoints
 	if len(eps) == 0 {
 		eps = grpcProxyEndpoints

+ 2 - 2
etcdmain/util.go

@@ -24,11 +24,11 @@ import (
 	"go.uber.org/zap"
 )
 
-func discoverEndpoints(lg *zap.Logger, dns string, ca string, insecure bool) (s srv.SRVClients) {
+func discoverEndpoints(lg *zap.Logger, dns string, ca string, insecure bool, serviceName string) (s srv.SRVClients) {
 	if dns == "" {
 		return s
 	}
-	srvs, err := srv.GetClient("etcd-client", dns)
+	srvs, err := srv.GetClient("etcd-client", dns, serviceName)
 	if err != nil {
 		fmt.Fprintln(os.Stderr, err)
 		os.Exit(1)

+ 15 - 3
pkg/srv/srv.go

@@ -96,7 +96,7 @@ type SRVClients struct {
 }
 
 // GetClient looks up the client endpoints for a service and domain.
-func GetClient(service, domain string) (*SRVClients, error) {
+func GetClient(service, domain string, serviceName string) (*SRVClients, error) {
 	var urls []*url.URL
 	var srvs []*net.SRV
 
@@ -115,8 +115,8 @@ func GetClient(service, domain string) (*SRVClients, error) {
 		return nil
 	}
 
-	errHTTPS := updateURLs(service+"-ssl", "https")
-	errHTTP := updateURLs(service, "http")
+	errHTTPS := updateURLs(GetSRVService(service, serviceName, "https"), "https")
+	errHTTP := updateURLs(GetSRVService(service, serviceName, "http"), "http")
 
 	if errHTTPS != nil && errHTTP != nil {
 		return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
@@ -128,3 +128,15 @@ func GetClient(service, domain string) (*SRVClients, error) {
 	}
 	return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil
 }
+
+// GetSRVService generates a SRV service including an optional suffix.
+func GetSRVService(service, serviceName string, scheme string) (SRVService string) {
+	if scheme == "https" {
+		service = fmt.Sprintf("%s-ssl", service)
+	}
+
+	if serviceName != "" {
+		return fmt.Sprintf("%s-%s", service, serviceName)
+	}
+	return service
+}

+ 38 - 1
pkg/srv/srv_test.go

@@ -188,7 +188,7 @@ func TestSRVDiscover(t *testing.T) {
 			return "", nil, errors.New("Unknown service in mock")
 		}
 
-		srvs, err := GetClient("etcd-client", "example.com")
+		srvs, err := GetClient("etcd-client", "example.com", "")
 		if err != nil {
 			t.Fatalf("%d: err: %#v", i, err)
 		}
@@ -199,3 +199,40 @@ func TestSRVDiscover(t *testing.T) {
 
 	}
 }
+
+func TestGetSRVService(t *testing.T) {
+	tests := []struct {
+		scheme      string
+		serviceName string
+
+		expected string
+	}{
+		{
+			"https",
+			"",
+			"etcd-client-ssl",
+		},
+		{
+			"http",
+			"",
+			"etcd-client",
+		},
+		{
+			"https",
+			"foo",
+			"etcd-client-ssl-foo",
+		},
+		{
+			"http",
+			"bar",
+			"etcd-client-bar",
+		},
+	}
+
+	for i, tt := range tests {
+		service := GetSRVService("etcd-client", tt.serviceName, tt.scheme)
+		if strings.Compare(service, tt.expected) != 0 {
+			t.Errorf("#%d: service = %s, want %s", i, service, tt.expected)
+		}
+	}
+}