Browse Source

Merge pull request #8519 from heyitsanthony/client-oneshot-failover

client: fail over to next endpoint on oneshot failure
Anthony Romano 8 years ago
parent
commit
80aa810309
2 changed files with 37 additions and 15 deletions
  1. 10 9
      client/client.go
  2. 27 6
      client/client_test.go

+ 10 - 9
client/client.go

@@ -371,12 +371,7 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
 			if err == context.Canceled || err == context.DeadlineExceeded {
 			if err == context.Canceled || err == context.DeadlineExceeded {
 				return nil, nil, err
 				return nil, nil, err
 			}
 			}
-			if isOneShot {
-				return nil, nil, err
-			}
-			continue
-		}
-		if resp.StatusCode/100 == 5 {
+		} else if resp.StatusCode/100 == 5 {
 			switch resp.StatusCode {
 			switch resp.StatusCode {
 			case http.StatusInternalServerError, http.StatusServiceUnavailable:
 			case http.StatusInternalServerError, http.StatusServiceUnavailable:
 				// TODO: make sure this is a no leader response
 				// TODO: make sure this is a no leader response
@@ -384,10 +379,16 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
 			default:
 			default:
 				cerr.Errors = append(cerr.Errors, fmt.Errorf("client: etcd member %s returns server error [%s]", eps[k].String(), http.StatusText(resp.StatusCode)))
 				cerr.Errors = append(cerr.Errors, fmt.Errorf("client: etcd member %s returns server error [%s]", eps[k].String(), http.StatusText(resp.StatusCode)))
 			}
 			}
-			if isOneShot {
-				return nil, nil, cerr.Errors[0]
+			err = cerr.Errors[0]
+		}
+		if err != nil {
+			if !isOneShot {
+				continue
 			}
 			}
-			continue
+			c.Lock()
+			c.pinned = (k + 1) % leps
+			c.Unlock()
+			return nil, nil, err
 		}
 		}
 		if k != pinned {
 		if k != pinned {
 			c.Lock()
 			c.Lock()

+ 27 - 6
client/client_test.go

@@ -17,6 +17,7 @@ package client
 import (
 import (
 	"context"
 	"context"
 	"errors"
 	"errors"
+	"fmt"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"math/rand"
 	"math/rand"
@@ -304,7 +305,9 @@ func TestHTTPClusterClientDo(t *testing.T) {
 	fakeErr := errors.New("fake!")
 	fakeErr := errors.New("fake!")
 	fakeURL := url.URL{}
 	fakeURL := url.URL{}
 	tests := []struct {
 	tests := []struct {
-		client     *httpClusterClient
+		client *httpClusterClient
+		ctx    context.Context
+
 		wantCode   int
 		wantCode   int
 		wantErr    error
 		wantErr    error
 		wantPinned int
 		wantPinned int
@@ -395,10 +398,30 @@ func TestHTTPClusterClientDo(t *testing.T) {
 			wantCode:   http.StatusTeapot,
 			wantCode:   http.StatusTeapot,
 			wantPinned: 1,
 			wantPinned: 1,
 		},
 		},
+
+		// 500-level errors cause one shot Do to fallthrough to next endpoint
+		{
+			client: &httpClusterClient{
+				endpoints: []url.URL{fakeURL, fakeURL},
+				clientFactory: newStaticHTTPClientFactory(
+					[]staticHTTPResponse{
+						{resp: http.Response{StatusCode: http.StatusBadGateway}},
+						{resp: http.Response{StatusCode: http.StatusTeapot}},
+					},
+				),
+				rand: rand.New(rand.NewSource(0)),
+			},
+			ctx:        context.WithValue(context.Background(), &oneShotCtxValue, &oneShotCtxValue),
+			wantErr:    fmt.Errorf("client: etcd member  returns server error [Bad Gateway]"),
+			wantPinned: 1,
+		},
 	}
 	}
 
 
 	for i, tt := range tests {
 	for i, tt := range tests {
-		resp, _, err := tt.client.Do(context.Background(), nil)
+		if tt.ctx == nil {
+			tt.ctx = context.Background()
+		}
+		resp, _, err := tt.client.Do(tt.ctx, nil)
 		if !reflect.DeepEqual(tt.wantErr, err) {
 		if !reflect.DeepEqual(tt.wantErr, err) {
 			t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
 			t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
 			continue
 			continue
@@ -407,11 +430,9 @@ func TestHTTPClusterClientDo(t *testing.T) {
 		if resp == nil {
 		if resp == nil {
 			if tt.wantCode != 0 {
 			if tt.wantCode != 0 {
 				t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode)
 				t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode)
+				continue
 			}
 			}
-			continue
-		}
-
-		if resp.StatusCode != tt.wantCode {
+		} else if resp.StatusCode != tt.wantCode {
 			t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
 			t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
 			continue
 			continue
 		}
 		}