ソースを参照

Merge pull request #3245 from yichengq/client_timeout

client: set timeout for each request
Yicheng Qin 10 年 前
コミット
1fe52e1ec3
2 ファイル変更90 行追加24 行削除
  1. 40 8
      client/client.go
  2. 50 16
      client/client_test.go

+ 40 - 8
client/client.go

@@ -87,6 +87,23 @@ type Config struct {
 	// Password is the password for the specified user to add as an authorization header
 	// to the request.
 	Password string
+
+	// HeaderTimeoutPerRequest specifies the time limit to wait for response
+	// header in a single request made by the Client. The timeout includes
+	// connection time, any redirects, and header wait time.
+	//
+	// For non-watch GET request, server returns the response body immediately.
+	// For PUT/POST/DELETE request, server will attempt to commit request
+	// before responding, which is expected to take `100ms + 2 * RTT`.
+	// For watch request, server returns the header immediately to notify Client
+	// watch start. But if server is behind some kind of proxy, the response
+	// header may be cached at proxy, and Client cannot rely on this behavior.
+	//
+	// One API call may send multiple requests to different etcd servers until it
+	// succeeds. Use context of the API to specify the overall timeout.
+	//
+	// A HeaderTimeoutPerRequest of zero means no timeout.
+	HeaderTimeoutPerRequest time.Duration
 }
 
 func (cfg *Config) transport() CancelableTransport {
@@ -150,7 +167,7 @@ type Client interface {
 
 func New(cfg Config) (Client, error) {
 	c := &httpClusterClient{
-		clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect()),
+		clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect(), cfg.HeaderTimeoutPerRequest),
 		rand:          rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
 	}
 	if cfg.Username != "" {
@@ -169,13 +186,14 @@ type httpClient interface {
 	Do(context.Context, httpAction) (*http.Response, []byte, error)
 }
 
-func newHTTPClientFactory(tr CancelableTransport, cr CheckRedirectFunc) httpClientFactory {
+func newHTTPClientFactory(tr CancelableTransport, cr CheckRedirectFunc, headerTimeout time.Duration) httpClientFactory {
 	return func(ep url.URL) httpClient {
 		return &redirectFollowingHTTPClient{
 			checkRedirect: cr,
 			client: &simpleHTTPClient{
-				transport: tr,
-				endpoint:  ep,
+				transport:     tr,
+				endpoint:      ep,
+				headerTimeout: headerTimeout,
 			},
 		}
 	}
@@ -353,8 +371,9 @@ type roundTripResponse struct {
 }
 
 type simpleHTTPClient struct {
-	transport CancelableTransport
-	endpoint  url.URL
+	transport     CancelableTransport
+	endpoint      url.URL
+	headerTimeout time.Duration
 }
 
 func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Response, []byte, error) {
@@ -364,6 +383,12 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon
 		return nil, nil, err
 	}
 
+	hctx, hcancel := context.WithCancel(ctx)
+	if c.headerTimeout > 0 {
+		hctx, hcancel = context.WithTimeout(ctx, c.headerTimeout)
+	}
+	defer hcancel()
+
 	rtchan := make(chan roundTripResponse, 1)
 	go func() {
 		resp, err := c.transport.RoundTrip(req)
@@ -377,12 +402,19 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon
 	select {
 	case rtresp := <-rtchan:
 		resp, err = rtresp.resp, rtresp.err
-	case <-ctx.Done():
+	case <-hctx.Done():
 		// cancel and wait for request to actually exit before continuing
 		c.transport.CancelRequest(req)
 		rtresp := <-rtchan
 		resp = rtresp.resp
-		err = ctx.Err()
+		switch {
+		case ctx.Err() != nil:
+			err = ctx.Err()
+		case hctx.Err() != nil:
+			err = fmt.Errorf("client: endpoint %s exceeded header timeout", c.endpoint)
+		default:
+			panic("failed to get error from context")
+		}
 	}
 
 	// always check for resp nil-ness to deal with possible

+ 50 - 16
client/client_test.go

@@ -297,6 +297,27 @@ func TestSimpleHTTPClientDoCancelContextWaitForRoundTrip(t *testing.T) {
 	}
 }
 
+func TestSimpleHTTPClientDoHeaderTimeout(t *testing.T) {
+	tr := newFakeTransport()
+	tr.finishCancel <- struct{}{}
+	c := &simpleHTTPClient{transport: tr, headerTimeout: time.Millisecond}
+
+	errc := make(chan error)
+	go func() {
+		_, _, err := c.Do(context.Background(), &fakeAction{})
+		errc <- err
+	}()
+
+	select {
+	case err := <-errc:
+		if err == nil {
+			t.Fatalf("expected non-nil error, got nil")
+		}
+	case <-time.After(time.Second):
+		t.Fatalf("unexpected timeout when waitting for the test to finish")
+	}
+}
+
 func TestHTTPClusterClientDo(t *testing.T) {
 	fakeErr := errors.New("fake!")
 	fakeURL := url.URL{}
@@ -337,21 +358,6 @@ func TestHTTPClusterClientDo(t *testing.T) {
 			wantPinned: 1,
 		},
 
-		// context.DeadlineExceeded short-circuits Do
-		{
-			client: &httpClusterClient{
-				endpoints: []url.URL{fakeURL, fakeURL},
-				clientFactory: newStaticHTTPClientFactory(
-					[]staticHTTPResponse{
-						staticHTTPResponse{err: context.DeadlineExceeded},
-						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
-					},
-				),
-				rand: rand.New(rand.NewSource(0)),
-			},
-			wantErr: &ClusterError{Errors: []error{context.DeadlineExceeded}},
-		},
-
 		// context.Canceled short-circuits Do
 		{
 			client: &httpClusterClient{
@@ -371,7 +377,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		{
 			client: &httpClusterClient{
 				endpoints:     []url.URL{},
-				clientFactory: newHTTPClientFactory(nil, nil),
+				clientFactory: newHTTPClientFactory(nil, nil, 0),
 				rand:          rand.New(rand.NewSource(0)),
 			},
 			wantErr: ErrNoEndpoints,
@@ -434,6 +440,34 @@ func TestHTTPClusterClientDo(t *testing.T) {
 	}
 }
 
+func TestHTTPClusterClientDoDeadlineExceedContext(t *testing.T) {
+	fakeURL := url.URL{}
+	tr := newFakeTransport()
+	tr.finishCancel <- struct{}{}
+	c := &httpClusterClient{
+		clientFactory: newHTTPClientFactory(tr, DefaultCheckRedirect, 0),
+		endpoints:     []url.URL{fakeURL},
+	}
+
+	errc := make(chan error)
+	go func() {
+		ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
+		defer cancel()
+		_, _, err := c.Do(ctx, &fakeAction{})
+		errc <- err
+	}()
+
+	select {
+	case err := <-errc:
+		werr := &ClusterError{Errors: []error{context.DeadlineExceeded}}
+		if !reflect.DeepEqual(err, werr) {
+			t.Errorf("err = %+v, want %+v", err, werr)
+		}
+	case <-time.After(time.Second):
+		t.Fatalf("unexpected timeout when waitting for request to deadline exceed")
+	}
+}
+
 func TestRedirectedHTTPAction(t *testing.T) {
 	act := &redirectedHTTPAction{
 		action: &staticHTTPAction{