Browse Source

client: pin itself to an endpoint that given

1. When reset endpoints, client will choose a random endpoint to pin.
2. If the pinned endpoint is healthy, client will keep using it.
3. If the pinned endpoint becomes unhealthy, client will attempt other
endpoints and update its pin.
Yicheng Qin 10 năm trước cách đây
mục cha
commit
ea2347a40f
2 tập tin đã thay đổi với 78 bổ sung12 xóa
  1. 25 3
      client/client.go
  2. 53 9
      client/client_test.go

+ 25 - 3
client/client.go

@@ -18,6 +18,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
+	"math/rand"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
@@ -131,6 +132,7 @@ type Client interface {
 func New(cfg Config) (Client, error) {
 func New(cfg Config) (Client, error) {
 	c := &httpClusterClient{
 	c := &httpClusterClient{
 		clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect()),
 		clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect()),
+		rand:          rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
 	}
 	}
 	if cfg.Username != "" {
 	if cfg.Username != "" {
 		c.credentials = &credentials{
 		c.credentials = &credentials{
@@ -174,8 +176,10 @@ type httpAction interface {
 type httpClusterClient struct {
 type httpClusterClient struct {
 	clientFactory httpClientFactory
 	clientFactory httpClientFactory
 	endpoints     []url.URL
 	endpoints     []url.URL
+	pinned        int
 	credentials   *credentials
 	credentials   *credentials
 	sync.RWMutex
 	sync.RWMutex
+	rand *rand.Rand
 }
 }
 
 
 func (c *httpClusterClient) reset(eps []string) error {
 func (c *httpClusterClient) reset(eps []string) error {
@@ -192,7 +196,9 @@ func (c *httpClusterClient) reset(eps []string) error {
 		neps[i] = *u
 		neps[i] = *u
 	}
 	}
 
 
-	c.endpoints = neps
+	c.endpoints = shuffleEndpoints(c.rand, neps)
+	// TODO: pin old endpoint if possible, and rebalance when new endpoint appears
+	c.pinned = 0
 
 
 	return nil
 	return nil
 }
 }
@@ -203,6 +209,7 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
 	leps := len(c.endpoints)
 	leps := len(c.endpoints)
 	eps := make([]url.URL, leps)
 	eps := make([]url.URL, leps)
 	n := copy(eps, c.endpoints)
 	n := copy(eps, c.endpoints)
+	pinned := c.pinned
 
 
 	if c.credentials != nil {
 	if c.credentials != nil {
 		action = &authedAction{
 		action = &authedAction{
@@ -224,8 +231,9 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
 	var body []byte
 	var body []byte
 	var err error
 	var err error
 
 
-	for _, ep := range eps {
-		hc := c.clientFactory(ep)
+	for i := pinned; i < leps+pinned; i++ {
+		k := i % leps
+		hc := c.clientFactory(eps[k])
 		resp, body, err = hc.Do(ctx, action)
 		resp, body, err = hc.Do(ctx, action)
 		if err != nil {
 		if err != nil {
 			if err == context.DeadlineExceeded || err == context.Canceled {
 			if err == context.DeadlineExceeded || err == context.Canceled {
@@ -236,6 +244,11 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
 		if resp.StatusCode/100 == 5 {
 		if resp.StatusCode/100 == 5 {
 			continue
 			continue
 		}
 		}
+		if k != pinned {
+			c.Lock()
+			c.pinned = k
+			c.Unlock()
+		}
 		break
 		break
 	}
 	}
 
 
@@ -401,3 +414,12 @@ func (r *redirectedHTTPAction) HTTPRequest(ep url.URL) *http.Request {
 	orig.URL = &r.location
 	orig.URL = &r.location
 	return orig
 	return orig
 }
 }
+
+func shuffleEndpoints(r *rand.Rand, eps []url.URL) []url.URL {
+	p := r.Perm(len(eps))
+	neps := make([]url.URL, len(eps))
+	for i, k := range p {
+		neps[i] = eps[k]
+	}
+	return neps
+}

+ 53 - 9
client/client_test.go

@@ -18,9 +18,11 @@ import (
 	"errors"
 	"errors"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
+	"math/rand"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 	"reflect"
 	"reflect"
+	"sort"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -299,9 +301,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
 	fakeErr := errors.New("fake!")
 	fakeErr := errors.New("fake!")
 	fakeURL := url.URL{}
 	fakeURL := url.URL{}
 	tests := []struct {
 	tests := []struct {
-		client   *httpClusterClient
-		wantCode int
-		wantErr  error
+		client     *httpClusterClient
+		wantCode   int
+		wantErr    error
+		wantPinned int
 	}{
 	}{
 		// first good response short-circuits Do
 		// first good response short-circuits Do
 		{
 		{
@@ -313,6 +316,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{err: fakeErr},
 						staticHTTPResponse{err: fakeErr},
 					},
 					},
 				),
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
 			},
 			wantCode: http.StatusTeapot,
 			wantCode: http.StatusTeapot,
 		},
 		},
@@ -327,8 +331,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 					},
 					},
 				),
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
 			},
-			wantCode: http.StatusTeapot,
+			wantCode:   http.StatusTeapot,
+			wantPinned: 1,
 		},
 		},
 
 
 		// context.DeadlineExceeded short-circuits Do
 		// context.DeadlineExceeded short-circuits Do
@@ -341,6 +347,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 					},
 					},
 				),
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
 			},
 			wantErr: context.DeadlineExceeded,
 			wantErr: context.DeadlineExceeded,
 		},
 		},
@@ -355,6 +362,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 					},
 					},
 				),
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
 			},
 			wantErr: context.Canceled,
 			wantErr: context.Canceled,
 		},
 		},
@@ -364,6 +372,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 			client: &httpClusterClient{
 			client: &httpClusterClient{
 				endpoints:     []url.URL{},
 				endpoints:     []url.URL{},
 				clientFactory: newHTTPClientFactory(nil, nil),
 				clientFactory: newHTTPClientFactory(nil, nil),
+				rand:          rand.New(rand.NewSource(0)),
 			},
 			},
 			wantErr: ErrNoEndpoints,
 			wantErr: ErrNoEndpoints,
 		},
 		},
@@ -378,6 +387,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{err: fakeErr},
 						staticHTTPResponse{err: fakeErr},
 					},
 					},
 				),
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
 			},
 			wantErr: fakeErr,
 			wantErr: fakeErr,
 		},
 		},
@@ -392,8 +402,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 						staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
 					},
 					},
 				),
 				),
+				rand: rand.New(rand.NewSource(0)),
 			},
 			},
-			wantCode: http.StatusTeapot,
+			wantCode:   http.StatusTeapot,
+			wantPinned: 1,
 		},
 		},
 	}
 	}
 
 
@@ -415,6 +427,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
 			t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
 			t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
 			continue
 			continue
 		}
 		}
+
+		if tt.client.pinned != tt.wantPinned {
+			t.Errorf("#%d: pinned=%d, want=%d", i, tt.client.pinned, tt.wantPinned)
+		}
 	}
 	}
 }
 }
 
 
@@ -671,7 +687,10 @@ func TestHTTPClusterClientSync(t *testing.T) {
 		},
 		},
 	})
 	})
 
 
-	hc := &httpClusterClient{clientFactory: cf}
+	hc := &httpClusterClient{
+		clientFactory: cf,
+		rand:          rand.New(rand.NewSource(0)),
+	}
 	err := hc.reset([]string{"http://127.0.0.1:2379"})
 	err := hc.reset([]string{"http://127.0.0.1:2379"})
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected error during setup: %#v", err)
 		t.Fatalf("unexpected error during setup: %#v", err)
@@ -688,8 +707,9 @@ func TestHTTPClusterClientSync(t *testing.T) {
 		t.Fatalf("unexpected error during Sync: %#v", err)
 		t.Fatalf("unexpected error during Sync: %#v", err)
 	}
 	}
 
 
-	want = []string{"http://127.0.0.1:4003", "http://127.0.0.1:2379", "http://127.0.0.1:4001", "http://127.0.0.1:4002"}
+	want = []string{"http://127.0.0.1:2379", "http://127.0.0.1:4001", "http://127.0.0.1:4002", "http://127.0.0.1:4003"}
 	got = hc.Endpoints()
 	got = hc.Endpoints()
+	sort.Sort(sort.StringSlice(got))
 	if !reflect.DeepEqual(want, got) {
 	if !reflect.DeepEqual(want, got) {
 		t.Fatalf("incorrect endpoints post-Sync: want=%#v got=%#v", want, got)
 		t.Fatalf("incorrect endpoints post-Sync: want=%#v got=%#v", want, got)
 	}
 	}
@@ -711,7 +731,10 @@ func TestHTTPClusterClientSyncFail(t *testing.T) {
 		staticHTTPResponse{err: errors.New("fail!")},
 		staticHTTPResponse{err: errors.New("fail!")},
 	})
 	})
 
 
-	hc := &httpClusterClient{clientFactory: cf}
+	hc := &httpClusterClient{
+		clientFactory: cf,
+		rand:          rand.New(rand.NewSource(0)),
+	}
 	err := hc.reset([]string{"http://127.0.0.1:2379"})
 	err := hc.reset([]string{"http://127.0.0.1:2379"})
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected error during setup: %#v", err)
 		t.Fatalf("unexpected error during setup: %#v", err)
@@ -744,10 +767,31 @@ func TestHTTPClusterClientResetFail(t *testing.T) {
 	}
 	}
 
 
 	for i, tt := range tests {
 	for i, tt := range tests {
-		hc := &httpClusterClient{}
+		hc := &httpClusterClient{rand: rand.New(rand.NewSource(0))}
 		err := hc.reset(tt)
 		err := hc.reset(tt)
 		if err == nil {
 		if err == nil {
 			t.Errorf("#%d: expected non-nil error", i)
 			t.Errorf("#%d: expected non-nil error", i)
 		}
 		}
 	}
 	}
 }
 }
+
+func TestHTTPClusterClientResetPinRandom(t *testing.T) {
+	round := 2000
+	pinNum := 0
+	for i := 0; i < round; i++ {
+		hc := &httpClusterClient{rand: rand.New(rand.NewSource(int64(i)))}
+		err := hc.reset([]string{"http://127.0.0.1:4001", "http://127.0.0.1:4002", "http://127.0.0.1:4003"})
+		if err != nil {
+			t.Fatalf("#%d: reset error (%v)", i, err)
+		}
+		if hc.endpoints[hc.pinned].String() == "http://127.0.0.1:4001" {
+			pinNum++
+		}
+	}
+
+	min := 1.0/3.0 - 0.05
+	max := 1.0/3.0 + 0.05
+	if ratio := float64(pinNum) / float64(round); ratio > max || ratio < min {
+		t.Errorf("pinned ratio = %v, want [%v, %v]", ratio, min, max)
+	}
+}