Quellcode durchsuchen

discovery: add a test case for srv

During srv discovery, it should try to match local member with
resolved addr and return unresolved hostnames for the cluster.
Xiang Li vor 11 Jahren
Ursprung
Commit
f5d4c86153
2 geänderte Dateien mit 36 neuen und 5 gelöschten Zeilen
  1. 4 3
      discovery/srv.go
  2. 32 2
      discovery/srv_test.go

+ 4 - 3
discovery/srv.go

@@ -25,7 +25,8 @@ import (
 
 var (
 	// indirection for testing
-	lookupSRV = net.LookupSRV
+	lookupSRV      = net.LookupSRV
+	resolveTCPAddr = net.ResolveTCPAddr
 )
 
 // TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap)
@@ -38,7 +39,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 
 	// First, resolve the apurls
 	for _, url := range apurls {
-		tcpAddr, err := net.ResolveTCPAddr("tcp", url.Host)
+		tcpAddr, err := resolveTCPAddr("tcp", url.Host)
 		if err != nil {
 			log.Printf("discovery: Couldn't resolve host %s during SRV discovery", url.Host)
 			return "", "", err
@@ -53,7 +54,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 		}
 		for _, srv := range addrs {
 			host := net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port))
-			tcpAddr, err := net.ResolveTCPAddr("tcp", host)
+			tcpAddr, err := resolveTCPAddr("tcp", host)
 			if err != nil {
 				log.Printf("discovery: Couldn't resolve host %s during SRV discovery", host)
 				continue

+ 32 - 2
discovery/srv_test.go

@@ -23,19 +23,26 @@ import (
 )
 
 func TestSRVGetCluster(t *testing.T) {
-	defer func() { lookupSRV = net.LookupSRV }()
+	defer func() {
+		lookupSRV = net.LookupSRV
+		resolveTCPAddr = net.ResolveTCPAddr
+	}()
 
 	name := "dnsClusterTest"
 	tests := []struct {
 		withSSL    []*net.SRV
 		withoutSSL []*net.SRV
 		urls       []string
-		expected   string
+		dns        map[string]string
+
+		expected string
 	}{
 		{
 			[]*net.SRV{},
 			[]*net.SRV{},
 			nil,
+			nil,
+
 			"",
 		},
 		{
@@ -46,6 +53,8 @@ func TestSRVGetCluster(t *testing.T) {
 			},
 			[]*net.SRV{},
 			nil,
+			nil,
+
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480",
 		},
 		{
@@ -58,6 +67,7 @@ func TestSRVGetCluster(t *testing.T) {
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 			},
 			nil,
+			nil,
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380",
 		},
 		{
@@ -70,8 +80,22 @@ func TestSRVGetCluster(t *testing.T) {
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 			},
 			[]string{"https://10.0.0.1:2480"},
+			nil,
 			"dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380",
 		},
+		// matching local member with resolved addr and return unresolved hostnames
+		{
+			[]*net.SRV{
+				&net.SRV{Target: "1.example.com.", Port: 2480},
+				&net.SRV{Target: "2.example.com.", Port: 2480},
+				&net.SRV{Target: "3.example.com.", Port: 2480},
+			},
+			nil,
+			[]string{"https://10.0.0.1:2480"},
+			map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"},
+
+			"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480",
+		},
 	}
 
 	for i, tt := range tests {
@@ -84,6 +108,12 @@ func TestSRVGetCluster(t *testing.T) {
 			}
 			return "", nil, errors.New("Unkown service in mock")
 		}
+		resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
+			if tt.dns == nil || tt.dns[addr] == "" {
+				return net.ResolveTCPAddr(network, addr)
+			}
+			return net.ResolveTCPAddr(network, tt.dns[addr])
+		}
 		urls := testutil.MustNewURLs(t, tt.urls)
 		str, token, err := SRVGetCluster(name, "example.com", "token", urls)
 		if err != nil {