Browse Source

discovery: discovery will try forever when there is a timeout.

Perviously, etcd retries three times for timeout error. After this
commit, etcd retries forever. Also this commit make etcd client
aware of gateway timetout.
Xiang Li 11 years ago
parent
commit
7171410422
4 changed files with 80 additions and 41 deletions
  1. 15 7
      client/keys.go
  2. 17 7
      client/keys_test.go
  3. 21 20
      discovery/discovery.go
  4. 27 7
      discovery/discovery_test.go

+ 15 - 7
client/keys.go

@@ -71,6 +71,7 @@ type Response struct {
 	Action   string `json:"action"`
 	Node     *Node  `json:"node"`
 	PrevNode *Node  `json:"prevNode"`
+	Index    uint64
 }
 
 type Nodes []*Node
@@ -107,7 +108,7 @@ func (k *httpKeysAPI) Create(ctx context.Context, key, val string, ttl time.Dura
 		return nil, err
 	}
 
-	return unmarshalHTTPResponse(resp.StatusCode, body)
+	return unmarshalHTTPResponse(resp.StatusCode, resp.Header, body)
 }
 
 func (k *httpKeysAPI) Get(ctx context.Context, key string) (*Response, error) {
@@ -122,7 +123,7 @@ func (k *httpKeysAPI) Get(ctx context.Context, key string) (*Response, error) {
 		return nil, err
 	}
 
-	return unmarshalHTTPResponse(resp.StatusCode, body)
+	return unmarshalHTTPResponse(resp.StatusCode, resp.Header, body)
 }
 
 func (k *httpKeysAPI) Watch(key string, idx uint64) Watcher {
@@ -160,7 +161,7 @@ func (hw *httpWatcher) Next(ctx context.Context) (*Response, error) {
 		return nil, err
 	}
 
-	resp, err := unmarshalHTTPResponse(httpresp.StatusCode, body)
+	resp, err := unmarshalHTTPResponse(httpresp.StatusCode, httpresp.Header, body)
 	if err != nil {
 		return nil, err
 	}
@@ -243,10 +244,10 @@ func (c *createAction) HTTPRequest(ep url.URL) *http.Request {
 	return req
 }
 
-func unmarshalHTTPResponse(code int, body []byte) (res *Response, err error) {
+func unmarshalHTTPResponse(code int, header http.Header, body []byte) (res *Response, err error) {
 	switch code {
 	case http.StatusOK, http.StatusCreated:
-		res, err = unmarshalSuccessfulResponse(body)
+		res, err = unmarshalSuccessfulResponse(header, body)
 	default:
 		err = unmarshalErrorResponse(code)
 	}
@@ -254,13 +255,18 @@ func unmarshalHTTPResponse(code int, body []byte) (res *Response, err error) {
 	return
 }
 
-func unmarshalSuccessfulResponse(body []byte) (*Response, error) {
+func unmarshalSuccessfulResponse(header http.Header, body []byte) (*Response, error) {
 	var res Response
 	err := json.Unmarshal(body, &res)
 	if err != nil {
 		return nil, err
 	}
-
+	if header.Get("X-Etcd-Index") != "" {
+		res.Index, err = strconv.ParseUint(header.Get("X-Etcd-Index"), 10, 64)
+	}
+	if err != nil {
+		return nil, err
+	}
 	return &res, nil
 }
 
@@ -273,6 +279,8 @@ func unmarshalErrorResponse(code int) error {
 	case http.StatusInternalServerError:
 		// this isn't necessarily true
 		return ErrNoLeader
+	case http.StatusGatewayTimeout:
+		return ErrTimeout
 	default:
 	}
 

+ 17 - 7
client/keys_test.go

@@ -255,40 +255,46 @@ func assertResponse(got http.Request, wantURL *url.URL, wantHeader http.Header,
 
 func TestUnmarshalSuccessfulResponse(t *testing.T) {
 	tests := []struct {
+		indexHeader string
 		body        string
 		res         *Response
 		expectError bool
 	}{
 		// Neither PrevNode or Node
 		{
+			"1",
 			`{"action":"delete"}`,
-			&Response{Action: "delete"},
+			&Response{Action: "delete", Index: 1},
 			false,
 		},
 
 		// PrevNode
 		{
+			"15",
 			`{"action":"delete", "prevNode": {"key": "/foo", "value": "bar", "modifiedIndex": 12, "createdIndex": 10}}`,
-			&Response{Action: "delete", PrevNode: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}},
+			&Response{Action: "delete", Index: 15, PrevNode: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}},
 			false,
 		},
 
 		// Node
 		{
+			"15",
 			`{"action":"get", "node": {"key": "/foo", "value": "bar", "modifiedIndex": 12, "createdIndex": 10}}`,
-			&Response{Action: "get", Node: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}},
+			&Response{Action: "get", Index: 15, Node: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}},
 			false,
 		},
 
 		// PrevNode and Node
 		{
+			"15",
 			`{"action":"update", "prevNode": {"key": "/foo", "value": "baz", "modifiedIndex": 10, "createdIndex": 10}, "node": {"key": "/foo", "value": "bar", "modifiedIndex": 12, "createdIndex": 10}}`,
-			&Response{Action: "update", PrevNode: &Node{Key: "/foo", Value: "baz", ModifiedIndex: 10, CreatedIndex: 10}, Node: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}},
+			&Response{Action: "update", Index: 15, PrevNode: &Node{Key: "/foo", Value: "baz", ModifiedIndex: 10, CreatedIndex: 10}, Node: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}},
 			false,
 		},
 
 		// Garbage in body
 		{
+			"",
 			`garbage`,
 			nil,
 			true,
@@ -296,7 +302,9 @@ func TestUnmarshalSuccessfulResponse(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		res, err := unmarshalSuccessfulResponse([]byte(tt.body))
+		h := make(http.Header)
+		h.Add("X-Etcd-Index", tt.indexHeader)
+		res, err := unmarshalSuccessfulResponse(h, []byte(tt.body))
 		if tt.expectError != (err != nil) {
 			t.Errorf("#%d: expectError=%t, err=%v", i, tt.expectError, err)
 		}
@@ -312,7 +320,9 @@ func TestUnmarshalSuccessfulResponse(t *testing.T) {
 		if res.Action != tt.res.Action {
 			t.Errorf("#%d: Action=%s, expected %s", i, res.Action, tt.res.Action)
 		}
-
+		if res.Index != tt.res.Index {
+			t.Errorf("#%d: Index=%d, expected %d", i, res.Index, tt.res.Index)
+		}
 		if !reflect.DeepEqual(res.Node, tt.res.Node) {
 			t.Errorf("#%d: Node=%v, expected %v", i, res.Node, tt.res.Node)
 		}
@@ -350,7 +360,7 @@ func TestUnmarshalErrorResponse(t *testing.T) {
 		{http.StatusNotImplemented, unrecognized},
 		{http.StatusBadGateway, unrecognized},
 		{http.StatusServiceUnavailable, unrecognized},
-		{http.StatusGatewayTimeout, unrecognized},
+		{http.StatusGatewayTimeout, ErrTimeout},
 		{http.StatusHTTPVersionNotSupported, unrecognized},
 	}
 

+ 21 - 20
discovery/discovery.go

@@ -20,6 +20,7 @@ import (
 	"errors"
 	"fmt"
 	"log"
+	"math"
 	"net/http"
 	"net/url"
 	"path"
@@ -44,9 +45,9 @@ var (
 	ErrTooManyRetries = errors.New("discovery: too many retries")
 )
 
-const (
+var (
 	// Number of retries discovery will attempt before giving up and erroring out.
-	nRetries = uint(3)
+	nRetries = uint(math.MaxUint32)
 )
 
 // JoinCluster will connect to the discovery service at the given url, and
@@ -135,7 +136,7 @@ func newDiscovery(durl, dproxyurl string, id types.ID) (*discovery, error) {
 func (d *discovery) joinCluster(config string) (string, error) {
 	// fast path: if the cluster is full, return the error
 	// do not need to register to the cluster in this case.
-	if _, _, err := d.checkCluster(); err != nil {
+	if _, _, _, err := d.checkCluster(); err != nil {
 		return "", err
 	}
 
@@ -146,12 +147,12 @@ func (d *discovery) joinCluster(config string) (string, error) {
 		return "", err
 	}
 
-	nodes, size, err := d.checkCluster()
+	nodes, size, index, err := d.checkCluster()
 	if err != nil {
 		return "", err
 	}
 
-	all, err := d.waitNodes(nodes, size)
+	all, err := d.waitNodes(nodes, size, index)
 	if err != nil {
 		return "", err
 	}
@@ -160,7 +161,7 @@ func (d *discovery) joinCluster(config string) (string, error) {
 }
 
 func (d *discovery) getCluster() (string, error) {
-	nodes, size, err := d.checkCluster()
+	nodes, size, index, err := d.checkCluster()
 	if err != nil {
 		if err == ErrFullCluster {
 			return nodesToCluster(nodes), nil
@@ -168,7 +169,7 @@ func (d *discovery) getCluster() (string, error) {
 		return "", err
 	}
 
-	all, err := d.waitNodes(nodes, size)
+	all, err := d.waitNodes(nodes, size, index)
 	if err != nil {
 		return "", err
 	}
@@ -189,7 +190,7 @@ func (d *discovery) createSelf(contents string) error {
 	return err
 }
 
-func (d *discovery) checkCluster() (client.Nodes, int, error) {
+func (d *discovery) checkCluster() (client.Nodes, int, uint64, error) {
 	configKey := path.Join("/", d.cluster, "_config")
 	ctx, cancel := context.WithTimeout(context.Background(), client.DefaultRequestTimeout)
 	// find cluster size
@@ -197,16 +198,16 @@ func (d *discovery) checkCluster() (client.Nodes, int, error) {
 	cancel()
 	if err != nil {
 		if err == client.ErrKeyNoExist {
-			return nil, 0, ErrSizeNotFound
+			return nil, 0, 0, ErrSizeNotFound
 		}
 		if err == client.ErrTimeout {
 			return d.checkClusterRetry()
 		}
-		return nil, 0, err
+		return nil, 0, 0, err
 	}
 	size, err := strconv.Atoi(resp.Node.Value)
 	if err != nil {
-		return nil, 0, ErrBadSizeKey
+		return nil, 0, 0, ErrBadSizeKey
 	}
 
 	ctx, cancel = context.WithTimeout(context.Background(), client.DefaultRequestTimeout)
@@ -216,7 +217,7 @@ func (d *discovery) checkCluster() (client.Nodes, int, error) {
 		if err == client.ErrTimeout {
 			return d.checkClusterRetry()
 		}
-		return nil, 0, err
+		return nil, 0, 0, err
 	}
 	nodes := make(client.Nodes, 0)
 	// append non-config keys to nodes
@@ -235,10 +236,10 @@ func (d *discovery) checkCluster() (client.Nodes, int, error) {
 			break
 		}
 		if i >= size-1 {
-			return nodes[:size], size, ErrFullCluster
+			return nodes[:size], size, resp.Index, ErrFullCluster
 		}
 	}
-	return nodes, size, nil
+	return nodes, size, resp.Index, nil
 }
 
 func (d *discovery) logAndBackoffForRetry(step string) {
@@ -248,31 +249,31 @@ func (d *discovery) logAndBackoffForRetry(step string) {
 	d.clock.Sleep(retryTime)
 }
 
-func (d *discovery) checkClusterRetry() (client.Nodes, int, error) {
+func (d *discovery) checkClusterRetry() (client.Nodes, int, uint64, error) {
 	if d.retries < nRetries {
 		d.logAndBackoffForRetry("cluster status check")
 		return d.checkCluster()
 	}
-	return nil, 0, ErrTooManyRetries
+	return nil, 0, 0, ErrTooManyRetries
 }
 
 func (d *discovery) waitNodesRetry() (client.Nodes, error) {
 	if d.retries < nRetries {
 		d.logAndBackoffForRetry("waiting for other nodes")
-		nodes, n, err := d.checkCluster()
+		nodes, n, index, err := d.checkCluster()
 		if err != nil {
 			return nil, err
 		}
-		return d.waitNodes(nodes, n)
+		return d.waitNodes(nodes, n, index)
 	}
 	return nil, ErrTooManyRetries
 }
 
-func (d *discovery) waitNodes(nodes client.Nodes, size int) (client.Nodes, error) {
+func (d *discovery) waitNodes(nodes client.Nodes, size int, index uint64) (client.Nodes, error) {
 	if len(nodes) > size {
 		nodes = nodes[:size]
 	}
-	w := d.c.RecursiveWatch(d.cluster, nodes[len(nodes)-1].ModifiedIndex+1)
+	w := d.c.RecursiveWatch(d.cluster, index)
 	all := make(client.Nodes, len(nodes))
 	copy(all, nodes)
 	for _, n := range all {

+ 27 - 7
discovery/discovery_test.go

@@ -18,6 +18,7 @@ package discovery
 
 import (
 	"errors"
+	"math"
 	"math/rand"
 	"net/http"
 	"reflect"
@@ -31,6 +32,10 @@ import (
 	"github.com/coreos/etcd/client"
 )
 
+const (
+	maxRetryInTest = 3
+)
+
 func TestNewProxyFuncUnset(t *testing.T) {
 	pf, err := newProxyFunc("")
 	if pf != nil {
@@ -89,6 +94,7 @@ func TestCheckCluster(t *testing.T) {
 
 	tests := []struct {
 		nodes []*client.Node
+		index uint64
 		werr  error
 		wsize int
 	}{
@@ -102,6 +108,7 @@ func TestCheckCluster(t *testing.T) {
 				{Key: "/1000/3", CreatedIndex: 4},
 				{Key: "/1000/4", CreatedIndex: 5},
 			},
+			5,
 			nil,
 			3,
 		},
@@ -115,6 +122,7 @@ func TestCheckCluster(t *testing.T) {
 				{Key: self, CreatedIndex: 4},
 				{Key: "/1000/4", CreatedIndex: 5},
 			},
+			5,
 			nil,
 			3,
 		},
@@ -128,6 +136,7 @@ func TestCheckCluster(t *testing.T) {
 				{Key: "/1000/4", CreatedIndex: 4},
 				{Key: self, CreatedIndex: 5},
 			},
+			5,
 			ErrFullCluster,
 			3,
 		},
@@ -139,6 +148,7 @@ func TestCheckCluster(t *testing.T) {
 				{Key: "/1000/2", CreatedIndex: 2},
 				{Key: "/1000/3", CreatedIndex: 3},
 			},
+			3,
 			nil,
 			3,
 		},
@@ -150,6 +160,7 @@ func TestCheckCluster(t *testing.T) {
 				{Key: "/1000/3", CreatedIndex: 3},
 				{Key: "/1000/4", CreatedIndex: 4},
 			},
+			3,
 			ErrFullCluster,
 			3,
 		},
@@ -158,12 +169,14 @@ func TestCheckCluster(t *testing.T) {
 			[]*client.Node{
 				{Key: "/1000/_config/size", Value: "bad", CreatedIndex: 1},
 			},
+			0,
 			ErrBadSizeKey,
 			0,
 		},
 		{
 			// no size key
 			[]*client.Node{},
+			0,
 			ErrSizeNotFound,
 			0,
 		},
@@ -172,12 +185,13 @@ func TestCheckCluster(t *testing.T) {
 	for i, tt := range tests {
 		rs := make([]*client.Response, 0)
 		if len(tt.nodes) > 0 {
-			rs = append(rs, &client.Response{Node: tt.nodes[0]})
+			rs = append(rs, &client.Response{Node: tt.nodes[0], Index: tt.index})
 			rs = append(rs, &client.Response{
 				Node: &client.Node{
 					Key:   cluster,
 					Nodes: tt.nodes[1:],
 				},
+				Index: tt.index,
 			})
 		}
 		c := &clientWithResp{rs: rs}
@@ -190,12 +204,12 @@ func TestCheckCluster(t *testing.T) {
 
 		for _, d := range []discovery{d, dRetry} {
 			go func() {
-				for i := uint(1); i <= nRetries; i++ {
+				for i := uint(1); i <= maxRetryInTest; i++ {
 					fc.BlockUntil(1)
 					fc.Advance(time.Second * (0x1 << i))
 				}
 			}()
-			ns, size, err := d.checkCluster()
+			ns, size, index, err := d.checkCluster()
 			if err != tt.werr {
 				t.Errorf("#%d: err = %v, want %v", i, err, tt.werr)
 			}
@@ -205,6 +219,9 @@ func TestCheckCluster(t *testing.T) {
 			if size != tt.wsize {
 				t.Errorf("#%d: size = %v, want %d", i, size, tt.wsize)
 			}
+			if index != tt.index {
+				t.Errorf("#%d: index = %v, want %d", i, index, tt.index)
+			}
 		}
 	}
 }
@@ -278,12 +295,12 @@ func TestWaitNodes(t *testing.T) {
 
 		for _, d := range []*discovery{d, dRetry} {
 			go func() {
-				for i := uint(1); i <= nRetries; i++ {
+				for i := uint(1); i <= maxRetryInTest; i++ {
 					fc.BlockUntil(1)
 					fc.Advance(time.Second * (0x1 << i))
 				}
 			}()
-			g, err := d.waitNodes(tt.nodes, 3)
+			g, err := d.waitNodes(tt.nodes, 3, 0) // we do not care about index in this test
 			if err != nil {
 				t.Errorf("#%d: err = %v, want %v", i, err, nil)
 			}
@@ -368,6 +385,9 @@ func TestSortableNodes(t *testing.T) {
 }
 
 func TestRetryFailure(t *testing.T) {
+	nRetries = maxRetryInTest
+	defer func() { nRetries = math.MaxUint32 }()
+
 	cluster := "1000"
 	c := &clientWithRetry{failTimes: 4}
 	fc := clockwork.NewFakeClock()
@@ -378,12 +398,12 @@ func TestRetryFailure(t *testing.T) {
 		clock:   fc,
 	}
 	go func() {
-		for i := uint(1); i <= nRetries; i++ {
+		for i := uint(1); i <= maxRetryInTest; i++ {
 			fc.BlockUntil(1)
 			fc.Advance(time.Second * (0x1 << i))
 		}
 	}()
-	if _, _, err := d.checkCluster(); err != ErrTooManyRetries {
+	if _, _, _, err := d.checkCluster(); err != ErrTooManyRetries {
 		t.Errorf("err = %v, want %v", err, ErrTooManyRetries)
 	}
 }