Browse Source

Merge pull request #1055 from bcwaldon/proxy-monster

proxy mode, pt II
Brian Waldon 11 years ago
parent
commit
1037e7ce55
6 changed files with 531 additions and 119 deletions
  1. 106 0
      proxy/director.go
  2. 61 0
      proxy/director_test.go
  3. 17 48
      proxy/proxy.go
  4. 0 71
      proxy/proxy_test.go
  5. 120 0
      proxy/reverse.go
  6. 227 0
      proxy/reverse_test.go

+ 106 - 0
proxy/director.go

@@ -0,0 +1,106 @@
+package proxy
+
+import (
+	"errors"
+	"fmt"
+	"log"
+	"net/url"
+	"sync"
+	"time"
+)
+
+const (
+	// amount of time an endpoint will be held in a failed
+	// state before being reconsidered for proxied requests
+	endpointFailureWait = 5 * time.Second
+)
+
+func newDirector(urls []string) (*director, error) {
+	if len(urls) == 0 {
+		return nil, errors.New("one or more endpoints required")
+	}
+
+	endpoints := make([]*endpoint, len(urls))
+	for i, v := range urls {
+		u, err := url.Parse(v)
+		if err != nil {
+			return nil, fmt.Errorf("invalid endpoint %q: %v", v, err)
+		}
+
+		if u.Scheme == "" {
+			return nil, fmt.Errorf("invalid endpoint %q: scheme required", v)
+		}
+
+		if u.Host == "" {
+			return nil, fmt.Errorf("invalid endpoint %q: host empty", v)
+		}
+
+		endpoints[i] = newEndpoint(*u)
+	}
+
+	d := director{ep: endpoints}
+	return &d, nil
+}
+
+type director struct {
+	ep []*endpoint
+}
+
+func (d *director) endpoints() []*endpoint {
+	filtered := make([]*endpoint, 0)
+	for _, ep := range d.ep {
+		if ep.Available {
+			filtered = append(filtered, ep)
+		}
+	}
+
+	return filtered
+}
+
+func newEndpoint(u url.URL) *endpoint {
+	ep := endpoint{
+		URL:       u,
+		Available: true,
+		failFunc:  timedUnavailabilityFunc(endpointFailureWait),
+	}
+
+	return &ep
+}
+
+type endpoint struct {
+	sync.Mutex
+
+	URL       url.URL
+	Available bool
+
+	failFunc func(ep *endpoint)
+}
+
+func (ep *endpoint) Failed() {
+	ep.Lock()
+	if !ep.Available {
+		ep.Unlock()
+		return
+	}
+
+	ep.Available = false
+	ep.Unlock()
+
+	log.Printf("proxy: marked endpoint %s unavailable", ep.URL.String())
+
+	if ep.failFunc == nil {
+		log.Printf("proxy: no failFunc defined, endpoint %s will be unavailable forever.", ep.URL.String())
+		return
+	}
+
+	ep.failFunc(ep)
+}
+
+func timedUnavailabilityFunc(wait time.Duration) func(*endpoint) {
+	return func(ep *endpoint) {
+		time.AfterFunc(wait, func() {
+			ep.Available = true
+			log.Printf("proxy: marked endpoint %s available", ep.URL.String())
+		})
+	}
+}

+ 61 - 0
proxy/director_test.go

@@ -0,0 +1,61 @@
+package proxy
+
+import (
+	"net/url"
+	"reflect"
+	"testing"
+)
+
+func TestNewDirectorEndpointValidation(t *testing.T) {
+	tests := []struct {
+		good      bool
+		endpoints []string
+	}{
+		{true, []string{"http://192.0.2.8"}},
+		{true, []string{"http://192.0.2.8:8001"}},
+		{true, []string{"http://example.com"}},
+		{true, []string{"http://example.com:8001"}},
+		{true, []string{"http://192.0.2.8:8001", "http://example.com:8002"}},
+
+		{false, []string{"://"}},
+		{false, []string{"http://"}},
+		{false, []string{"192.0.2.8"}},
+		{false, []string{"192.0.2.8:8001"}},
+		{false, []string{""}},
+		{false, []string{}},
+	}
+
+	for i, tt := range tests {
+		_, err := newDirector(tt.endpoints)
+		if tt.good != (err == nil) {
+			t.Errorf("#%d: expected success = %t, got err = %v", i, tt.good, err)
+		}
+	}
+}
+
+func TestDirectorEndpointsFiltering(t *testing.T) {
+	d := director{
+		ep: []*endpoint{
+			&endpoint{
+				URL:       url.URL{Scheme: "http", Host: "192.0.2.5:5050"},
+				Available: false,
+			},
+			&endpoint{
+				URL:       url.URL{Scheme: "http", Host: "192.0.2.4:4000"},
+				Available: true,
+			},
+		},
+	}
+
+	got := d.endpoints()
+	want := []*endpoint{
+		&endpoint{
+			URL:       url.URL{Scheme: "http", Host: "192.0.2.4:4000"},
+			Available: true,
+		},
+	}
+
+	if !reflect.DeepEqual(want, got) {
+		t.Fatalf("directed to incorrect endpoint: want = %#v, got = %#v", want, got)
+	}
+}

+ 17 - 48
proxy/proxy.go

@@ -1,64 +1,33 @@
 package proxy
 
 import (
-	"errors"
-	"fmt"
+	"net"
 	"net/http"
-	"net/http/httputil"
-	"net/url"
+	"time"
 )
 
-func NewHandler(endpoints []string) (*httputil.ReverseProxy, error) {
+const (
+	dialTimeout           = 30 * time.Second
+	responseHeaderTimeout = 30 * time.Second
+)
+
+func NewHandler(endpoints []string) (http.Handler, error) {
 	d, err := newDirector(endpoints)
 	if err != nil {
 		return nil, err
 	}
 
-	proxy := httputil.ReverseProxy{
-		Director:      d.direct,
-		Transport:     &http.Transport{},
-		FlushInterval: 0,
+	tr := http.Transport{
+		Dial: func(network, address string) (net.Conn, error) {
+			return net.DialTimeout(network, address, dialTimeout)
+		},
+		ResponseHeaderTimeout: responseHeaderTimeout,
 	}
 
-	return &proxy, nil
-}
-
-func newDirector(endpoints []string) (*director, error) {
-	if len(endpoints) == 0 {
-		return nil, errors.New("one or more endpoints required")
+	rp := reverseProxy{
+		director:  d,
+		transport: &tr,
 	}
 
-	urls := make([]url.URL, len(endpoints))
-	for i, e := range endpoints {
-		u, err := url.Parse(e)
-		if err != nil {
-			return nil, fmt.Errorf("invalid endpoint %q: %v", e, err)
-		}
-
-		if u.Scheme == "" {
-			return nil, fmt.Errorf("invalid endpoint %q: scheme required", e)
-		}
-
-		if u.Host == "" {
-			return nil, fmt.Errorf("invalid endpoint %q: host empty", e)
-		}
-
-		urls[i] = *u
-	}
-
-	d := director{
-		endpoints: urls,
-	}
-
-	return &d, nil
-}
-
-type director struct {
-	endpoints []url.URL
-}
-
-func (d *director) direct(req *http.Request) {
-	choice := d.endpoints[0]
-	req.URL.Scheme = choice.Scheme
-	req.URL.Host = choice.Host
+	return &rp, nil
 }

+ 0 - 71
proxy/proxy_test.go

@@ -1,71 +0,0 @@
-package proxy
-
-import (
-	"net/http"
-	"net/url"
-	"reflect"
-	"testing"
-)
-
-func TestNewDirector(t *testing.T) {
-	tests := []struct {
-		good      bool
-		endpoints []string
-	}{
-		{true, []string{"http://192.0.2.8"}},
-		{true, []string{"http://192.0.2.8:8001"}},
-		{true, []string{"http://example.com"}},
-		{true, []string{"http://example.com:8001"}},
-		{true, []string{"http://192.0.2.8:8001", "http://example.com:8002"}},
-
-		{false, []string{"192.0.2.8"}},
-		{false, []string{"192.0.2.8:8001"}},
-		{false, []string{""}},
-	}
-
-	for i, tt := range tests {
-		_, err := newDirector(tt.endpoints)
-		if tt.good != (err == nil) {
-			t.Errorf("#%d: expected success = %t, got err = %v", i, tt.good, err)
-		}
-	}
-}
-
-func TestDirectorDirect(t *testing.T) {
-	d := &director{
-		endpoints: []url.URL{
-			url.URL{
-				Scheme: "http",
-				Host:   "bar.example.com",
-			},
-		},
-	}
-
-	req := &http.Request{
-		Method: "GET",
-		Host:   "foo.example.com",
-		URL: &url.URL{
-			Host: "foo.example.com",
-			Path: "/v2/keys/baz",
-		},
-	}
-
-	d.direct(req)
-
-	want := &http.Request{
-		Method: "GET",
-		// this field must not change
-		Host: "foo.example.com",
-		URL: &url.URL{
-			// the Scheme field is updated per the director's first endpoint
-			Scheme: "http",
-			// the Host field is updated per the director's first endpoint
-			Host: "bar.example.com",
-			Path: "/v2/keys/baz",
-		},
-	}
-
-	if !reflect.DeepEqual(want, req) {
-		t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
-	}
-}

+ 120 - 0
proxy/reverse.go

@@ -0,0 +1,120 @@
+package proxy
+
+import (
+	"io"
+	"log"
+	"net"
+	"net/http"
+	"net/url"
+	"strings"
+)
+
+// Hop-by-hop headers. These are removed when sent to the backend.
+// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
+// This list of headers borrowed from stdlib httputil.ReverseProxy
+var singleHopHeaders = []string{
+	"Connection",
+	"Keep-Alive",
+	"Proxy-Authenticate",
+	"Proxy-Authorization",
+	"Te", // canonicalized version of "TE"
+	"Trailers",
+	"Transfer-Encoding",
+	"Upgrade",
+}
+
+func removeSingleHopHeaders(hdrs *http.Header) {
+	for _, h := range singleHopHeaders {
+		hdrs.Del(h)
+	}
+}
+
+type reverseProxy struct {
+	director  *director
+	transport http.RoundTripper
+}
+
+func (p *reverseProxy) ServeHTTP(rw http.ResponseWriter, clientreq *http.Request) {
+	proxyreq := new(http.Request)
+	*proxyreq = *clientreq
+
+	// deep-copy the headers, as these will be modified below
+	proxyreq.Header = make(http.Header)
+	copyHeader(proxyreq.Header, clientreq.Header)
+
+	normalizeRequest(proxyreq)
+	removeSingleHopHeaders(&proxyreq.Header)
+	maybeSetForwardedFor(proxyreq)
+
+	endpoints := p.director.endpoints()
+	if len(endpoints) == 0 {
+		log.Printf("proxy: zero endpoints currently available")
+		rw.WriteHeader(http.StatusServiceUnavailable)
+		return
+	}
+
+	var res *http.Response
+	var err error
+
+	for _, ep := range endpoints {
+		redirectRequest(proxyreq, ep.URL)
+
+		res, err = p.transport.RoundTrip(proxyreq)
+		if err != nil {
+			log.Printf("proxy: failed to direct request to %s: %v", ep.URL.String(), err)
+			ep.Failed()
+			continue
+		}
+
+		break
+	}
+
+	if res == nil {
+		log.Printf("proxy: unable to get response from %d endpoint(s)", len(endpoints))
+		rw.WriteHeader(http.StatusBadGateway)
+		return
+	}
+
+	defer res.Body.Close()
+
+	removeSingleHopHeaders(&res.Header)
+	copyHeader(rw.Header(), res.Header)
+
+	rw.WriteHeader(res.StatusCode)
+	io.Copy(rw, res.Body)
+}
+
+func copyHeader(dst, src http.Header) {
+	for k, vv := range src {
+		for _, v := range vv {
+			dst.Add(k, v)
+		}
+	}
+}
+
+func redirectRequest(req *http.Request, loc url.URL) {
+	req.URL.Scheme = loc.Scheme
+	req.URL.Host = loc.Host
+}
+
+func normalizeRequest(req *http.Request) {
+	req.Proto = "HTTP/1.1"
+	req.ProtoMajor = 1
+	req.ProtoMinor = 1
+	req.Close = false
+}
+
+func maybeSetForwardedFor(req *http.Request) {
+	clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
+	if err != nil {
+		return
+	}
+
+	// If we aren't the first proxy retain prior
+	// X-Forwarded-For information as a comma+space
+	// separated list and fold multiple headers into one.
+	if prior, ok := req.Header["X-Forwarded-For"]; ok {
+		clientIP = strings.Join(prior, ", ") + ", " + clientIP
+	}
+	req.Header.Set("X-Forwarded-For", clientIP)
+}

+ 227 - 0
proxy/reverse_test.go

@@ -0,0 +1,227 @@
+package proxy
+
+import (
+	"bytes"
+	"errors"
+	"io/ioutil"
+	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"reflect"
+	"testing"
+)
+
+type staticRoundTripper struct {
+	res *http.Response
+	err error
+}
+
+func (srt *staticRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+	return srt.res, srt.err
+}
+
+func TestReverseProxyServe(t *testing.T) {
+	u := url.URL{Scheme: "http", Host: "192.0.2.3:4040"}
+
+	tests := []struct {
+		eps  []*endpoint
+		rt   http.RoundTripper
+		want int
+	}{
+		// no endpoints available so no requests are even made
+		{
+			eps: []*endpoint{},
+			rt: &staticRoundTripper{
+				res: &http.Response{
+					StatusCode: http.StatusCreated,
+					Body:       ioutil.NopCloser(&bytes.Reader{}),
+				},
+			},
+			want: http.StatusServiceUnavailable,
+		},
+
+		// error is returned from one endpoint that should be available
+		{
+			eps:  []*endpoint{&endpoint{URL: u, Available: true}},
+			rt:   &staticRoundTripper{err: errors.New("what a bad trip")},
+			want: http.StatusBadGateway,
+		},
+
+		// endpoint is available and returns success
+		{
+			eps: []*endpoint{&endpoint{URL: u, Available: true}},
+			rt: &staticRoundTripper{
+				res: &http.Response{
+					StatusCode: http.StatusCreated,
+					Body:       ioutil.NopCloser(&bytes.Reader{}),
+				},
+			},
+			want: http.StatusCreated,
+		},
+	}
+
+	for i, tt := range tests {
+		rp := reverseProxy{
+			director:  &director{tt.eps},
+			transport: tt.rt,
+		}
+
+		req, _ := http.NewRequest("GET", "http://192.0.2.2:4001", nil)
+		rr := httptest.NewRecorder()
+		rp.ServeHTTP(rr, req)
+
+		if rr.Code != tt.want {
+			t.Errorf("#%d: unexpected HTTP status code: want = %d, got = %d", i, tt.want, rr.Code)
+		}
+	}
+}
+
+func TestRedirectRequest(t *testing.T) {
+	loc := url.URL{
+		Scheme: "http",
+		Host:   "bar.example.com",
+	}
+
+	req := &http.Request{
+		Method: "GET",
+		Host:   "foo.example.com",
+		URL: &url.URL{
+			Host: "foo.example.com",
+			Path: "/v2/keys/baz",
+		},
+	}
+
+	redirectRequest(req, loc)
+
+	want := &http.Request{
+		Method: "GET",
+		// this field must not change
+		Host: "foo.example.com",
+		URL: &url.URL{
+			// the Scheme field is updated to that of the provided URL
+			Scheme: "http",
+			// the Host field is updated to that of the provided URL
+			Host: "bar.example.com",
+			Path: "/v2/keys/baz",
+		},
+	}
+
+	if !reflect.DeepEqual(want, req) {
+		t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
+	}
+}
+
+func TestMaybeSetForwardedFor(t *testing.T) {
+	tests := []struct {
+		raddr  string
+		fwdFor string
+		want   string
+	}{
+		{"192.0.2.3:8002", "", "192.0.2.3"},
+		{"192.0.2.3:8002", "192.0.2.2", "192.0.2.2, 192.0.2.3"},
+		{"192.0.2.3:8002", "192.0.2.1, 192.0.2.2", "192.0.2.1, 192.0.2.2, 192.0.2.3"},
+		{"example.com:8002", "", "example.com"},
+
+		// While these cases look valid, golang net/http will not let it happen
+		// The RemoteAddr field will always be a valid host:port
+		{":8002", "", ""},
+		{"192.0.2.3", "", ""},
+
+		// blatantly invalid host w/o a port
+		{"12", "", ""},
+		{"12", "192.0.2.3", "192.0.2.3"},
+	}
+
+	for i, tt := range tests {
+		req := &http.Request{
+			RemoteAddr: tt.raddr,
+			Header:     make(http.Header),
+		}
+
+		if tt.fwdFor != "" {
+			req.Header.Set("X-Forwarded-For", tt.fwdFor)
+		}
+
+		maybeSetForwardedFor(req)
+		got := req.Header.Get("X-Forwarded-For")
+		if tt.want != got {
+			t.Errorf("#%d: incorrect header: want = %q, got = %q", i, tt.want, got)
+		}
+	}
+}
+
+func TestRemoveSingleHopHeaders(t *testing.T) {
+	hdr := http.Header(map[string][]string{
+		// single-hop headers that should be removed
+		"Connection":          []string{"close"},
+		"Keep-Alive":          []string{"foo"},
+		"Proxy-Authenticate":  []string{"Basic realm=example.com"},
+		"Proxy-Authorization": []string{"foo"},
+		"Te":                []string{"deflate,gzip"},
+		"Trailers":          []string{"ETag"},
+		"Transfer-Encoding": []string{"chunked"},
+		"Upgrade":           []string{"WebSocket"},
+
+		// headers that should persist
+		"Accept": []string{"application/json"},
+		"X-Foo":  []string{"Bar"},
+	})
+
+	removeSingleHopHeaders(&hdr)
+
+	want := http.Header(map[string][]string{
+		"Accept": []string{"application/json"},
+		"X-Foo":  []string{"Bar"},
+	})
+
+	if !reflect.DeepEqual(want, hdr) {
+		t.Fatalf("unexpected result: want = %#v, got = %#v", want, hdr)
+	}
+}
+
+func TestCopyHeader(t *testing.T) {
+	tests := []struct {
+		src  http.Header
+		dst  http.Header
+		want http.Header
+	}{
+		{
+			src: http.Header(map[string][]string{
+				"Foo": []string{"bar", "baz"},
+			}),
+			dst: http.Header(map[string][]string{}),
+			want: http.Header(map[string][]string{
+				"Foo": []string{"bar", "baz"},
+			}),
+		},
+		{
+			src: http.Header(map[string][]string{
+				"Foo":  []string{"bar"},
+				"Ping": []string{"pong"},
+			}),
+			dst: http.Header(map[string][]string{}),
+			want: http.Header(map[string][]string{
+				"Foo":  []string{"bar"},
+				"Ping": []string{"pong"},
+			}),
+		},
+		{
+			src: http.Header(map[string][]string{
+				"Foo": []string{"bar", "baz"},
+			}),
+			dst: http.Header(map[string][]string{
+				"Foo": []string{"qux"},
+			}),
+			want: http.Header(map[string][]string{
+				"Foo": []string{"qux", "bar", "baz"},
+			}),
+		},
+	}
+
+	for i, tt := range tests {
+		copyHeader(tt.dst, tt.src)
+		if !reflect.DeepEqual(tt.dst, tt.want) {
+			t.Errorf("#%d: unexpected headers: want = %v, got = %v", i, tt.want, tt.dst)
+		}
+	}
+}