Browse Source

discovery: add ability to proxy discovery requests

Jonathan Boulle 11 năm trước cách đây
mục cha
commit
7f8f371b0e
2 tập tin đã thay đổi với 94 bổ sung1 xóa
  1. 37 1
      discovery/discovery.go
  2. 57 0
      discovery/discovery_test.go

+ 37 - 1
discovery/discovery.go

@@ -6,6 +6,7 @@ import (
 	"log"
 	"net/http"
 	"net/url"
+	"os"
 	"path"
 	"sort"
 	"strconv"
@@ -26,6 +27,8 @@ var (
 )
 
 const (
+	// Environment variable used to configure an HTTP proxy for discovery
+	DiscoveryProxyEnv = "ETCD_DISCOVERY_PROXY"
 	// Number of retries discovery will attempt before giving up and erroring out.
 	nRetries = uint(3)
 )
@@ -46,6 +49,35 @@ type discovery struct {
 	timeoutTimescale time.Duration
 }
 
+// proxyFuncFromEnv builds a proxy function if the appropriate environment
+// variable is set. It performs basic sanitization of the environment variable
+// and returns any error encountered.
+func proxyFuncFromEnv() (func(*http.Request) (*url.URL, error), error) {
+	proxy := os.Getenv(DiscoveryProxyEnv)
+	if proxy == "" {
+		return nil, nil
+	}
+	// Do a small amount of URL sanitization to help the user
+	// Derived from net/http.ProxyFromEnvironment
+	proxyURL, err := url.Parse(proxy)
+	if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") {
+		// proxy was bogus. Try prepending "http://" to it and
+		// see if that parses correctly. If not, we ignore the
+		// error and complain about the original one
+		var err2 error
+		proxyURL, err2 = url.Parse("http://" + proxy)
+		if err2 == nil {
+			err = nil
+		}
+	}
+	if err != nil {
+		return nil, fmt.Errorf("invalid proxy address %q: %v", proxy, err)
+	}
+
+	log.Printf("discovery: using proxy %q", proxyURL.String())
+	return http.ProxyURL(proxyURL), nil
+}
+
 func New(durl string, id uint64, config string) (Discoverer, error) {
 	u, err := url.Parse(durl)
 	if err != nil {
@@ -53,7 +85,11 @@ func New(durl string, id uint64, config string) (Discoverer, error) {
 	}
 	token := u.Path
 	u.Path = ""
-	c, err := client.NewHTTPClient(&http.Transport{}, u.String(), time.Second*5)
+	pf, err := proxyFuncFromEnv()
+	if err != nil {
+		return nil, err
+	}
+	c, err := client.NewHTTPClient(&http.Transport{Proxy: pf}, u.String(), time.Second*5)
 	if err != nil {
 		return nil, err
 	}

+ 57 - 0
discovery/discovery_test.go

@@ -3,6 +3,8 @@ package discovery
 import (
 	"errors"
 	"math/rand"
+	"net/http"
+	"os"
 	"sort"
 	"strconv"
 
@@ -13,6 +15,61 @@ import (
 	"github.com/coreos/etcd/client"
 )
 
+func TestProxyFuncFromEnvUnset(t *testing.T) {
+	os.Setenv(DiscoveryProxyEnv, "")
+	pf, err := proxyFuncFromEnv()
+	if pf != nil {
+		t.Fatal("unexpected non-nil proxyFunc")
+	}
+	if err != nil {
+		t.Fatalf("unexpected non-nil err: %v", err)
+	}
+}
+
+func TestProxyFuncFromEnvBad(t *testing.T) {
+	tests := []string{
+		"%%",
+		"http://foo.com/%1",
+	}
+	for i, in := range tests {
+		os.Setenv(DiscoveryProxyEnv, in)
+		pf, err := proxyFuncFromEnv()
+		if pf != nil {
+			t.Errorf("#%d: unexpected non-nil proxyFunc", i)
+		}
+		if err == nil {
+			t.Errorf("#%d: unexpected nil err", i)
+		}
+	}
+}
+
+func TestProxyFuncFromEnv(t *testing.T) {
+	tests := map[string]string{
+		"bar.com":              "http://bar.com",
+		"http://disco.foo.bar": "http://disco.foo.bar",
+	}
+	for in, w := range tests {
+		os.Setenv(DiscoveryProxyEnv, in)
+		pf, err := proxyFuncFromEnv()
+		if pf == nil {
+			t.Errorf("%s: unexpected nil proxyFunc", in)
+			continue
+		}
+		if err != nil {
+			t.Errorf("%s: unexpected non-nil err: %v", in, err)
+			continue
+		}
+		g, err := pf(&http.Request{})
+		if err != nil {
+			t.Errorf("%s: unexpected non-nil err: %v", in, err)
+		}
+		if g.String() != w {
+			t.Errorf("%s: proxyURL=%q, want %q", g, w)
+		}
+
+	}
+}
+
 func TestCheckCluster(t *testing.T) {
 	cluster := "1000"
 	self := "/1000/1"