Forráskód Böngészése

netutil: use "context" and ctx-ize TCP addr resolution

Anthony Romano 8 éve
szülő
commit
85e87e8f6b
2 módosított fájl, 35 hozzáadás és 10 törlés
  1. 30 4
      pkg/netutil/netutil.go
  2. 5 6
      pkg/netutil/netutil_test.go

+ 30 - 4
pkg/netutil/netutil.go

@@ -16,14 +16,13 @@
 package netutil
 
 import (
+	"context"
 	"net"
 	"net/url"
 	"reflect"
 	"sort"
 	"time"
 
-	"golang.org/x/net/context"
-
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/pkg/capnslog"
 )
@@ -32,11 +31,38 @@ var (
 	plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "pkg/netutil")
 
 	// indirection for testing
-	resolveTCPAddr = net.ResolveTCPAddr
+	resolveTCPAddr = resolveTCPAddrDefault
 )
 
 const retryInterval = time.Second
 
+// taken from go's ResolveTCP code but uses configurable ctx
+func resolveTCPAddrDefault(ctx context.Context, addr string) (*net.TCPAddr, error) {
+	host, port, serr := net.SplitHostPort(addr)
+	if serr != nil {
+		return nil, serr
+	}
+	portnum, perr := net.DefaultResolver.LookupPort(ctx, "tcp", port)
+	if perr != nil {
+		return nil, perr
+	}
+
+	var ips []net.IPAddr
+	if ip := net.ParseIP(host); ip != nil {
+		ips = []net.IPAddr{{IP: ip}}
+	} else {
+		// Try as a DNS name.
+		ipss, err := net.DefaultResolver.LookupIPAddr(ctx, host)
+		if err != nil {
+			return nil, err
+		}
+		ips = ipss
+	}
+	// randomize?
+	ip := ips[0]
+	return &net.TCPAddr{IP: ip.IP, Port: portnum, Zone: ip.Zone}, nil
+}
+
 // resolveTCPAddrs is a convenience wrapper for net.ResolveTCPAddr.
 // resolveTCPAddrs return a new set of url.URLs, in which all DNS hostnames
 // are resolved.
@@ -75,7 +101,7 @@ func resolveURL(ctx context.Context, u url.URL) (string, error) {
 		if host == "localhost" || net.ParseIP(host) != nil {
 			return "", nil
 		}
-		tcpAddr, err := resolveTCPAddr("tcp", u.Host)
+		tcpAddr, err := resolveTCPAddr(ctx, u.Host)
 		if err == nil {
 			plog.Infof("resolving %s to %s", u.Host, tcpAddr.String())
 			return tcpAddr.String(), nil

+ 5 - 6
pkg/netutil/netutil_test.go

@@ -15,6 +15,7 @@
 package netutil
 
 import (
+	"context"
 	"errors"
 	"net"
 	"net/url"
@@ -22,12 +23,10 @@ import (
 	"strconv"
 	"testing"
 	"time"
-
-	"golang.org/x/net/context"
 )
 
 func TestResolveTCPAddrs(t *testing.T) {
-	defer func() { resolveTCPAddr = net.ResolveTCPAddr }()
+	defer func() { resolveTCPAddr = resolveTCPAddrDefault }()
 	tests := []struct {
 		urls     [][]url.URL
 		expected [][]url.URL
@@ -113,7 +112,7 @@ func TestResolveTCPAddrs(t *testing.T) {
 		},
 	}
 	for _, tt := range tests {
-		resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
+		resolveTCPAddr = func(ctx context.Context, addr string) (*net.TCPAddr, error) {
 			host, port, err := net.SplitHostPort(addr)
 			if err != nil {
 				return nil, err
@@ -143,13 +142,13 @@ func TestResolveTCPAddrs(t *testing.T) {
 }
 
 func TestURLsEqual(t *testing.T) {
-	defer func() { resolveTCPAddr = net.ResolveTCPAddr }()
+	defer func() { resolveTCPAddr = resolveTCPAddrDefault }()
 	hostm := map[string]string{
 		"example.com": "10.0.10.1",
 		"first.com":   "10.0.11.1",
 		"second.com":  "10.0.11.2",
 	}
-	resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
+	resolveTCPAddr = func(ctx context.Context, addr string) (*net.TCPAddr, error) {
 		host, port, herr := net.SplitHostPort(addr)
 		if herr != nil {
 			return nil, herr