Переглянути джерело

discovery: add a test case for srv

During srv discovery, it should try to match local member with
resolved addr and return unresolved hostnames for the cluster.
Xiang Li 11 роки тому
батько
коміт
f5d4c86153
2 змінених файлів з 36 додано та 5 видалено
  1. 4 3
      discovery/srv.go
  2. 32 2
      discovery/srv_test.go

+ 4 - 3
discovery/srv.go

@@ -25,7 +25,8 @@ import (
 
 var (
 	// indirection for testing
-	lookupSRV = net.LookupSRV
+	lookupSRV      = net.LookupSRV
+	resolveTCPAddr = net.ResolveTCPAddr
 )
 
 // TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap)
@@ -38,7 +39,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 
 	// First, resolve the apurls
 	for _, url := range apurls {
-		tcpAddr, err := net.ResolveTCPAddr("tcp", url.Host)
+		tcpAddr, err := resolveTCPAddr("tcp", url.Host)
 		if err != nil {
 			log.Printf("discovery: Couldn't resolve host %s during SRV discovery", url.Host)
 			return "", "", err
@@ -53,7 +54,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 		}
 		for _, srv := range addrs {
 			host := net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port))
-			tcpAddr, err := net.ResolveTCPAddr("tcp", host)
+			tcpAddr, err := resolveTCPAddr("tcp", host)
 			if err != nil {
 				log.Printf("discovery: Couldn't resolve host %s during SRV discovery", host)
 				continue

+ 32 - 2
discovery/srv_test.go

@@ -23,19 +23,26 @@ import (
 )
 
 func TestSRVGetCluster(t *testing.T) {
-	defer func() { lookupSRV = net.LookupSRV }()
+	defer func() {
+		lookupSRV = net.LookupSRV
+		resolveTCPAddr = net.ResolveTCPAddr
+	}()
 
 	name := "dnsClusterTest"
 	tests := []struct {
 		withSSL    []*net.SRV
 		withoutSSL []*net.SRV
 		urls       []string
-		expected   string
+		dns        map[string]string
+
+		expected string
 	}{
 		{
 			[]*net.SRV{},
 			[]*net.SRV{},
 			nil,
+			nil,
+
 			"",
 		},
 		{
@@ -46,6 +53,8 @@ func TestSRVGetCluster(t *testing.T) {
 			},
 			[]*net.SRV{},
 			nil,
+			nil,
+
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480",
 		},
 		{
@@ -58,6 +67,7 @@ func TestSRVGetCluster(t *testing.T) {
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 			},
 			nil,
+			nil,
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380",
 		},
 		{
@@ -70,8 +80,22 @@ func TestSRVGetCluster(t *testing.T) {
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 			},
 			[]string{"https://10.0.0.1:2480"},
+			nil,
 			"dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380",
 		},
+		// matching local member with resolved addr and return unresolved hostnames
+		{
+			[]*net.SRV{
+				&net.SRV{Target: "1.example.com.", Port: 2480},
+				&net.SRV{Target: "2.example.com.", Port: 2480},
+				&net.SRV{Target: "3.example.com.", Port: 2480},
+			},
+			nil,
+			[]string{"https://10.0.0.1:2480"},
+			map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"},
+
+			"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480",
+		},
 	}
 
 	for i, tt := range tests {
@@ -84,6 +108,12 @@ func TestSRVGetCluster(t *testing.T) {
 			}
 			return "", nil, errors.New("Unkown service in mock")
 		}
+		resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
+			if tt.dns == nil || tt.dns[addr] == "" {
+				return net.ResolveTCPAddr(network, addr)
+			}
+			return net.ResolveTCPAddr(network, tt.dns[addr])
+		}
 		urls := testutil.MustNewURLs(t, tt.urls)
 		str, token, err := SRVGetCluster(name, "example.com", "token", urls)
 		if err != nil {