Browse Source

Merge pull request #3164 from yichengq/pin-endpoint

client: pin itself to an endpoint that given
Yicheng Qin 10 years ago
parent
commit
6184e271a4
2 changed files with 78 additions and 12 deletions
  1. 25 3
      client/client.go
  2. 53 9
      client/client_test.go

+ 25 - 3
client/client.go

@@ -18,6 +18,7 @@ import (
 	"errors"
 	"fmt"
 	"io/ioutil"
+	"math/rand"
 	"net"
 	"net/http"
 	"net/url"
@@ -131,6 +132,7 @@ type Client interface {
 func New(cfg Config) (Client, error) {
 	c := &httpClusterClient{
 		clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect()),
+		rand:          rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
 	}
 	if cfg.Username != "" {
 		c.credentials = &credentials{
@@ -174,8 +176,10 @@ type httpAction interface {
 type httpClusterClient struct {
 	clientFactory httpClientFactory
 	endpoints     []url.URL
+	pinned        int
 	credentials   *credentials
 	sync.RWMutex
+	rand *rand.Rand
 }
 
 func (c *httpClusterClient) reset(eps []string) error {
@@ -192,7 +196,9 @@ func (c *httpClusterClient) reset(eps []string) error {
 		neps[i] = *u
 	}
 
-	c.endpoints = neps
+	c.endpoints = shuffleEndpoints(c.rand, neps)
+	// TODO: pin old endpoint if possible, and rebalance when new endpoint appears
+	c.pinned = 0
 
 	return nil
 }
@@ -203,6 +209,7 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
 	leps := len(c.endpoints)
 	eps := make([]url.URL, leps)
 	n := copy(eps, c.endpoints)
+	pinned := c.pinned
 
 	if c.credentials != nil {
 		action = &authedAction{
@@ -224,8 +231,9 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
 	var body []byte
 	var err error
 
-	for _, ep := range eps {
-		hc := c.clientFactory(ep)
+	for i := pinned; i < leps+pinned; i++ {
+		k := i % leps
+		hc := c.clientFactory(eps[k])
 		resp, body, err = hc.Do(ctx, action)
 		if err != nil {
 			if err == context.DeadlineExceeded || err == context.Canceled {
@@ -236,6 +244,11 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
 		if resp.StatusCode/100 == 5 {
 			continue
 		}
+		if k != pinned {
+			c.Lock()
+			c.pinned = k
+			c.Unlock()
+		}
 		break
 	}
 
@@ -401,3 +414,12 @@ func (r *redirectedHTTPAction) HTTPRequest(ep url.URL) *http.Request {
 	orig.URL = &r.location
 	return orig
 }
+
+func shuffleEndpoints(r *rand.Rand, eps []url.URL) []url.URL {
+	p := r.Perm(len(eps))
+	neps := make([]url.URL, len(eps))
+	for i, k := range p {
+		neps[i] = eps[k]
+	}
+	return neps
+}

+ 53 - 9
client/client_test.go

@@ -18,9 +18,11 @@ import (
 	"errors"
 	"io"
 	"io/ioutil"
+	"math/rand"
 	"net/http"
 	"net/url"
 	"reflect"
+	"sort"
 	"strings"
 	"testing"
 	"time"
@@ -299,9 +301,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
 	fakeErr := errors.New("fake!")
 	fakeURL := url.URL{}
 	tests := []struct {
-		client   *httpClusterClient
-		wantCode int
-		wantErr  error
+		client     *httpClusterClient
+		wantCode   int
+		wantErr    error
+		wantPinned int
 	}{
 		// first good response short-circuits Do
 		{
@@ -313,6 +316,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{err: fakeErr},
 					},
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
 			wantCode: http.StatusTeapot,
 		},
@@ -327,8 +331,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 					},
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
-			wantCode: http.StatusTeapot,
+			wantCode:   http.StatusTeapot,
+			wantPinned: 1,
 		},
 
 		// context.DeadlineExceeded short-circuits Do
@@ -341,6 +347,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 					},
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
 			wantErr: context.DeadlineExceeded,
 		},
@@ -355,6 +362,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 					},
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
 			wantErr: context.Canceled,
 		},
@@ -364,6 +372,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 			client: &httpClusterClient{
 				endpoints:     []url.URL{},
 				clientFactory: newHTTPClientFactory(nil, nil),
+				rand:          rand.New(rand.NewSource(0)),
 			},
 			wantErr: ErrNoEndpoints,
 		},
@@ -378,6 +387,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{err: fakeErr},
 					},
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
 			wantErr: fakeErr,
 		},
@@ -392,8 +402,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 					},
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
-			wantCode: http.StatusTeapot,
+			wantCode:   http.StatusTeapot,
+			wantPinned: 1,
 		},
 	}
 
@@ -415,6 +427,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
 			t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
 			continue
 		}
+
+		if tt.client.pinned != tt.wantPinned {
+			t.Errorf("#%d: pinned=%d, want=%d", i, tt.client.pinned, tt.wantPinned)
+		}
 	}
 }
 
@@ -671,7 +687,10 @@ func TestHTTPClusterClientSync(t *testing.T) {
 		},
 	})
 
-	hc := &httpClusterClient{clientFactory: cf}
+	hc := &httpClusterClient{
+		clientFactory: cf,
+		rand:          rand.New(rand.NewSource(0)),
+	}
 	err := hc.reset([]string{"http://127.0.0.1:2379"})
 	if err != nil {
 		t.Fatalf("unexpected error during setup: %#v", err)
@@ -688,8 +707,9 @@ func TestHTTPClusterClientSync(t *testing.T) {
 		t.Fatalf("unexpected error during Sync: %#v", err)
 	}
 
-	want = []string{"http://127.0.0.1:4003", "http://127.0.0.1:2379", "http://127.0.0.1:4001", "http://127.0.0.1:4002"}
+	want = []string{"http://127.0.0.1:2379", "http://127.0.0.1:4001", "http://127.0.0.1:4002", "http://127.0.0.1:4003"}
 	got = hc.Endpoints()
+	sort.Sort(sort.StringSlice(got))
 	if !reflect.DeepEqual(want, got) {
 		t.Fatalf("incorrect endpoints post-Sync: want=%#v got=%#v", want, got)
 	}
@@ -711,7 +731,10 @@ func TestHTTPClusterClientSyncFail(t *testing.T) {
 		staticHTTPResponse{err: errors.New("fail!")},
 	})
 
-	hc := &httpClusterClient{clientFactory: cf}
+	hc := &httpClusterClient{
+		clientFactory: cf,
+		rand:          rand.New(rand.NewSource(0)),
+	}
 	err := hc.reset([]string{"http://127.0.0.1:2379"})
 	if err != nil {
 		t.Fatalf("unexpected error during setup: %#v", err)
@@ -744,10 +767,31 @@ func TestHTTPClusterClientResetFail(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		hc := &httpClusterClient{}
+		hc := &httpClusterClient{rand: rand.New(rand.NewSource(0))}
 		err := hc.reset(tt)
 		if err == nil {
 			t.Errorf("#%d: expected non-nil error", i)
 		}
 	}
 }
+
+func TestHTTPClusterClientResetPinRandom(t *testing.T) {
+	round := 2000
+	pinNum := 0
+	for i := 0; i < round; i++ {
+		hc := &httpClusterClient{rand: rand.New(rand.NewSource(int64(i)))}
+		err := hc.reset([]string{"http://127.0.0.1:4001", "http://127.0.0.1:4002", "http://127.0.0.1:4003"})
+		if err != nil {
+			t.Fatalf("#%d: reset error (%v)", i, err)
+		}
+		if hc.endpoints[hc.pinned].String() == "http://127.0.0.1:4001" {
+			pinNum++
+		}
+	}
+
+	min := 1.0/3.0 - 0.05
+	max := 1.0/3.0 + 0.05
+	if ratio := float64(pinNum) / float64(round); ratio > max || ratio < min {
+		t.Errorf("pinned ratio = %v, want [%v, %v]", ratio, min, max)
+	}
+}