Ver Fonte

client: follow redirects

Brian Waldon há 11 anos atrás
pai
commit
6dd4944e62
2 ficheiros alterados com 276 adições e 5 exclusões
  1. 54 5
      client/http.go
  2. 222 0
      client/http_test.go

+ 54 - 5
client/http.go

@@ -17,6 +17,8 @@
 package client
 
 import (
+	"errors"
+	"fmt"
 	"io/ioutil"
 	"net/http"
 	"net/url"
@@ -26,10 +28,12 @@ import (
 )
 
 var (
-	ErrTimeout  = context.DeadlineExceeded
-	ErrCanceled = context.Canceled
+	ErrTimeout          = context.DeadlineExceeded
+	ErrCanceled         = context.Canceled
+	ErrTooManyRedirects = errors.New("too many redirects")
 
 	DefaultRequestTimeout = 5 * time.Second
+	DefaultMaxRedirects   = 10
 )
 
 type SyncableHTTPClient interface {
@@ -69,9 +73,12 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli
 			return nil, err
 		}
 
-		c.endpoints[i] = &httpClient{
-			transport: tr,
-			endpoint:  *u,
+		c.endpoints[i] = &redirectFollowingHTTPClient{
+			max: DefaultMaxRedirects,
+			client: &httpClient{
+				transport: tr,
+				endpoint:  *u,
+			},
 		}
 	}
 
@@ -168,3 +175,45 @@ func (c *httpClient) Do(ctx context.Context, act HTTPAction) (*http.Response, []
 	body, err := ioutil.ReadAll(resp.Body)
 	return resp, body, err
 }
+
+type redirectFollowingHTTPClient struct {
+	client HTTPClient
+	max    int
+}
+
+func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act HTTPAction) (*http.Response, []byte, error) {
+	for i := 0; i <= r.max; i++ {
+		resp, body, err := r.client.Do(ctx, act)
+		if err != nil {
+			return nil, nil, err
+		}
+		if resp.StatusCode/100 == 3 {
+			hdr := resp.Header.Get("Location")
+			if hdr == "" {
+				return nil, nil, fmt.Errorf("Location header not set")
+			}
+			loc, err := url.Parse(hdr)
+			if err != nil {
+				return nil, nil, fmt.Errorf("Location header not valid URL: %s", hdr)
+			}
+			act = &redirectedHTTPAction{
+				action:   act,
+				location: *loc,
+			}
+			continue
+		}
+		return resp, body, nil
+	}
+	return nil, nil, ErrTooManyRedirects
+}
+
+type redirectedHTTPAction struct {
+	action   HTTPAction
+	location url.URL
+}
+
+func (r *redirectedHTTPAction) HTTPRequest(ep url.URL) *http.Request {
+	orig := r.action.HTTPRequest(ep)
+	orig.URL = &r.location
+	return orig
+}

+ 222 - 0
client/http_test.go

@@ -38,6 +38,30 @@ func (s *staticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []by
 	return &s.resp, nil, s.err
 }
 
+type staticHTTPAction struct {
+	request http.Request
+}
+
+type staticHTTPResponse struct {
+	resp http.Response
+	err  error
+}
+
+func (s *staticHTTPAction) HTTPRequest(url.URL) *http.Request {
+	return &s.request
+}
+
+type multiStaticHTTPClient struct {
+	responses []staticHTTPResponse
+	cur       int
+}
+
+func (s *multiStaticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []byte, error) {
+	r := s.responses[s.cur]
+	s.cur++
+	return &r.resp, nil, r.err
+}
+
 type fakeTransport struct {
 	respchan     chan *http.Response
 	errchan      chan error
@@ -253,3 +277,201 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		}
 	}
 }
+
+func TestRedirectedHTTPAction(t *testing.T) {
+	act := &redirectedHTTPAction{
+		action: &staticHTTPAction{
+			request: http.Request{
+				Method: "DELETE",
+				URL: &url.URL{
+					Scheme: "https",
+					Host:   "foo.example.com",
+					Path:   "/ping",
+				},
+			},
+		},
+		location: url.URL{
+			Scheme: "https",
+			Host:   "bar.example.com",
+			Path:   "/pong",
+		},
+	}
+
+	want := &http.Request{
+		Method: "DELETE",
+		URL: &url.URL{
+			Scheme: "https",
+			Host:   "bar.example.com",
+			Path:   "/pong",
+		},
+	}
+	got := act.HTTPRequest(url.URL{Scheme: "http", Host: "baz.example.com", Path: "/pang"})
+
+	if !reflect.DeepEqual(want, got) {
+		t.Fatalf("HTTPRequest is %#v, want %#v", want, got)
+	}
+}
+
+func TestRedirectFollowingHTTPClient(t *testing.T) {
+	tests := []struct {
+		max      int
+		client   HTTPClient
+		wantCode int
+		wantErr  error
+	}{
+		// errors bubbled up
+		{
+			max: 2,
+			client: &multiStaticHTTPClient{
+				responses: []staticHTTPResponse{
+					staticHTTPResponse{
+						err: errors.New("fail!"),
+					},
+				},
+			},
+			wantErr: errors.New("fail!"),
+		},
+
+		// no need to follow redirect if none given
+		{
+			max: 2,
+			client: &multiStaticHTTPClient{
+				responses: []staticHTTPResponse{
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTeapot,
+						},
+					},
+				},
+			},
+			wantCode: http.StatusTeapot,
+		},
+
+		// redirects if less than max
+		{
+			max: 2,
+			client: &multiStaticHTTPClient{
+				responses: []staticHTTPResponse{
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTemporaryRedirect,
+							Header:     http.Header{"Location": []string{"http://example.com"}},
+						},
+					},
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTeapot,
+						},
+					},
+				},
+			},
+			wantCode: http.StatusTeapot,
+		},
+
+		// succeed after reaching max redirects
+		{
+			max: 2,
+			client: &multiStaticHTTPClient{
+				responses: []staticHTTPResponse{
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTemporaryRedirect,
+							Header:     http.Header{"Location": []string{"http://example.com"}},
+						},
+					},
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTemporaryRedirect,
+							Header:     http.Header{"Location": []string{"http://example.com"}},
+						},
+					},
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTeapot,
+						},
+					},
+				},
+			},
+			wantCode: http.StatusTeapot,
+		},
+
+		// fail at max+1 redirects
+		{
+			max: 1,
+			client: &multiStaticHTTPClient{
+				responses: []staticHTTPResponse{
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTemporaryRedirect,
+							Header:     http.Header{"Location": []string{"http://example.com"}},
+						},
+					},
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTemporaryRedirect,
+							Header:     http.Header{"Location": []string{"http://example.com"}},
+						},
+					},
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTeapot,
+						},
+					},
+				},
+			},
+			wantErr: ErrTooManyRedirects,
+		},
+
+		// fail if Location header not set
+		{
+			max: 1,
+			client: &multiStaticHTTPClient{
+				responses: []staticHTTPResponse{
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTemporaryRedirect,
+						},
+					},
+				},
+			},
+			wantErr: errors.New("Location header not set"),
+		},
+
+		// fail if Location header is invalid
+		{
+			max: 1,
+			client: &multiStaticHTTPClient{
+				responses: []staticHTTPResponse{
+					staticHTTPResponse{
+						resp: http.Response{
+							StatusCode: http.StatusTemporaryRedirect,
+							Header:     http.Header{"Location": []string{":"}},
+						},
+					},
+				},
+			},
+			wantErr: errors.New("Location header not valid URL: :"),
+		},
+	}
+
+	for i, tt := range tests {
+		client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max}
+		resp, _, err := client.Do(context.Background(), nil)
+		if !reflect.DeepEqual(tt.wantErr, err) {
+			t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
+			continue
+		}
+
+		if resp == nil {
+			if tt.wantCode != 0 {
+				t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode)
+			}
+			continue
+		}
+
+		if resp.StatusCode != tt.wantCode {
+			t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
+			continue
+		}
+	}
+}