Browse Source

Merge pull request #7882 from heyitsanthony/srv-priority

gateway: DNS SRV priority
Anthony Romano 8 years ago
parent
commit
aac2292ab5

+ 19 - 0
client/discover.go

@@ -14,8 +14,27 @@
 
 
 package client
 package client
 
 
+import (
+	"github.com/coreos/etcd/pkg/srv"
+)
+
 // Discoverer is an interface that wraps the Discover method.
 // Discoverer is an interface that wraps the Discover method.
 type Discoverer interface {
 type Discoverer interface {
 	// Discover looks up the etcd servers for the domain.
 	// Discover looks up the etcd servers for the domain.
 	Discover(domain string) ([]string, error)
 	Discover(domain string) ([]string, error)
 }
 }
+
+type srvDiscover struct{}
+
+// NewSRVDiscover constructs a new Discoverer that uses the stdlib to lookup SRV records.
+func NewSRVDiscover() Discoverer {
+	return &srvDiscover{}
+}
+
+func (d *srvDiscover) Discover(domain string) ([]string, error) {
+	srvs, err := srv.GetClient("etcd-client", domain)
+	if err != nil {
+		return nil, err
+	}
+	return srvs.Endpoints, nil
+}

+ 0 - 65
client/srv.go

@@ -1,65 +0,0 @@
-// Copyright 2015 The etcd Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package client
-
-import (
-	"fmt"
-	"net"
-	"net/url"
-)
-
-var (
-	// indirection for testing
-	lookupSRV = net.LookupSRV
-)
-
-type srvDiscover struct{}
-
-// NewSRVDiscover constructs a new Discoverer that uses the stdlib to lookup SRV records.
-func NewSRVDiscover() Discoverer {
-	return &srvDiscover{}
-}
-
-// Discover looks up the etcd servers for the domain.
-func (d *srvDiscover) Discover(domain string) ([]string, error) {
-	var urls []*url.URL
-
-	updateURLs := func(service, scheme string) error {
-		_, addrs, err := lookupSRV(service, "tcp", domain)
-		if err != nil {
-			return err
-		}
-		for _, srv := range addrs {
-			urls = append(urls, &url.URL{
-				Scheme: scheme,
-				Host:   net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)),
-			})
-		}
-		return nil
-	}
-
-	errHTTPS := updateURLs("etcd-client-ssl", "https")
-	errHTTP := updateURLs("etcd-client", "http")
-
-	if errHTTPS != nil && errHTTP != nil {
-		return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
-	}
-
-	endpoints := make([]string, len(urls))
-	for i := range urls {
-		endpoints[i] = urls[i].String()
-	}
-	return endpoints, nil
-}

+ 0 - 102
client/srv_test.go

@@ -1,102 +0,0 @@
-// Copyright 2015 The etcd Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package client
-
-import (
-	"errors"
-	"net"
-	"reflect"
-	"testing"
-)
-
-func TestSRVDiscover(t *testing.T) {
-	defer func() { lookupSRV = net.LookupSRV }()
-
-	tests := []struct {
-		withSSL    []*net.SRV
-		withoutSSL []*net.SRV
-		expected   []string
-	}{
-		{
-			[]*net.SRV{},
-			[]*net.SRV{},
-			[]string{},
-		},
-		{
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2480},
-				{Target: "10.0.0.2", Port: 2480},
-				{Target: "10.0.0.3", Port: 2480},
-			},
-			[]*net.SRV{},
-			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480"},
-		},
-		{
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2480},
-				{Target: "10.0.0.2", Port: 2480},
-				{Target: "10.0.0.3", Port: 2480},
-			},
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 7001},
-			},
-			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"},
-		},
-		{
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2480},
-				{Target: "10.0.0.2", Port: 2480},
-				{Target: "10.0.0.3", Port: 2480},
-			},
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 7001},
-			},
-			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"},
-		},
-		{
-			[]*net.SRV{
-				{Target: "a.example.com", Port: 2480},
-				{Target: "b.example.com", Port: 2480},
-				{Target: "c.example.com", Port: 2480},
-			},
-			[]*net.SRV{},
-			[]string{"https://a.example.com:2480", "https://b.example.com:2480", "https://c.example.com:2480"},
-		},
-	}
-
-	for i, tt := range tests {
-		lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) {
-			if service == "etcd-client-ssl" {
-				return "", tt.withSSL, nil
-			}
-			if service == "etcd-client" {
-				return "", tt.withoutSSL, nil
-			}
-			return "", nil, errors.New("Unknown service in mock")
-		}
-
-		d := NewSRVDiscover()
-
-		endpoints, err := d.Discover("example.com")
-		if err != nil {
-			t.Fatalf("%d: err: %#v", i, err)
-		}
-
-		if !reflect.DeepEqual(endpoints, tt.expected) {
-			t.Errorf("#%d: endpoints = %v, want %v", i, endpoints, tt.expected)
-		}
-
-	}
-}

+ 9 - 5
embed/config.go

@@ -22,10 +22,10 @@ import (
 	"net/url"
 	"net/url"
 	"strings"
 	"strings"
 
 
-	"github.com/coreos/etcd/discovery"
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/pkg/cors"
 	"github.com/coreos/etcd/pkg/cors"
 	"github.com/coreos/etcd/pkg/netutil"
 	"github.com/coreos/etcd/pkg/netutil"
+	"github.com/coreos/etcd/pkg/srv"
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/types"
 
 
@@ -321,11 +321,15 @@ func (cfg *Config) PeerURLsMapAndToken(which string) (urlsmap types.URLsMap, tok
 		urlsmap[cfg.Name] = cfg.APUrls
 		urlsmap[cfg.Name] = cfg.APUrls
 		token = cfg.Durl
 		token = cfg.Durl
 	case cfg.DNSCluster != "":
 	case cfg.DNSCluster != "":
-		var clusterStr string
-		clusterStr, err = discovery.SRVGetCluster(cfg.Name, cfg.DNSCluster, cfg.APUrls)
-		if err != nil {
-			return nil, "", err
+		clusterStrs, cerr := srv.GetCluster("etcd-server", cfg.Name, cfg.DNSCluster, cfg.APUrls)
+		if cerr != nil {
+			plog.Errorf("couldn't resolve during SRV discovery (%v)", cerr)
+			return nil, "", cerr
+		}
+		for _, s := range clusterStrs {
+			plog.Noticef("got bootstrap from DNS for etcd-server at %s", s)
 		}
 		}
+		clusterStr := strings.Join(clusterStrs, ",")
 		if strings.Contains(clusterStr, "https://") && cfg.PeerTLSInfo.CAFile == "" {
 		if strings.Contains(clusterStr, "https://") && cfg.PeerTLSInfo.CAFile == "" {
 			cfg.PeerTLSInfo.ServerName = cfg.DNSCluster
 			cfg.PeerTLSInfo.ServerName = cfg.DNSCluster
 		}
 		}

+ 19 - 8
etcdmain/gateway.go

@@ -91,17 +91,28 @@ func stripSchema(eps []string) []string {
 
 
 	return endpoints
 	return endpoints
 }
 }
-func startGateway(cmd *cobra.Command, args []string) {
-	endpoints := gatewayEndpoints
 
 
-	if eps := discoverEndpoints(gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery); len(eps) != 0 {
-		endpoints = eps
+func startGateway(cmd *cobra.Command, args []string) {
+	srvs := discoverEndpoints(gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery)
+	if len(srvs.Endpoints) == 0 {
+		// no endpoints discovered, fall back to provided endpoints
+		srvs.Endpoints = gatewayEndpoints
 	}
 	}
-
 	// Strip the schema from the endpoints because we start just a TCP proxy
 	// Strip the schema from the endpoints because we start just a TCP proxy
-	endpoints = stripSchema(endpoints)
+	srvs.Endpoints = stripSchema(srvs.Endpoints)
+	if len(srvs.SRVs) == 0 {
+		for _, ep := range srvs.Endpoints {
+			h, p, err := net.SplitHostPort(ep)
+			if err != nil {
+				plog.Fatalf("error parsing endpoint %q", ep)
+			}
+			var port uint16
+			fmt.Sscanf(p, "%d", &port)
+			srvs.SRVs = append(srvs.SRVs, &net.SRV{Target: h, Port: port})
+		}
+	}
 
 
-	if len(endpoints) == 0 {
+	if len(srvs.Endpoints) == 0 {
 		plog.Fatalf("no endpoints found")
 		plog.Fatalf("no endpoints found")
 	}
 	}
 
 
@@ -113,7 +124,7 @@ func startGateway(cmd *cobra.Command, args []string) {
 
 
 	tp := tcpproxy.TCPProxy{
 	tp := tcpproxy.TCPProxy{
 		Listener:        l,
 		Listener:        l,
-		Endpoints:       endpoints,
+		Endpoints:       srvs.SRVs,
 		MonitorInterval: getewayRetryDelay,
 		MonitorInterval: getewayRetryDelay,
 	}
 	}
 
 

+ 3 - 2
etcdmain/grpc_proxy.go

@@ -106,8 +106,9 @@ func startGRPCProxy(cmd *cobra.Command, args []string) {
 		os.Exit(1)
 		os.Exit(1)
 	}
 	}
 
 
-	if eps := discoverEndpoints(grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery); len(eps) != 0 {
-		grpcProxyEndpoints = eps
+	srvs := discoverEndpoints(grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery)
+	if len(srvs.Endpoints) != 0 {
+		grpcProxyEndpoints = srvs.Endpoints
 	}
 	}
 
 
 	l, err := net.Listen("tcp", grpcProxyListenAddr)
 	l, err := net.Listen("tcp", grpcProxyListenAddr)

+ 21 - 6
etcdmain/util.go

@@ -18,22 +18,23 @@ import (
 	"fmt"
 	"fmt"
 	"os"
 	"os"
 
 
-	"github.com/coreos/etcd/client"
+	"github.com/coreos/etcd/pkg/srv"
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/transport"
 )
 )
 
 
-func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string) {
+func discoverEndpoints(dns string, ca string, insecure bool) (s srv.SRVClients) {
 	if dns == "" {
 	if dns == "" {
-		return nil
+		return s
 	}
 	}
-	endpoints, err := client.NewSRVDiscover().Discover(dns)
+	srvs, err := srv.GetClient("etcd-client", dns)
 	if err != nil {
 	if err != nil {
 		fmt.Fprintln(os.Stderr, err)
 		fmt.Fprintln(os.Stderr, err)
 		os.Exit(1)
 		os.Exit(1)
 	}
 	}
+	endpoints := srvs.Endpoints
 	plog.Infof("discovered the cluster %s from %s", endpoints, dns)
 	plog.Infof("discovered the cluster %s from %s", endpoints, dns)
 	if insecure {
 	if insecure {
-		return endpoints
+		return *srvs
 	}
 	}
 	// confirm TLS connections are good
 	// confirm TLS connections are good
 	tlsInfo := transport.TLSInfo{
 	tlsInfo := transport.TLSInfo{
@@ -46,5 +47,19 @@ func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string
 		plog.Warningf("%v", err)
 		plog.Warningf("%v", err)
 	}
 	}
 	plog.Infof("using discovered endpoints %v", endpoints)
 	plog.Infof("using discovered endpoints %v", endpoints)
-	return endpoints
+
+	// map endpoints back to SRVClients struct with SRV data
+	eps := make(map[string]struct{})
+	for _, ep := range endpoints {
+		eps[ep] = struct{}{}
+	}
+	for i := range srvs.Endpoints {
+		if _, ok := eps[srvs.Endpoints[i]]; !ok {
+			continue
+		}
+		s.Endpoints = append(s.Endpoints, srvs.Endpoints[i])
+		s.SRVs = append(s.SRVs, srvs.SRVs[i])
+	}
+
+	return s
 }
 }

+ 57 - 21
discovery/srv.go → pkg/srv/srv.go

@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // See the License for the specific language governing permissions and
 // limitations under the License.
 // limitations under the License.
 
 
-package discovery
+package srv
 
 
 import (
 import (
 	"fmt"
 	"fmt"
@@ -25,14 +25,13 @@ import (
 
 
 var (
 var (
 	// indirection for testing
 	// indirection for testing
-	lookupSRV      = net.LookupSRV
+	lookupSRV      = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict
 	resolveTCPAddr = net.ResolveTCPAddr
 	resolveTCPAddr = net.ResolveTCPAddr
 )
 )
 
 
-// SRVGetCluster gets the cluster information via DNS discovery.
-// TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap)
+// GetCluster gets the cluster information via DNS discovery.
 // Also sees each entry as a separate instance.
 // Also sees each entry as a separate instance.
-func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) {
+func GetCluster(service, name, dns string, apurls types.URLs) ([]string, error) {
 	tempName := int(0)
 	tempName := int(0)
 	tcp2ap := make(map[string]url.URL)
 	tcp2ap := make(map[string]url.URL)
 
 
@@ -40,8 +39,7 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) {
 	for _, url := range apurls {
 	for _, url := range apurls {
 		tcpAddr, err := resolveTCPAddr("tcp", url.Host)
 		tcpAddr, err := resolveTCPAddr("tcp", url.Host)
 		if err != nil {
 		if err != nil {
-			plog.Errorf("couldn't resolve host %s during SRV discovery", url.Host)
-			return "", err
+			return nil, err
 		}
 		}
 		tcp2ap[tcpAddr.String()] = url
 		tcp2ap[tcpAddr.String()] = url
 	}
 	}
@@ -55,9 +53,9 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) {
 		for _, srv := range addrs {
 		for _, srv := range addrs {
 			port := fmt.Sprintf("%d", srv.Port)
 			port := fmt.Sprintf("%d", srv.Port)
 			host := net.JoinHostPort(srv.Target, port)
 			host := net.JoinHostPort(srv.Target, port)
-			tcpAddr, err := resolveTCPAddr("tcp", host)
-			if err != nil {
-				plog.Warningf("couldn't resolve host %s during SRV discovery", host)
+			tcpAddr, terr := resolveTCPAddr("tcp", host)
+			if terr != nil {
+				terr = err
 				continue
 				continue
 			}
 			}
 			n := ""
 			n := ""
@@ -73,31 +71,69 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) {
 			shortHost := strings.TrimSuffix(srv.Target, ".")
 			shortHost := strings.TrimSuffix(srv.Target, ".")
 			urlHost := net.JoinHostPort(shortHost, port)
 			urlHost := net.JoinHostPort(shortHost, port)
 			stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost))
 			stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost))
-			plog.Noticef("got bootstrap from DNS for %s at %s://%s", service, scheme, urlHost)
 			if ok && url.Scheme != scheme {
 			if ok && url.Scheme != scheme {
-				plog.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String())
+				err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String())
 			}
 			}
 		}
 		}
+		if len(stringParts) == 0 {
+			return err
+		}
 		return nil
 		return nil
 	}
 	}
 
 
 	failCount := 0
 	failCount := 0
-	err := updateNodeMap("etcd-server-ssl", "https")
+	err := updateNodeMap(service+"-ssl", "https")
 	srvErr := make([]string, 2)
 	srvErr := make([]string, 2)
 	if err != nil {
 	if err != nil {
-		srvErr[0] = fmt.Sprintf("error querying DNS SRV records for _etcd-server-ssl %s", err)
+		srvErr[0] = fmt.Sprintf("error querying DNS SRV records for _%s-ssl %s", service, err)
 		failCount++
 		failCount++
 	}
 	}
-	err = updateNodeMap("etcd-server", "http")
+	err = updateNodeMap(service, "http")
 	if err != nil {
 	if err != nil {
-		srvErr[1] = fmt.Sprintf("error querying DNS SRV records for _etcd-server %s", err)
+		srvErr[1] = fmt.Sprintf("error querying DNS SRV records for _%s %s", service, err)
 		failCount++
 		failCount++
 	}
 	}
 	if failCount == 2 {
 	if failCount == 2 {
-		plog.Warningf(srvErr[0])
-		plog.Warningf(srvErr[1])
-		plog.Errorf("SRV discovery failed: too many errors querying DNS SRV records")
-		return "", err
+		return nil, fmt.Errorf("srv: too many errors querying DNS SRV records (%q, %q)", srvErr[0], srvErr[1])
+	}
+	return stringParts, nil
+}
+
+type SRVClients struct {
+	Endpoints []string
+	SRVs      []*net.SRV
+}
+
+// GetClient looks up the client endpoints for a service and domain.
+func GetClient(service, domain string) (*SRVClients, error) {
+	var urls []*url.URL
+	var srvs []*net.SRV
+
+	updateURLs := func(service, scheme string) error {
+		_, addrs, err := lookupSRV(service, "tcp", domain)
+		if err != nil {
+			return err
+		}
+		for _, srv := range addrs {
+			urls = append(urls, &url.URL{
+				Scheme: scheme,
+				Host:   net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)),
+			})
+		}
+		srvs = append(srvs, addrs...)
+		return nil
+	}
+
+	errHTTPS := updateURLs(service+"-ssl", "https")
+	errHTTP := updateURLs(service, "http")
+
+	if errHTTPS != nil && errHTTP != nil {
+		return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
+	}
+
+	endpoints := make([]string, len(urls))
+	for i := range urls {
+		endpoints[i] = urls[i].String()
 	}
 	}
-	return strings.Join(stringParts, ","), nil
+	return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil
 }
 }

+ 82 - 3
discovery/srv_test.go → pkg/srv/srv_test.go

@@ -12,11 +12,12 @@
 // See the License for the specific language governing permissions and
 // See the License for the specific language governing permissions and
 // limitations under the License.
 // limitations under the License.
 
 
-package discovery
+package srv
 
 
 import (
 import (
 	"errors"
 	"errors"
 	"net"
 	"net"
+	"reflect"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
 
 
@@ -110,12 +111,90 @@ func TestSRVGetCluster(t *testing.T) {
 			return "", nil, errors.New("Unknown service in mock")
 			return "", nil, errors.New("Unknown service in mock")
 		}
 		}
 		urls := testutil.MustNewURLs(t, tt.urls)
 		urls := testutil.MustNewURLs(t, tt.urls)
-		str, err := SRVGetCluster(name, "example.com", urls)
+		str, err := GetCluster("etcd-server", name, "example.com", urls)
 		if err != nil {
 		if err != nil {
 			t.Fatalf("%d: err: %#v", i, err)
 			t.Fatalf("%d: err: %#v", i, err)
 		}
 		}
-		if str != tt.expected {
+		if strings.Join(str, ",") != tt.expected {
 			t.Errorf("#%d: cluster = %s, want %s", i, str, tt.expected)
 			t.Errorf("#%d: cluster = %s, want %s", i, str, tt.expected)
 		}
 		}
 	}
 	}
 }
 }
+
+func TestSRVDiscover(t *testing.T) {
+	defer func() { lookupSRV = net.LookupSRV }()
+
+	tests := []struct {
+		withSSL    []*net.SRV
+		withoutSSL []*net.SRV
+		expected   []string
+	}{
+		{
+			[]*net.SRV{},
+			[]*net.SRV{},
+			[]string{},
+		},
+		{
+			[]*net.SRV{
+				{Target: "10.0.0.1", Port: 2480},
+				{Target: "10.0.0.2", Port: 2480},
+				{Target: "10.0.0.3", Port: 2480},
+			},
+			[]*net.SRV{},
+			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480"},
+		},
+		{
+			[]*net.SRV{
+				{Target: "10.0.0.1", Port: 2480},
+				{Target: "10.0.0.2", Port: 2480},
+				{Target: "10.0.0.3", Port: 2480},
+			},
+			[]*net.SRV{
+				{Target: "10.0.0.1", Port: 7001},
+			},
+			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"},
+		},
+		{
+			[]*net.SRV{
+				{Target: "10.0.0.1", Port: 2480},
+				{Target: "10.0.0.2", Port: 2480},
+				{Target: "10.0.0.3", Port: 2480},
+			},
+			[]*net.SRV{
+				{Target: "10.0.0.1", Port: 7001},
+			},
+			[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"},
+		},
+		{
+			[]*net.SRV{
+				{Target: "a.example.com", Port: 2480},
+				{Target: "b.example.com", Port: 2480},
+				{Target: "c.example.com", Port: 2480},
+			},
+			[]*net.SRV{},
+			[]string{"https://a.example.com:2480", "https://b.example.com:2480", "https://c.example.com:2480"},
+		},
+	}
+
+	for i, tt := range tests {
+		lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) {
+			if service == "etcd-client-ssl" {
+				return "", tt.withSSL, nil
+			}
+			if service == "etcd-client" {
+				return "", tt.withoutSSL, nil
+			}
+			return "", nil, errors.New("Unknown service in mock")
+		}
+
+		srvs, err := GetClient("etcd-client", "example.com")
+		if err != nil {
+			t.Fatalf("%d: err: %#v", i, err)
+		}
+
+		if !reflect.DeepEqual(srvs.Endpoints, tt.expected) {
+			t.Errorf("#%d: endpoints = %v, want %v", i, srvs.Endpoints, tt.expected)
+		}
+
+	}
+}

+ 71 - 24
proxy/tcpproxy/userspace.go

@@ -15,7 +15,9 @@
 package tcpproxy
 package tcpproxy
 
 
 import (
 import (
+	"fmt"
 	"io"
 	"io"
+	"math/rand"
 	"net"
 	"net"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -29,6 +31,7 @@ var (
 
 
 type remote struct {
 type remote struct {
 	mu       sync.Mutex
 	mu       sync.Mutex
+	srv      *net.SRV
 	addr     string
 	addr     string
 	inactive bool
 	inactive bool
 }
 }
@@ -59,14 +62,14 @@ func (r *remote) isActive() bool {
 
 
 type TCPProxy struct {
 type TCPProxy struct {
 	Listener        net.Listener
 	Listener        net.Listener
-	Endpoints       []string
+	Endpoints       []*net.SRV
 	MonitorInterval time.Duration
 	MonitorInterval time.Duration
 
 
 	donec chan struct{}
 	donec chan struct{}
 
 
-	mu         sync.Mutex // guards the following fields
-	remotes    []*remote
-	nextRemote int
+	mu        sync.Mutex // guards the following fields
+	remotes   []*remote
+	pickCount int // for round robin
 }
 }
 
 
 func (tp *TCPProxy) Run() error {
 func (tp *TCPProxy) Run() error {
@@ -74,11 +77,12 @@ func (tp *TCPProxy) Run() error {
 	if tp.MonitorInterval == 0 {
 	if tp.MonitorInterval == 0 {
 		tp.MonitorInterval = 5 * time.Minute
 		tp.MonitorInterval = 5 * time.Minute
 	}
 	}
-	for _, ep := range tp.Endpoints {
-		tp.remotes = append(tp.remotes, &remote{addr: ep})
+	for _, srv := range tp.Endpoints {
+		addr := fmt.Sprintf("%s:%d", srv.Target, srv.Port)
+		tp.remotes = append(tp.remotes, &remote{srv: srv, addr: addr})
 	}
 	}
 
 
-	plog.Printf("ready to proxy client requests to %v", tp.Endpoints)
+	plog.Printf("ready to proxy client requests to %+v", tp.Endpoints)
 	go tp.runMonitor()
 	go tp.runMonitor()
 	for {
 	for {
 		in, err := tp.Listener.Accept()
 		in, err := tp.Listener.Accept()
@@ -90,10 +94,61 @@ func (tp *TCPProxy) Run() error {
 	}
 	}
 }
 }
 
 
-func (tp *TCPProxy) numRemotes() int {
-	tp.mu.Lock()
-	defer tp.mu.Unlock()
-	return len(tp.remotes)
+func (tp *TCPProxy) pick() *remote {
+	var weighted []*remote
+	var unweighted []*remote
+
+	bestPr := uint16(65535)
+	w := 0
+	// find best priority class
+	for _, r := range tp.remotes {
+		switch {
+		case !r.isActive():
+		case r.srv.Priority < bestPr:
+			bestPr = r.srv.Priority
+			w = 0
+			weighted, unweighted = nil, nil
+			unweighted = []*remote{r}
+			fallthrough
+		case r.srv.Priority == bestPr:
+			if r.srv.Weight > 0 {
+				weighted = append(weighted, r)
+				w += int(r.srv.Weight)
+			} else {
+				unweighted = append(unweighted, r)
+			}
+		}
+	}
+	if weighted != nil {
+		if len(unweighted) > 0 && rand.Intn(100) == 1 {
+			// In the presence of records containing weights greater
+			// than 0, records with weight 0 should have a very small
+			// chance of being selected.
+			r := unweighted[tp.pickCount%len(unweighted)]
+			tp.pickCount++
+			return r
+		}
+		// choose a uniform random number between 0 and the sum computed
+		// (inclusive), and select the RR whose running sum value is the
+		// first in the selected order
+		choose := rand.Intn(w)
+		for i := 0; i < len(weighted); i++ {
+			choose -= int(weighted[i].srv.Weight)
+			if choose <= 0 {
+				return weighted[i]
+			}
+		}
+	}
+	if unweighted != nil {
+		for i := 0; i < len(tp.remotes); i++ {
+			picked := tp.remotes[tp.pickCount%len(tp.remotes)]
+			tp.pickCount++
+			if picked.isActive() {
+				return picked
+			}
+		}
+	}
+	return nil
 }
 }
 
 
 func (tp *TCPProxy) serve(in net.Conn) {
 func (tp *TCPProxy) serve(in net.Conn) {
@@ -102,10 +157,12 @@ func (tp *TCPProxy) serve(in net.Conn) {
 		out net.Conn
 		out net.Conn
 	)
 	)
 
 
-	for i := 0; i < tp.numRemotes(); i++ {
+	for {
+		tp.mu.Lock()
 		remote := tp.pick()
 		remote := tp.pick()
-		if !remote.isActive() {
-			continue
+		tp.mu.Unlock()
+		if remote == nil {
+			break
 		}
 		}
 		// TODO: add timeout
 		// TODO: add timeout
 		out, err = net.Dial("tcp", remote.addr)
 		out, err = net.Dial("tcp", remote.addr)
@@ -132,16 +189,6 @@ func (tp *TCPProxy) serve(in net.Conn) {
 	in.Close()
 	in.Close()
 }
 }
 
 
-// pick picks a remote in round-robin fashion
-func (tp *TCPProxy) pick() *remote {
-	tp.mu.Lock()
-	defer tp.mu.Unlock()
-
-	picked := tp.remotes[tp.nextRemote]
-	tp.nextRemote = (tp.nextRemote + 1) % len(tp.remotes)
-	return picked
-}
-
 func (tp *TCPProxy) runMonitor() {
 func (tp *TCPProxy) runMonitor() {
 	for {
 	for {
 		select {
 		select {

+ 3 - 1
proxy/tcpproxy/userspace_test.go

@@ -42,9 +42,11 @@ func TestUserspaceProxy(t *testing.T) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
+	var port uint16
+	fmt.Sscanf(u.Port(), "%d", &port)
 	p := TCPProxy{
 	p := TCPProxy{
 		Listener:  l,
 		Listener:  l,
-		Endpoints: []string{u.Host},
+		Endpoints: []*net.SRV{{Target: u.Hostname(), Port: port}},
 	}
 	}
 	go p.Run()
 	go p.Run()
 	defer p.Stop()
 	defer p.Stop()