Browse Source

client: allow caller to decide HTTP redirect policy

Brian Waldon 11 years ago
parent
commit
9b334e07a6
2 changed files with 68 additions and 22 deletions
  1. 39 8
      client/client.go
  2. 29 14
      client/client_test.go

+ 39 - 8
client/client.go

@@ -39,7 +39,6 @@ var (
 	ErrKeyExists  = errors.New("client: key already exists")
 
 	DefaultRequestTimeout = 5 * time.Second
-	DefaultMaxRedirects   = 10
 )
 
 var DefaultTransport CancelableTransport = &http.Transport{
@@ -72,6 +71,17 @@ type Config struct {
 	// Transport is used by the Client to drive HTTP requests. If not
 	// provided, DefaultTransport will be used.
 	Transport CancelableTransport
+
+	// CheckRedirect specifies the policy for handling HTTP redirects.
+	// If CheckRedirect is not nil, the Client calls it before
+	// following an HTTP redirect. The sole argument is the number of
+	// requests that have alrady been made. If CheckRedirect returns
+	// an error, Client.Do will not make any further requests and return
+	// the error back it to the caller.
+	//
+	// If CheckRedirect is nil, the Client uses its default policy,
+	// which is to stop after 10 consecutive requests.
+	CheckRedirect CheckRedirectFunc
 }
 
 func (cfg *Config) transport() CancelableTransport {
@@ -81,6 +91,13 @@ func (cfg *Config) transport() CancelableTransport {
 	return cfg.Transport
 }
 
+func (cfg *Config) checkRedirect() CheckRedirectFunc {
+	if cfg.CheckRedirect == nil {
+		return DefaultCheckRedirect
+	}
+	return cfg.CheckRedirect
+}
+
 // CancelableTransport mimics net/http.Transport, but requires that
 // the object also support request cancellation.
 type CancelableTransport interface {
@@ -88,6 +105,16 @@ type CancelableTransport interface {
 	CancelRequest(req *http.Request)
 }
 
+type CheckRedirectFunc func(via int) error
+
+// DefaultCheckRedirect follows up to 10 redirects, but no more.
+var DefaultCheckRedirect CheckRedirectFunc = func(via int) error {
+	if via > 10 {
+		return ErrTooManyRedirects
+	}
+	return nil
+}
+
 type Client interface {
 	// Sync updates the internal cache of the etcd cluster's membership.
 	Sync(context.Context) error
@@ -101,7 +128,7 @@ type Client interface {
 }
 
 func New(cfg Config) (Client, error) {
-	c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport())}
+	c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect())}
 	if err := c.reset(cfg.Endpoints); err != nil {
 		return nil, err
 	}
@@ -112,10 +139,10 @@ type httpClient interface {
 	Do(context.Context, httpAction) (*http.Response, []byte, error)
 }
 
-func newHTTPClientFactory(tr CancelableTransport) httpClientFactory {
+func newHTTPClientFactory(tr CancelableTransport, cr CheckRedirectFunc) httpClientFactory {
 	return func(ep url.URL) httpClient {
 		return &redirectFollowingHTTPClient{
-			max: DefaultMaxRedirects,
+			checkRedirect: cr,
 			client: &simpleHTTPClient{
 				transport: tr,
 				endpoint:  ep,
@@ -270,12 +297,17 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon
 }
 
 type redirectFollowingHTTPClient struct {
-	client httpClient
-	max    int
+	client        httpClient
+	checkRedirect CheckRedirectFunc
 }
 
 func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*http.Response, []byte, error) {
-	for i := 0; i <= r.max; i++ {
+	for i := 0; ; i++ {
+		if i > 0 {
+			if err := r.checkRedirect(i); err != nil {
+				return nil, nil, err
+			}
+		}
 		resp, body, err := r.client.Do(ctx, act)
 		if err != nil {
 			return nil, nil, err
@@ -297,7 +329,6 @@ func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*
 		}
 		return resp, body, nil
 	}
-	return nil, nil, ErrTooManyRedirects
 }
 
 type redirectedHTTPAction struct {

+ 29 - 14
client/client_test.go

@@ -258,7 +258,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		{
 			client: &httpClusterClient{
 				endpoints:     []url.URL{},
-				clientFactory: newHTTPClientFactory(nil),
+				clientFactory: newHTTPClientFactory(nil, nil),
 			},
 			wantErr: ErrNoEndpoints,
 		},
@@ -349,14 +349,14 @@ func TestRedirectedHTTPAction(t *testing.T) {
 
 func TestRedirectFollowingHTTPClient(t *testing.T) {
 	tests := []struct {
-		max      int
-		client   httpClient
-		wantCode int
-		wantErr  error
+		checkRedirect CheckRedirectFunc
+		client        httpClient
+		wantCode      int
+		wantErr       error
 	}{
 		// errors bubbled up
 		{
-			max: 2,
+			checkRedirect: func(int) error { return ErrTooManyRedirects },
 			client: &multiStaticHTTPClient{
 				responses: []staticHTTPResponse{
 					staticHTTPResponse{
@@ -369,7 +369,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
 
 		// no need to follow redirect if none given
 		{
-			max: 2,
+			checkRedirect: func(int) error { return ErrTooManyRedirects },
 			client: &multiStaticHTTPClient{
 				responses: []staticHTTPResponse{
 					staticHTTPResponse{
@@ -384,7 +384,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
 
 		// redirects if less than max
 		{
-			max: 2,
+			checkRedirect: func(via int) error {
+				if via >= 2 {
+					return ErrTooManyRedirects
+				}
+				return nil
+			},
 			client: &multiStaticHTTPClient{
 				responses: []staticHTTPResponse{
 					staticHTTPResponse{
@@ -405,7 +410,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
 
 		// succeed after reaching max redirects
 		{
-			max: 2,
+			checkRedirect: func(via int) error {
+				if via >= 3 {
+					return ErrTooManyRedirects
+				}
+				return nil
+			},
 			client: &multiStaticHTTPClient{
 				responses: []staticHTTPResponse{
 					staticHTTPResponse{
@@ -430,9 +440,14 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
 			wantCode: http.StatusTeapot,
 		},
 
-		// fail at max+1 redirects
+		// fail if too many redirects
 		{
-			max: 1,
+			checkRedirect: func(via int) error {
+				if via >= 2 {
+					return ErrTooManyRedirects
+				}
+				return nil
+			},
 			client: &multiStaticHTTPClient{
 				responses: []staticHTTPResponse{
 					staticHTTPResponse{
@@ -459,7 +474,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
 
 		// fail if Location header not set
 		{
-			max: 1,
+			checkRedirect: func(int) error { return ErrTooManyRedirects },
 			client: &multiStaticHTTPClient{
 				responses: []staticHTTPResponse{
 					staticHTTPResponse{
@@ -474,7 +489,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
 
 		// fail if Location header is invalid
 		{
-			max: 1,
+			checkRedirect: func(int) error { return ErrTooManyRedirects },
 			client: &multiStaticHTTPClient{
 				responses: []staticHTTPResponse{
 					staticHTTPResponse{
@@ -490,7 +505,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max}
+		client := &redirectFollowingHTTPClient{client: tt.client, checkRedirect: tt.checkRedirect}
 		resp, _, err := client.Do(context.Background(), nil)
 		if !reflect.DeepEqual(tt.wantErr, err) {
 			t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)