Browse Source

pkg/srv: package for SRV utilities

Trying to decouple the v2 client from SRV code. Can't move
into discovery/ since that creates a circular dependency. So,
give up and move all the SRV code into a new package.
Anthony Romano 8 years ago
parent
commit
07ad18178d
7 changed files with 170 additions and 198 deletions
  1. 19 0
      client/discover.go
  2. 0 65
      client/srv.go
  3. 0 102
      client/srv_test.go
  4. 9 5
      embed/config.go
  5. 3 2
      etcdmain/util.go
  6. 57 21
      pkg/srv/srv.go
  7. 82 3
      pkg/srv/srv_test.go

+ 19 - 0
client/discover.go

@@ -14,8 +14,27 @@
 
 package client
 
+import (
+	"github.com/coreos/etcd/pkg/srv"
+)
+
 // Discoverer is an interface that wraps the Discover method.
 type Discoverer interface {
 	// Discover looks up the etcd servers for the domain.
 	Discover(domain string) ([]string, error)
 }
+
+type srvDiscover struct{}
+
+// NewSRVDiscover constructs a new Discoverer that uses the stdlib to lookup SRV records.
+func NewSRVDiscover() Discoverer {
+	return &srvDiscover{}
+}
+
+func (d *srvDiscover) Discover(domain string) ([]string, error) {
+	srvs, err := srv.GetClient("etcd-client", domain)
+	if err != nil {
+		return nil, err
+	}
+	return srvs.Endpoints, nil
+}

+ 0 - 65
client/srv.go

@@ -1,65 +0,0 @@
-// Copyright 2015 The etcd Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package client
-
-import (
-	"fmt"
-	"net"
-	"net/url"
-)
-
-var (
-	// indirection for testing
-	lookupSRV = net.LookupSRV
-)
-
-type srvDiscover struct{}
-
-// NewSRVDiscover constructs a new Discoverer that uses the stdlib to lookup SRV records.
-func NewSRVDiscover() Discoverer {
-	return &srvDiscover{}
-}
-
-// Discover looks up the etcd servers for the domain.
-func (d *srvDiscover) Discover(domain string) ([]string, error) {
-	var urls []*url.URL
-
-	updateURLs := func(service, scheme string) error {
-		_, addrs, err := lookupSRV(service, "tcp", domain)
-		if err != nil {
-			return err
-		}
-		for _, srv := range addrs {
-			urls = append(urls, &url.URL{
-				Scheme: scheme,
-				Host:   net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)),
-			})
-		}
-		return nil
-	}
-
-	errHTTPS := updateURLs("etcd-client-ssl", "https")
-	errHTTP := updateURLs("etcd-client", "http")
-
-	if errHTTPS != nil && errHTTP != nil {
-		return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
-	}
-
-	endpoints := make([]string, len(urls))
-	for i := range urls {
-		endpoints[i] = urls[i].String()
-	}
-	return endpoints, nil
-}

+ 0 - 102
client/srv_test.go

@@ -1,102 +0,0 @@
-// Copyright 2015 The etcd Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package client
-
-import (
-	"errors"
-	"net"
-	"reflect"
-	"testing"
-)
-
-func TestSRVDiscover(t *testing.T) {
-	defer func() { lookupSRV = net.LookupSRV }()
-
-	tests := []struct {
-		withSSL    []*net.SRV
-		withoutSSL []*net.SRV
-		expected   []string
-	}{
-		{
-			[]*net.SRV{},
-			[]*net.SRV{},
-			[]string{},
-		},
-		{
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2480},
-				{Target: "10.0.0.2", Port: 2480},
-				{Target: "10.0.0.3", Port: 2480},
-			},
-			[]*net.SRV{},
-			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480"},
-		},
-		{
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2480},
-				{Target: "10.0.0.2", Port: 2480},
-				{Target: "10.0.0.3", Port: 2480},
-			},
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 7001},
-			},
-			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"},
-		},
-		{
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2480},
-				{Target: "10.0.0.2", Port: 2480},
-				{Target: "10.0.0.3", Port: 2480},
-			},
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 7001},
-			},
-			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"},
-		},
-		{
-			[]*net.SRV{
-				{Target: "a.example.com", Port: 2480},
-				{Target: "b.example.com", Port: 2480},
-				{Target: "c.example.com", Port: 2480},
-			},
-			[]*net.SRV{},
-			[]string{"https://a.example.com:2480", "https://b.example.com:2480", "https://c.example.com:2480"},
-		},
-	}
-
-	for i, tt := range tests {
-		lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) {
-			if service == "etcd-client-ssl" {
-				return "", tt.withSSL, nil
-			}
-			if service == "etcd-client" {
-				return "", tt.withoutSSL, nil
-			}
-			return "", nil, errors.New("Unknown service in mock")
-		}
-
-		d := NewSRVDiscover()
-
-		endpoints, err := d.Discover("example.com")
-		if err != nil {
-			t.Fatalf("%d: err: %#v", i, err)
-		}
-
-		if !reflect.DeepEqual(endpoints, tt.expected) {
-			t.Errorf("#%d: endpoints = %v, want %v", i, endpoints, tt.expected)
-		}
-
-	}
-}

+ 9 - 5
embed/config.go

@@ -22,10 +22,10 @@ import (
 	"net/url"
 	"strings"
 
-	"github.com/coreos/etcd/discovery"
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/pkg/cors"
 	"github.com/coreos/etcd/pkg/netutil"
+	"github.com/coreos/etcd/pkg/srv"
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 
@@ -321,11 +321,15 @@ func (cfg *Config) PeerURLsMapAndToken(which string) (urlsmap types.URLsMap, tok
 		urlsmap[cfg.Name] = cfg.APUrls
 		token = cfg.Durl
 	case cfg.DNSCluster != "":
-		var clusterStr string
-		clusterStr, err = discovery.SRVGetCluster(cfg.Name, cfg.DNSCluster, cfg.APUrls)
-		if err != nil {
-			return nil, "", err
+		clusterStrs, cerr := srv.GetCluster("etcd-server", cfg.Name, cfg.DNSCluster, cfg.APUrls)
+		if cerr != nil {
+			plog.Errorf("couldn't resolve during SRV discovery (%v)", cerr)
+			return nil, "", cerr
+		}
+		for _, s := range clusterStrs {
+			plog.Noticef("got bootstrap from DNS for etcd-server at %s", s)
 		}
+		clusterStr := strings.Join(clusterStrs, ",")
 		if strings.Contains(clusterStr, "https://") && cfg.PeerTLSInfo.CAFile == "" {
 			cfg.PeerTLSInfo.ServerName = cfg.DNSCluster
 		}

+ 3 - 2
etcdmain/util.go

@@ -18,7 +18,7 @@ import (
 	"fmt"
 	"os"
 
-	"github.com/coreos/etcd/client"
+	"github.com/coreos/etcd/pkg/srv"
 	"github.com/coreos/etcd/pkg/transport"
 )
 
@@ -26,11 +26,12 @@ func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string
 	if dns == "" {
 		return nil
 	}
-	endpoints, err := client.NewSRVDiscover().Discover(dns)
+	srvs, err := srv.GetClient("etcd-client", dns)
 	if err != nil {
 		fmt.Fprintln(os.Stderr, err)
 		os.Exit(1)
 	}
+	endpoints = srvs.Endpoints
 	plog.Infof("discovered the cluster %s from %s", endpoints, dns)
 	if insecure {
 		return endpoints

+ 57 - 21
discovery/srv.go → pkg/srv/srv.go

@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-package discovery
+package srv
 
 import (
 	"fmt"
@@ -25,14 +25,13 @@ import (
 
 var (
 	// indirection for testing
-	lookupSRV      = net.LookupSRV
+	lookupSRV      = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict
 	resolveTCPAddr = net.ResolveTCPAddr
 )
 
-// SRVGetCluster gets the cluster information via DNS discovery.
-// TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap)
+// GetCluster gets the cluster information via DNS discovery.
 // Also sees each entry as a separate instance.
-func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) {
+func GetCluster(service, name, dns string, apurls types.URLs) ([]string, error) {
 	tempName := int(0)
 	tcp2ap := make(map[string]url.URL)
 
@@ -40,8 +39,7 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) {
 	for _, url := range apurls {
 		tcpAddr, err := resolveTCPAddr("tcp", url.Host)
 		if err != nil {
-			plog.Errorf("couldn't resolve host %s during SRV discovery", url.Host)
-			return "", err
+			return nil, err
 		}
 		tcp2ap[tcpAddr.String()] = url
 	}
@@ -55,9 +53,9 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) {
 		for _, srv := range addrs {
 			port := fmt.Sprintf("%d", srv.Port)
 			host := net.JoinHostPort(srv.Target, port)
-			tcpAddr, err := resolveTCPAddr("tcp", host)
-			if err != nil {
-				plog.Warningf("couldn't resolve host %s during SRV discovery", host)
+			tcpAddr, terr := resolveTCPAddr("tcp", host)
+			if terr != nil {
+				terr = err
 				continue
 			}
 			n := ""
@@ -73,31 +71,69 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) {
 			shortHost := strings.TrimSuffix(srv.Target, ".")
 			urlHost := net.JoinHostPort(shortHost, port)
 			stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost))
-			plog.Noticef("got bootstrap from DNS for %s at %s://%s", service, scheme, urlHost)
 			if ok && url.Scheme != scheme {
-				plog.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String())
+				err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String())
 			}
 		}
+		if len(stringParts) == 0 {
+			return err
+		}
 		return nil
 	}
 
 	failCount := 0
-	err := updateNodeMap("etcd-server-ssl", "https")
+	err := updateNodeMap(service+"-ssl", "https")
 	srvErr := make([]string, 2)
 	if err != nil {
-		srvErr[0] = fmt.Sprintf("error querying DNS SRV records for _etcd-server-ssl %s", err)
+		srvErr[0] = fmt.Sprintf("error querying DNS SRV records for _%s-ssl %s", service, err)
 		failCount++
 	}
-	err = updateNodeMap("etcd-server", "http")
+	err = updateNodeMap(service, "http")
 	if err != nil {
-		srvErr[1] = fmt.Sprintf("error querying DNS SRV records for _etcd-server %s", err)
+		srvErr[1] = fmt.Sprintf("error querying DNS SRV records for _%s %s", service, err)
 		failCount++
 	}
 	if failCount == 2 {
-		plog.Warningf(srvErr[0])
-		plog.Warningf(srvErr[1])
-		plog.Errorf("SRV discovery failed: too many errors querying DNS SRV records")
-		return "", err
+		return nil, fmt.Errorf("srv: too many errors querying DNS SRV records (%q, %q)", srvErr[0], srvErr[1])
+	}
+	return stringParts, nil
+}
+
+type SRVClients struct {
+	Endpoints []string
+	SRVs      []*net.SRV
+}
+
+// GetClient looks up the client endpoints for a service and domain.
+func GetClient(service, domain string) (*SRVClients, error) {
+	var urls []*url.URL
+	var srvs []*net.SRV
+
+	updateURLs := func(service, scheme string) error {
+		_, addrs, err := lookupSRV(service, "tcp", domain)
+		if err != nil {
+			return err
+		}
+		for _, srv := range addrs {
+			urls = append(urls, &url.URL{
+				Scheme: scheme,
+				Host:   net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)),
+			})
+		}
+		srvs = append(srvs, addrs...)
+		return nil
+	}
+
+	errHTTPS := updateURLs(service+"-ssl", "https")
+	errHTTP := updateURLs(service, "http")
+
+	if errHTTPS != nil && errHTTP != nil {
+		return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
+	}
+
+	endpoints := make([]string, len(urls))
+	for i := range urls {
+		endpoints[i] = urls[i].String()
 	}
-	return strings.Join(stringParts, ","), nil
+	return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil
 }

+ 82 - 3
discovery/srv_test.go → pkg/srv/srv_test.go

@@ -12,11 +12,12 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-package discovery
+package srv
 
 import (
 	"errors"
 	"net"
+	"reflect"
 	"strings"
 	"testing"
 
@@ -110,12 +111,90 @@ func TestSRVGetCluster(t *testing.T) {
 			return "", nil, errors.New("Unknown service in mock")
 		}
 		urls := testutil.MustNewURLs(t, tt.urls)
-		str, err := SRVGetCluster(name, "example.com", urls)
+		str, err := GetCluster("etcd-server", name, "example.com", urls)
 		if err != nil {
 			t.Fatalf("%d: err: %#v", i, err)
 		}
-		if str != tt.expected {
+		if strings.Join(str, ",") != tt.expected {
 			t.Errorf("#%d: cluster = %s, want %s", i, str, tt.expected)
 		}
 	}
 }
+
+func TestSRVDiscover(t *testing.T) {
+	defer func() { lookupSRV = net.LookupSRV }()
+
+	tests := []struct {
+		withSSL    []*net.SRV
+		withoutSSL []*net.SRV
+		expected   []string
+	}{
+		{
+			[]*net.SRV{},
+			[]*net.SRV{},
+			[]string{},
+		},
+		{
+			[]*net.SRV{
+				{Target: "10.0.0.1", Port: 2480},
+				{Target: "10.0.0.2", Port: 2480},
+				{Target: "10.0.0.3", Port: 2480},
+			},
+			[]*net.SRV{},
+			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480"},
+		},
+		{
+			[]*net.SRV{
+				{Target: "10.0.0.1", Port: 2480},
+				{Target: "10.0.0.2", Port: 2480},
+				{Target: "10.0.0.3", Port: 2480},
+			},
+			[]*net.SRV{
+				{Target: "10.0.0.1", Port: 7001},
+			},
+			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"},
+		},
+		{
+			[]*net.SRV{
+				{Target: "10.0.0.1", Port: 2480},
+				{Target: "10.0.0.2", Port: 2480},
+				{Target: "10.0.0.3", Port: 2480},
+			},
+			[]*net.SRV{
+				{Target: "10.0.0.1", Port: 7001},
+			},
+			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"},
+		},
+		{
+			[]*net.SRV{
+				{Target: "a.example.com", Port: 2480},
+				{Target: "b.example.com", Port: 2480},
+				{Target: "c.example.com", Port: 2480},
+			},
+			[]*net.SRV{},
+			[]string{"https://a.example.com:2480", "https://b.example.com:2480", "https://c.example.com:2480"},
+		},
+	}
+
+	for i, tt := range tests {
+		lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) {
+			if service == "etcd-client-ssl" {
+				return "", tt.withSSL, nil
+			}
+			if service == "etcd-client" {
+				return "", tt.withoutSSL, nil
+			}
+			return "", nil, errors.New("Unknown service in mock")
+		}
+
+		srvs, err := GetClient("etcd-client", "example.com")
+		if err != nil {
+			t.Fatalf("%d: err: %#v", i, err)
+		}
+
+		if !reflect.DeepEqual(srvs.Endpoints, tt.expected) {
+			t.Errorf("#%d: endpoints = %v, want %v", i, srvs.Endpoints, tt.expected)
+		}
+
+	}
+}