Browse Source

Merge pull request #1570 from bcwaldon/client-endpoints

client: use all endpoints
Brian Waldon 11 years ago
parent
commit
729770f32a
2 changed files with 122 additions and 4 deletions
  1. 18 4
      client/http.go
  2. 104 0
      client/http_test.go

+ 18 - 4
client/http.go

@@ -26,7 +26,9 @@ import (
 )
 
 var (
-	ErrTimeout            = context.DeadlineExceeded
+	ErrTimeout  = context.DeadlineExceeded
+	ErrCanceled = context.Canceled
+
 	DefaultRequestTimeout = 5 * time.Second
 )
 
@@ -81,9 +83,21 @@ type httpClusterClient struct {
 	endpoints []HTTPClient
 }
 
-func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (*http.Response, []byte, error) {
-	//TODO(bcwaldon): introduce retry logic so all endpoints are attempted
-	return c.endpoints[0].Do(ctx, act)
+func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.Response, body []byte, err error) {
+	for _, hc := range c.endpoints {
+		resp, body, err = hc.Do(ctx, act)
+		if err != nil {
+			if err == ErrTimeout || err == ErrCanceled {
+				return nil, nil, err
+			}
+			continue
+		}
+		if resp.StatusCode/100 == 5 {
+			continue
+		}
+		break
+	}
+	return
 }
 
 func (c *httpClusterClient) Sync(ctx context.Context) error {

+ 104 - 0
client/http_test.go

@@ -29,6 +29,15 @@ import (
 	"github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context"
 )
 
+type staticHTTPClient struct {
+	resp http.Response
+	err  error
+}
+
+func (s *staticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []byte, error) {
+	return &s.resp, nil, s.err
+}
+
 type fakeTransport struct {
 	respchan     chan *http.Response
 	errchan      chan error
@@ -149,3 +158,98 @@ func TestHTTPClientDoCancelContextWaitForRoundTrip(t *testing.T) {
 		t.Fatalf("httpClient.do did not exit within 1s")
 	}
 }
+
+func TestHTTPClusterClientDo(t *testing.T) {
+	fakeErr := errors.New("fake!")
+	tests := []struct {
+		client   *httpClusterClient
+		wantCode int
+		wantErr  error
+	}{
+		// first good response short-circuits Do
+		{
+			client: &httpClusterClient{
+				endpoints: []HTTPClient{
+					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
+					&staticHTTPClient{err: fakeErr},
+				},
+			},
+			wantCode: http.StatusTeapot,
+		},
+
+		// fall through to good endpoint if err is arbitrary
+		{
+			client: &httpClusterClient{
+				endpoints: []HTTPClient{
+					&staticHTTPClient{err: fakeErr},
+					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
+				},
+			},
+			wantCode: http.StatusTeapot,
+		},
+
+		// ErrTimeout short-circuits Do
+		{
+			client: &httpClusterClient{
+				endpoints: []HTTPClient{
+					&staticHTTPClient{err: ErrTimeout},
+					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
+				},
+			},
+			wantErr: ErrTimeout,
+		},
+
+		// ErrCanceled short-circuits Do
+		{
+			client: &httpClusterClient{
+				endpoints: []HTTPClient{
+					&staticHTTPClient{err: ErrCanceled},
+					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
+				},
+			},
+			wantErr: ErrCanceled,
+		},
+
+		// return err if all endpoints return arbitrary errors
+		{
+			client: &httpClusterClient{
+				endpoints: []HTTPClient{
+					&staticHTTPClient{err: fakeErr},
+					&staticHTTPClient{err: fakeErr},
+				},
+			},
+			wantErr: fakeErr,
+		},
+
+		// 500-level errors cause Do to fallthrough to next endpoint
+		{
+			client: &httpClusterClient{
+				endpoints: []HTTPClient{
+					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusBadGateway}},
+					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
+				},
+			},
+			wantCode: http.StatusTeapot,
+		},
+	}
+
+	for i, tt := range tests {
+		resp, _, err := tt.client.Do(context.Background(), nil)
+		if !reflect.DeepEqual(tt.wantErr, err) {
+			t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
+			continue
+		}
+
+		if resp == nil {
+			if tt.wantCode != 0 {
+				t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode)
+			}
+			continue
+		}
+
+		if resp.StatusCode != tt.wantCode {
+			t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
+			continue
+		}
+	}
+}