Browse Source

client: don't cache httpClients in httpClusterClient

Brian Waldon 11 years ago
parent
commit
62054dfb5e
2 changed files with 103 additions and 56 deletions
  1. 49 31
      client/http.go
  2. 54 25
      client/http_test.go

+ 49 - 31
client/http.go

@@ -36,13 +36,23 @@ var (
 	DefaultMaxRedirects   = 10
 )
 
+func defaultHTTPClientFactory(tr CancelableTransport, ep url.URL) HTTPClient {
+	return &redirectFollowingHTTPClient{
+		max: DefaultMaxRedirects,
+		client: &httpClient{
+			transport: tr,
+			endpoint:  ep,
+		},
+	}
+}
+
 type ClientConfig struct {
 	Endpoints []string
 	Transport CancelableTransport
 }
 
 func New(cfg ClientConfig) (SyncableHTTPClient, error) {
-	return newHTTPClusterClient(cfg.Transport, cfg.Endpoints)
+	return newHTTPClusterClient(cfg.Transport, cfg.Endpoints, defaultHTTPClientFactory)
 }
 
 type SyncableHTTPClient interface {
@@ -55,6 +65,8 @@ type HTTPClient interface {
 	Do(context.Context, HTTPAction) (*http.Response, []byte, error)
 }
 
+type httpClientFactory func(CancelableTransport, url.URL) HTTPClient
+
 type HTTPAction interface {
 	HTTPRequest(url.URL) *http.Request
 }
@@ -67,8 +79,8 @@ type CancelableTransport interface {
 	CancelRequest(req *http.Request)
 }
 
-func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterClient, error) {
-	c := &httpClusterClient{}
+func newHTTPClusterClient(tr CancelableTransport, eps []string, cf httpClientFactory) (*httpClusterClient, error) {
+	c := &httpClusterClient{clientFactory: cf}
 	if err := c.reset(tr, eps); err != nil {
 		return nil, err
 	}
@@ -76,37 +88,27 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli
 }
 
 type httpClusterClient struct {
-	transport CancelableTransport
-	endpoints []string
-	clients   []HTTPClient
+	clientFactory httpClientFactory
+	transport     CancelableTransport
+	endpoints     []url.URL
 	sync.RWMutex
 }
 
 func (c *httpClusterClient) reset(tr CancelableTransport, eps []string) error {
-	le := len(eps)
-	ne := make([]string, le)
-	if copy(ne, eps) != le {
-		return errors.New("copy call failed")
+	if len(eps) == 0 {
+		return ErrNoEndpoints
 	}
 
-	nc := make([]HTTPClient, len(ne))
-	for i, e := range ne {
-		u, err := url.Parse(e)
+	neps := make([]url.URL, len(eps))
+	for i, ep := range eps {
+		u, err := url.Parse(ep)
 		if err != nil {
 			return err
 		}
-
-		nc[i] = &redirectFollowingHTTPClient{
-			max: DefaultMaxRedirects,
-			client: &httpClient{
-				transport: tr,
-				endpoint:  *u,
-			},
-		}
+		neps[i] = *u
 	}
 
-	c.endpoints = ne
-	c.clients = nc
+	c.endpoints = neps
 	c.transport = tr
 
 	return nil
@@ -114,12 +116,24 @@ func (c *httpClusterClient) reset(tr CancelableTransport, eps []string) error {
 
 func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.Response, body []byte, err error) {
 	c.RLock()
-	defer c.RUnlock()
+	leps := len(c.endpoints)
+	eps := make([]url.URL, leps)
+	n := copy(eps, c.endpoints)
+	tr := c.transport
+	c.RUnlock()
+
+	if leps == 0 {
+		err = ErrNoEndpoints
+		return
+	}
 
-	if len(c.clients) == 0 {
-		return nil, nil, ErrNoEndpoints
+	if leps != n {
+		err = errors.New("unable to pick endpoint: copy failed")
+		return
 	}
-	for _, hc := range c.clients {
+
+	for _, ep := range eps {
+		hc := c.clientFactory(tr, ep)
 		resp, body, err = hc.Do(ctx, act)
 		if err != nil {
 			if err == ErrTimeout || err == ErrCanceled {
@@ -132,13 +146,20 @@ func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.
 		}
 		break
 	}
+
 	return
 }
 
 func (c *httpClusterClient) Endpoints() []string {
 	c.RLock()
 	defer c.RUnlock()
-	return c.endpoints
+
+	eps := make([]string, len(c.endpoints))
+	for i, ep := range c.endpoints {
+		eps[i] = ep.String()
+	}
+
+	return eps
 }
 
 func (c *httpClusterClient) Sync(ctx context.Context) error {
@@ -155,9 +176,6 @@ func (c *httpClusterClient) Sync(ctx context.Context) error {
 	for _, m := range ms {
 		eps = append(eps, m.ClientURLs...)
 	}
-	if len(eps) == 0 {
-		return ErrNoEndpoints
-	}
 
 	return c.reset(c.transport, eps)
 }

+ 54 - 25
client/http_test.go

@@ -60,6 +60,15 @@ func (s *multiStaticHTTPClient) Do(context.Context, HTTPAction) (*http.Response,
 	return &r.resp, nil, r.err
 }
 
+func newStaticHTTPClientFactory(responses []staticHTTPResponse) httpClientFactory {
+	var cur int
+	return func(CancelableTransport, url.URL) HTTPClient {
+		r := responses[cur]
+		cur++
+		return &staticHTTPClient{resp: r.resp, err: r.err}
+	}
+}
+
 type fakeTransport struct {
 	respchan     chan *http.Response
 	errchan      chan error
@@ -183,6 +192,7 @@ func TestHTTPClientDoCancelContextWaitForRoundTrip(t *testing.T) {
 
 func TestHTTPClusterClientDo(t *testing.T) {
 	fakeErr := errors.New("fake!")
+	fakeURL := url.URL{}
 	tests := []struct {
 		client   *httpClusterClient
 		wantCode int
@@ -191,10 +201,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// first good response short-circuits Do
 		{
 			client: &httpClusterClient{
-				clients: []HTTPClient{
-					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
-					&staticHTTPClient{err: fakeErr},
-				},
+				endpoints: []url.URL{fakeURL, fakeURL},
+				clientFactory: newStaticHTTPClientFactory(
+					[]staticHTTPResponse{
+						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
+						staticHTTPResponse{err: fakeErr},
+					},
+				),
 			},
 			wantCode: http.StatusTeapot,
 		},
@@ -202,10 +215,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// fall through to good endpoint if err is arbitrary
 		{
 			client: &httpClusterClient{
-				clients: []HTTPClient{
-					&staticHTTPClient{err: fakeErr},
-					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
-				},
+				endpoints: []url.URL{fakeURL, fakeURL},
+				clientFactory: newStaticHTTPClientFactory(
+					[]staticHTTPResponse{
+						staticHTTPResponse{err: fakeErr},
+						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
+					},
+				),
 			},
 			wantCode: http.StatusTeapot,
 		},
@@ -213,10 +229,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// ErrTimeout short-circuits Do
 		{
 			client: &httpClusterClient{
-				clients: []HTTPClient{
-					&staticHTTPClient{err: ErrTimeout},
-					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
-				},
+				endpoints: []url.URL{fakeURL, fakeURL},
+				clientFactory: newStaticHTTPClientFactory(
+					[]staticHTTPResponse{
+						staticHTTPResponse{err: ErrTimeout},
+						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
+					},
+				),
 			},
 			wantErr: ErrTimeout,
 		},
@@ -224,10 +243,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// ErrCanceled short-circuits Do
 		{
 			client: &httpClusterClient{
-				clients: []HTTPClient{
-					&staticHTTPClient{err: ErrCanceled},
-					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
-				},
+				endpoints: []url.URL{fakeURL, fakeURL},
+				clientFactory: newStaticHTTPClientFactory(
+					[]staticHTTPResponse{
+						staticHTTPResponse{err: ErrCanceled},
+						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
+					},
+				),
 			},
 			wantErr: ErrCanceled,
 		},
@@ -235,7 +257,8 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// return err if there are no endpoints
 		{
 			client: &httpClusterClient{
-				clients: []HTTPClient{},
+				endpoints:     []url.URL{},
+				clientFactory: defaultHTTPClientFactory,
 			},
 			wantErr: ErrNoEndpoints,
 		},
@@ -243,10 +266,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// return err if all endpoints return arbitrary errors
 		{
 			client: &httpClusterClient{
-				clients: []HTTPClient{
-					&staticHTTPClient{err: fakeErr},
-					&staticHTTPClient{err: fakeErr},
-				},
+				endpoints: []url.URL{fakeURL, fakeURL},
+				clientFactory: newStaticHTTPClientFactory(
+					[]staticHTTPResponse{
+						staticHTTPResponse{err: fakeErr},
+						staticHTTPResponse{err: fakeErr},
+					},
+				),
 			},
 			wantErr: fakeErr,
 		},
@@ -254,10 +280,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		// 500-level errors cause Do to fallthrough to next endpoint
 		{
 			client: &httpClusterClient{
-				clients: []HTTPClient{
-					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusBadGateway}},
-					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
-				},
+				endpoints: []url.URL{fakeURL, fakeURL},
+				clientFactory: newStaticHTTPClientFactory(
+					[]staticHTTPResponse{
+						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusBadGateway}},
+						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
+					},
+				),
 			},
 			wantCode: http.StatusTeapot,
 		},