Browse Source

proxy: rewrite stdlib ReverseProxy

The ReverseProxy code from the standard library doesn't actually
give us the control that we want. Pull it down and rip out what
we don't need, adding tests in the process.

All available endpoints are attempted when proxying a request. If a
proxied request fails, the upstream will be considered unavailable
for 5s and no more requests will be proxied to it. After the 5s is
up, the endpoint will be put back to rotation.
Brian Waldon 11 years ago
parent
commit
a155f0bda6
6 changed files with 531 additions and 119 deletions
  1. 106 0
      proxy/director.go
  2. 61 0
      proxy/director_test.go
  3. 17 48
      proxy/proxy.go
  4. 0 71
      proxy/proxy_test.go
  5. 120 0
      proxy/reverse.go
  6. 227 0
      proxy/reverse_test.go

+ 106 - 0
proxy/director.go

@@ -0,0 +1,106 @@
+package proxy
+
+import (
+	"errors"
+	"fmt"
+	"log"
+	"net/url"
+	"sync"
+	"time"
+)
+
+const (
+	// amount of time an endpoint will be held in a failed
+	// state before being reconsidered for proxied requests
+	endpointFailureWait = 5 * time.Second
+)
+
+func newDirector(urls []string) (*director, error) {
+	if len(urls) == 0 {
+		return nil, errors.New("one or more endpoints required")
+	}
+
+	endpoints := make([]*endpoint, len(urls))
+	for i, v := range urls {
+		u, err := url.Parse(v)
+		if err != nil {
+			return nil, fmt.Errorf("invalid endpoint %q: %v", v, err)
+		}
+
+		if u.Scheme == "" {
+			return nil, fmt.Errorf("invalid endpoint %q: scheme required", v)
+		}
+
+		if u.Host == "" {
+			return nil, fmt.Errorf("invalid endpoint %q: host empty", v)
+		}
+
+		endpoints[i] = newEndpoint(*u)
+	}
+
+	d := director{ep: endpoints}
+	return &d, nil
+}
+
+type director struct {
+	ep []*endpoint
+}
+
+func (d *director) endpoints() []*endpoint {
+	filtered := make([]*endpoint, 0)
+	for _, ep := range d.ep {
+		if ep.Available {
+			filtered = append(filtered, ep)
+		}
+	}
+
+	return filtered
+}
+
+func newEndpoint(u url.URL) *endpoint {
+	ep := endpoint{
+		URL:       u,
+		Available: true,
+		failFunc:  timedUnavailabilityFunc(endpointFailureWait),
+	}
+
+	return &ep
+}
+
+type endpoint struct {
+	sync.Mutex
+
+	URL       url.URL
+	Available bool
+
+	failFunc func(ep *endpoint)
+}
+
+func (ep *endpoint) Failed() {
+	ep.Lock()
+	if !ep.Available {
+		ep.Unlock()
+		return
+	}
+
+	ep.Available = false
+	ep.Unlock()
+
+	log.Printf("proxy: marked endpoint %s unavailable", ep.URL.String())
+
+	if ep.failFunc == nil {
+		log.Printf("proxy: no failFunc defined, endpoint %s will be unavailable forever.", ep.URL.String())
+		return
+	}
+
+	ep.failFunc(ep)
+}
+
+func timedUnavailabilityFunc(wait time.Duration) func(*endpoint) {
+	return func(ep *endpoint) {
+		time.AfterFunc(wait, func() {
+			ep.Available = true
+			log.Printf("proxy: marked endpoint %s available", ep.URL.String())
+		})
+	}
+}

+ 61 - 0
proxy/director_test.go

@@ -0,0 +1,61 @@
+package proxy
+
+import (
+	"net/url"
+	"reflect"
+	"testing"
+)
+
+func TestNewDirectorEndpointValidation(t *testing.T) {
+	tests := []struct {
+		good      bool
+		endpoints []string
+	}{
+		{true, []string{"http://192.0.2.8"}},
+		{true, []string{"http://192.0.2.8:8001"}},
+		{true, []string{"http://example.com"}},
+		{true, []string{"http://example.com:8001"}},
+		{true, []string{"http://192.0.2.8:8001", "http://example.com:8002"}},
+
+		{false, []string{"://"}},
+		{false, []string{"http://"}},
+		{false, []string{"192.0.2.8"}},
+		{false, []string{"192.0.2.8:8001"}},
+		{false, []string{""}},
+		{false, []string{}},
+	}
+
+	for i, tt := range tests {
+		_, err := newDirector(tt.endpoints)
+		if tt.good != (err == nil) {
+			t.Errorf("#%d: expected success = %t, got err = %v", i, tt.good, err)
+		}
+	}
+}
+
+func TestDirectorEndpointsFiltering(t *testing.T) {
+	d := director{
+		ep: []*endpoint{
+			&endpoint{
+				URL:       url.URL{Scheme: "http", Host: "192.0.2.5:5050"},
+				Available: false,
+			},
+			&endpoint{
+				URL:       url.URL{Scheme: "http", Host: "192.0.2.4:4000"},
+				Available: true,
+			},
+		},
+	}
+
+	got := d.endpoints()
+	want := []*endpoint{
+		&endpoint{
+			URL:       url.URL{Scheme: "http", Host: "192.0.2.4:4000"},
+			Available: true,
+		},
+	}
+
+	if !reflect.DeepEqual(want, got) {
+		t.Fatalf("directed to incorrect endpoint: want = %#v, got = %#v", want, got)
+	}
+}

+ 17 - 48
proxy/proxy.go

@@ -1,64 +1,33 @@
 package proxy
 package proxy
 
 
 import (
 import (
-	"errors"
-	"fmt"
+	"net"
 	"net/http"
 	"net/http"
-	"net/http/httputil"
-	"net/url"
+	"time"
 )
 )
 
 
-func NewHandler(endpoints []string) (*httputil.ReverseProxy, error) {
+const (
+	dialTimeout           = 30 * time.Second
+	responseHeaderTimeout = 30 * time.Second
+)
+
+func NewHandler(endpoints []string) (http.Handler, error) {
 	d, err := newDirector(endpoints)
 	d, err := newDirector(endpoints)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	proxy := httputil.ReverseProxy{
-		Director:      d.direct,
-		Transport:     &http.Transport{},
-		FlushInterval: 0,
+	tr := http.Transport{
+		Dial: func(network, address string) (net.Conn, error) {
+			return net.DialTimeout(network, address, dialTimeout)
+		},
+		ResponseHeaderTimeout: responseHeaderTimeout,
 	}
 	}
 
 
-	return &proxy, nil
-}
-
-func newDirector(endpoints []string) (*director, error) {
-	if len(endpoints) == 0 {
-		return nil, errors.New("one or more endpoints required")
+	rp := reverseProxy{
+		director:  d,
+		transport: &tr,
 	}
 	}
 
 
-	urls := make([]url.URL, len(endpoints))
-	for i, e := range endpoints {
-		u, err := url.Parse(e)
-		if err != nil {
-			return nil, fmt.Errorf("invalid endpoint %q: %v", e, err)
-		}
-
-		if u.Scheme == "" {
-			return nil, fmt.Errorf("invalid endpoint %q: scheme required", e)
-		}
-
-		if u.Host == "" {
-			return nil, fmt.Errorf("invalid endpoint %q: host empty", e)
-		}
-
-		urls[i] = *u
-	}
-
-	d := director{
-		endpoints: urls,
-	}
-
-	return &d, nil
-}
-
-type director struct {
-	endpoints []url.URL
-}
-
-func (d *director) direct(req *http.Request) {
-	choice := d.endpoints[0]
-	req.URL.Scheme = choice.Scheme
-	req.URL.Host = choice.Host
+	return &rp, nil
 }
 }

+ 0 - 71
proxy/proxy_test.go

@@ -1,71 +0,0 @@
-package proxy
-
-import (
-	"net/http"
-	"net/url"
-	"reflect"
-	"testing"
-)
-
-func TestNewDirector(t *testing.T) {
-	tests := []struct {
-		good      bool
-		endpoints []string
-	}{
-		{true, []string{"http://192.0.2.8"}},
-		{true, []string{"http://192.0.2.8:8001"}},
-		{true, []string{"http://example.com"}},
-		{true, []string{"http://example.com:8001"}},
-		{true, []string{"http://192.0.2.8:8001", "http://example.com:8002"}},
-
-		{false, []string{"192.0.2.8"}},
-		{false, []string{"192.0.2.8:8001"}},
-		{false, []string{""}},
-	}
-
-	for i, tt := range tests {
-		_, err := newDirector(tt.endpoints)
-		if tt.good != (err == nil) {
-			t.Errorf("#%d: expected success = %t, got err = %v", i, tt.good, err)
-		}
-	}
-}
-
-func TestDirectorDirect(t *testing.T) {
-	d := &director{
-		endpoints: []url.URL{
-			url.URL{
-				Scheme: "http",
-				Host:   "bar.example.com",
-			},
-		},
-	}
-
-	req := &http.Request{
-		Method: "GET",
-		Host:   "foo.example.com",
-		URL: &url.URL{
-			Host: "foo.example.com",
-			Path: "/v2/keys/baz",
-		},
-	}
-
-	d.direct(req)
-
-	want := &http.Request{
-		Method: "GET",
-		// this field must not change
-		Host: "foo.example.com",
-		URL: &url.URL{
-			// the Scheme field is updated per the director's first endpoint
-			Scheme: "http",
-			// the Host field is updated per the director's first endpoint
-			Host: "bar.example.com",
-			Path: "/v2/keys/baz",
-		},
-	}
-
-	if !reflect.DeepEqual(want, req) {
-		t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
-	}
-}

+ 120 - 0
proxy/reverse.go

@@ -0,0 +1,120 @@
+package proxy
+
+import (
+	"io"
+	"log"
+	"net"
+	"net/http"
+	"net/url"
+	"strings"
+)
+
+// Hop-by-hop headers. These are removed when sent to the backend.
+// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
+// This list of headers borrowed from stdlib httputil.ReverseProxy
+var singleHopHeaders = []string{
+	"Connection",
+	"Keep-Alive",
+	"Proxy-Authenticate",
+	"Proxy-Authorization",
+	"Te", // canonicalized version of "TE"
+	"Trailers",
+	"Transfer-Encoding",
+	"Upgrade",
+}
+
+func removeSingleHopHeaders(hdrs *http.Header) {
+	for _, h := range singleHopHeaders {
+		hdrs.Del(h)
+	}
+}
+
+type reverseProxy struct {
+	director  *director
+	transport http.RoundTripper
+}
+
+func (p *reverseProxy) ServeHTTP(rw http.ResponseWriter, clientreq *http.Request) {
+	proxyreq := new(http.Request)
+	*proxyreq = *clientreq
+
+	// deep-copy the headers, as these will be modified below
+	proxyreq.Header = make(http.Header)
+	copyHeader(proxyreq.Header, clientreq.Header)
+
+	normalizeRequest(proxyreq)
+	removeSingleHopHeaders(&proxyreq.Header)
+	maybeSetForwardedFor(proxyreq)
+
+	endpoints := p.director.endpoints()
+	if len(endpoints) == 0 {
+		log.Printf("proxy: zero endpoints currently available")
+		rw.WriteHeader(http.StatusServiceUnavailable)
+		return
+	}
+
+	var res *http.Response
+	var err error
+
+	for _, ep := range endpoints {
+		redirectRequest(proxyreq, ep.URL)
+
+		res, err = p.transport.RoundTrip(proxyreq)
+		if err != nil {
+			log.Printf("proxy: failed to direct request to %s: %v", ep.URL.String(), err)
+			ep.Failed()
+			continue
+		}
+
+		break
+	}
+
+	if res == nil {
+		log.Printf("proxy: unable to get response from %d endpoint(s)", len(endpoints))
+		rw.WriteHeader(http.StatusBadGateway)
+		return
+	}
+
+	defer res.Body.Close()
+
+	removeSingleHopHeaders(&res.Header)
+	copyHeader(rw.Header(), res.Header)
+
+	rw.WriteHeader(res.StatusCode)
+	io.Copy(rw, res.Body)
+}
+
+func copyHeader(dst, src http.Header) {
+	for k, vv := range src {
+		for _, v := range vv {
+			dst.Add(k, v)
+		}
+	}
+}
+
+func redirectRequest(req *http.Request, loc url.URL) {
+	req.URL.Scheme = loc.Scheme
+	req.URL.Host = loc.Host
+}
+
+func normalizeRequest(req *http.Request) {
+	req.Proto = "HTTP/1.1"
+	req.ProtoMajor = 1
+	req.ProtoMinor = 1
+	req.Close = false
+}
+
+func maybeSetForwardedFor(req *http.Request) {
+	clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
+	if err != nil {
+		return
+	}
+
+	// If we aren't the first proxy retain prior
+	// X-Forwarded-For information as a comma+space
+	// separated list and fold multiple headers into one.
+	if prior, ok := req.Header["X-Forwarded-For"]; ok {
+		clientIP = strings.Join(prior, ", ") + ", " + clientIP
+	}
+	req.Header.Set("X-Forwarded-For", clientIP)
+}

+ 227 - 0
proxy/reverse_test.go

@@ -0,0 +1,227 @@
+package proxy
+
+import (
+	"bytes"
+	"errors"
+	"io/ioutil"
+	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"reflect"
+	"testing"
+)
+
+type staticRoundTripper struct {
+	res *http.Response
+	err error
+}
+
+func (srt *staticRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+	return srt.res, srt.err
+}
+
+func TestReverseProxyServe(t *testing.T) {
+	u := url.URL{Scheme: "http", Host: "192.0.2.3:4040"}
+
+	tests := []struct {
+		eps  []*endpoint
+		rt   http.RoundTripper
+		want int
+	}{
+		// no endpoints available so no requests are even made
+		{
+			eps: []*endpoint{},
+			rt: &staticRoundTripper{
+				res: &http.Response{
+					StatusCode: http.StatusCreated,
+					Body:       ioutil.NopCloser(&bytes.Reader{}),
+				},
+			},
+			want: http.StatusServiceUnavailable,
+		},
+
+		// error is returned from one endpoint that should be available
+		{
+			eps:  []*endpoint{&endpoint{URL: u, Available: true}},
+			rt:   &staticRoundTripper{err: errors.New("what a bad trip")},
+			want: http.StatusBadGateway,
+		},
+
+		// endpoint is available and returns success
+		{
+			eps: []*endpoint{&endpoint{URL: u, Available: true}},
+			rt: &staticRoundTripper{
+				res: &http.Response{
+					StatusCode: http.StatusCreated,
+					Body:       ioutil.NopCloser(&bytes.Reader{}),
+				},
+			},
+			want: http.StatusCreated,
+		},
+	}
+
+	for i, tt := range tests {
+		rp := reverseProxy{
+			director:  &director{tt.eps},
+			transport: tt.rt,
+		}
+
+		req, _ := http.NewRequest("GET", "http://192.0.2.2:4001", nil)
+		rr := httptest.NewRecorder()
+		rp.ServeHTTP(rr, req)
+
+		if rr.Code != tt.want {
+			t.Errorf("#%d: unexpected HTTP status code: want = %d, got = %d", i, tt.want, rr.Code)
+		}
+	}
+}
+
+func TestRedirectRequest(t *testing.T) {
+	loc := url.URL{
+		Scheme: "http",
+		Host:   "bar.example.com",
+	}
+
+	req := &http.Request{
+		Method: "GET",
+		Host:   "foo.example.com",
+		URL: &url.URL{
+			Host: "foo.example.com",
+			Path: "/v2/keys/baz",
+		},
+	}
+
+	redirectRequest(req, loc)
+
+	want := &http.Request{
+		Method: "GET",
+		// this field must not change
+		Host: "foo.example.com",
+		URL: &url.URL{
+			// the Scheme field is updated to that of the provided URL
+			Scheme: "http",
+			// the Host field is updated to that of the provided URL
+			Host: "bar.example.com",
+			Path: "/v2/keys/baz",
+		},
+	}
+
+	if !reflect.DeepEqual(want, req) {
+		t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
+	}
+}
+
+func TestMaybeSetForwardedFor(t *testing.T) {
+	tests := []struct {
+		raddr  string
+		fwdFor string
+		want   string
+	}{
+		{"192.0.2.3:8002", "", "192.0.2.3"},
+		{"192.0.2.3:8002", "192.0.2.2", "192.0.2.2, 192.0.2.3"},
+		{"192.0.2.3:8002", "192.0.2.1, 192.0.2.2", "192.0.2.1, 192.0.2.2, 192.0.2.3"},
+		{"example.com:8002", "", "example.com"},
+
+		// While these cases look valid, golang net/http will not let it happen
+		// The RemoteAddr field will always be a valid host:port
+		{":8002", "", ""},
+		{"192.0.2.3", "", ""},
+
+		// blatantly invalid host w/o a port
+		{"12", "", ""},
+		{"12", "192.0.2.3", "192.0.2.3"},
+	}
+
+	for i, tt := range tests {
+		req := &http.Request{
+			RemoteAddr: tt.raddr,
+			Header:     make(http.Header),
+		}
+
+		if tt.fwdFor != "" {
+			req.Header.Set("X-Forwarded-For", tt.fwdFor)
+		}
+
+		maybeSetForwardedFor(req)
+		got := req.Header.Get("X-Forwarded-For")
+		if tt.want != got {
+			t.Errorf("#%d: incorrect header: want = %q, got = %q", i, tt.want, got)
+		}
+	}
+}
+
+func TestRemoveSingleHopHeaders(t *testing.T) {
+	hdr := http.Header(map[string][]string{
+		// single-hop headers that should be removed
+		"Connection":          []string{"close"},
+		"Keep-Alive":          []string{"foo"},
+		"Proxy-Authenticate":  []string{"Basic realm=example.com"},
+		"Proxy-Authorization": []string{"foo"},
+		"Te":                []string{"deflate,gzip"},
+		"Trailers":          []string{"ETag"},
+		"Transfer-Encoding": []string{"chunked"},
+		"Upgrade":           []string{"WebSocket"},
+
+		// headers that should persist
+		"Accept": []string{"application/json"},
+		"X-Foo":  []string{"Bar"},
+	})
+
+	removeSingleHopHeaders(&hdr)
+
+	want := http.Header(map[string][]string{
+		"Accept": []string{"application/json"},
+		"X-Foo":  []string{"Bar"},
+	})
+
+	if !reflect.DeepEqual(want, hdr) {
+		t.Fatalf("unexpected result: want = %#v, got = %#v", want, hdr)
+	}
+}
+
+func TestCopyHeader(t *testing.T) {
+	tests := []struct {
+		src  http.Header
+		dst  http.Header
+		want http.Header
+	}{
+		{
+			src: http.Header(map[string][]string{
+				"Foo": []string{"bar", "baz"},
+			}),
+			dst: http.Header(map[string][]string{}),
+			want: http.Header(map[string][]string{
+				"Foo": []string{"bar", "baz"},
+			}),
+		},
+		{
+			src: http.Header(map[string][]string{
+				"Foo":  []string{"bar"},
+				"Ping": []string{"pong"},
+			}),
+			dst: http.Header(map[string][]string{}),
+			want: http.Header(map[string][]string{
+				"Foo":  []string{"bar"},
+				"Ping": []string{"pong"},
+			}),
+		},
+		{
+			src: http.Header(map[string][]string{
+				"Foo": []string{"bar", "baz"},
+			}),
+			dst: http.Header(map[string][]string{
+				"Foo": []string{"qux"},
+			}),
+			want: http.Header(map[string][]string{
+				"Foo": []string{"qux", "bar", "baz"},
+			}),
+		},
+	}
+
+	for i, tt := range tests {
+		copyHeader(tt.dst, tt.src)
+		if !reflect.DeepEqual(tt.dst, tt.want) {
+			t.Errorf("#%d: unexpected headers: want = %v, got = %v", i, tt.want, tt.dst)
+		}
+	}
+}