Browse Source

Merge pull request #1700 from yichengq/222

etcdserver: add sender tests
Yicheng Qin 11 years ago
parent
commit
9716def94b
2 changed files with 194 additions and 7 deletions
  1. 8 7
      etcdserver/sender.go
  2. 186 0
      etcdserver/sender_test.go

+ 8 - 7
etcdserver/sender.go

@@ -100,9 +100,8 @@ func (h *sendHub) Add(m *Member) {
 	}
 	// TODO: considering how to switch between all available peer urls
 	u := fmt.Sprintf("%s%s", m.PickPeerURL(), raftPrefix)
-	c := &http.Client{Transport: h.tr}
 	fs := h.ls.Follower(m.ID.String())
-	s := newSender(u, h.cl.ID(), c, fs)
+	s := newSender(h.tr, u, h.cl.ID(), fs)
 	h.senders[m.ID] = s
 }
 
@@ -129,19 +128,19 @@ func (h *sendHub) Update(m *Member) {
 }
 
 type sender struct {
+	tr  http.RoundTripper
 	u   string
 	cid types.ID
-	c   *http.Client
 	fs  *stats.FollowerStats
 	q   chan []byte
 	mu  sync.RWMutex
 }
 
-func newSender(u string, cid types.ID, c *http.Client, fs *stats.FollowerStats) *sender {
+func newSender(tr http.RoundTripper, u string, cid types.ID, fs *stats.FollowerStats) *sender {
 	s := &sender{
+		tr:  tr,
 		u:   u,
 		cid: cid,
-		c:   c,
 		fs:  fs,
 		q:   make(chan []byte),
 	}
@@ -151,11 +150,13 @@ func newSender(u string, cid types.ID, c *http.Client, fs *stats.FollowerStats)
 	return s
 }
 
-func (s *sender) send(data []byte) {
+func (s *sender) send(data []byte) error {
 	select {
 	case s.q <- data:
+		return nil
 	default:
 		log.Printf("sender: reach the maximal serving to %s", s.u)
+		return fmt.Errorf("reach maximal serving")
 	}
 }
 
@@ -188,7 +189,7 @@ func (s *sender) post(data []byte) error {
 	}
 	req.Header.Set("Content-Type", "application/protobuf")
 	req.Header.Set("X-Etcd-Cluster-ID", s.cid.String())
-	resp, err := s.c.Do(req)
+	resp, err := s.tr.RoundTrip(req)
 	if err != nil {
 		return fmt.Errorf("error posting to %q: %v", req.URL.String(), err)
 	}

+ 186 - 0
etcdserver/sender_test.go

@@ -17,9 +17,15 @@
 package etcdserver
 
 import (
+	"errors"
+	"io/ioutil"
+	"net/http"
+	"sync"
 	"testing"
+	"time"
 
 	"github.com/coreos/etcd/etcdserver/stats"
+	"github.com/coreos/etcd/pkg/testutil"
 	"github.com/coreos/etcd/pkg/types"
 )
 
@@ -82,3 +88,183 @@ func TestSendHubRemove(t *testing.T) {
 		t.Fatalf("senders[1] exists, want removed")
 	}
 }
+
+// TestSenderSend tests that send func could post data using roundtripper
+// and increase success count in stats.
+func TestSenderSend(t *testing.T) {
+	tr := &roundTripperRecorder{}
+	fs := &stats.FollowerStats{}
+	s := newSender(tr, "http://10.0.0.1", types.ID(1), fs)
+	// wait for handle goroutines start
+	// TODO: wait for goroutines ready before return newSender
+	time.Sleep(10 * time.Millisecond)
+	if err := s.send([]byte("some data")); err != nil {
+		t.Fatalf("unexpect send error: %v", err)
+	}
+	s.stop()
+	// wait for goroutines end
+	// TODO: elegant stop
+	time.Sleep(10 * time.Millisecond)
+
+	if tr.Request() == nil {
+		t.Errorf("sender fails to post the data")
+	}
+	fs.Lock()
+	defer fs.Unlock()
+	if fs.Counts.Success != 1 {
+		t.Errorf("success = %d, want 1", fs.Counts.Success)
+	}
+}
+
+func TestSenderExceedMaximalServing(t *testing.T) {
+	tr := newRoundTripperBlocker()
+	fs := &stats.FollowerStats{}
+	s := newSender(tr, "http://10.0.0.1", types.ID(1), fs)
+	// wait for handle goroutines start
+	// TODO: wait for goroutines ready before return newSender
+	time.Sleep(10 * time.Millisecond)
+	// It could handle that many requests at the same time.
+	for i := 0; i < connPerSender; i++ {
+		if err := s.send([]byte("some data")); err != nil {
+			t.Errorf("send err = %v, want nil", err)
+		}
+	}
+	// This one exceeds its maximal serving ability
+	if err := s.send([]byte("some data")); err == nil {
+		t.Errorf("unexpect send success")
+	}
+	tr.unblock()
+	// Make handles finish their post
+	testutil.ForceGosched()
+	// It could send new data after previous ones succeed
+	if err := s.send([]byte("some data")); err != nil {
+		t.Errorf("send err = %v, want nil", err)
+	}
+	s.stop()
+}
+
+// TestSenderSendFailed tests that when send func meets the post error,
+// it increases fail count in stats.
+func TestSenderSendFailed(t *testing.T) {
+	fs := &stats.FollowerStats{}
+	s := newSender(newRespRoundTripper(0, errors.New("blah")), "http://10.0.0.1", types.ID(1), fs)
+	// wait for handle goroutines start
+	// TODO: wait for goroutines ready before return newSender
+	time.Sleep(10 * time.Millisecond)
+	if err := s.send([]byte("some data")); err != nil {
+		t.Fatalf("unexpect send error: %v", err)
+	}
+	s.stop()
+	// wait for goroutines end
+	// TODO: elegant stop
+	time.Sleep(10 * time.Millisecond)
+
+	fs.Lock()
+	defer fs.Unlock()
+	if fs.Counts.Fail != 1 {
+		t.Errorf("fail = %d, want 1", fs.Counts.Fail)
+	}
+}
+
+func TestSenderPost(t *testing.T) {
+	tr := &roundTripperRecorder{}
+	s := newSender(tr, "http://10.0.0.1", types.ID(1), nil)
+	if err := s.post([]byte("some data")); err != nil {
+		t.Fatalf("unexpect post error: %v", err)
+	}
+	s.stop()
+
+	if g := tr.Request().Method; g != "POST" {
+		t.Errorf("method = %s, want %s", g, "POST")
+	}
+	if g := tr.Request().URL.String(); g != "http://10.0.0.1" {
+		t.Errorf("url = %s, want %s", g, "http://10.0.0.1")
+	}
+	if g := tr.Request().Header.Get("Content-Type"); g != "application/protobuf" {
+		t.Errorf("content type = %s, want %s", g, "application/protobuf")
+	}
+	if g := tr.Request().Header.Get("X-Etcd-Cluster-ID"); g != "1" {
+		t.Errorf("cluster id = %s, want %s", g, "1")
+	}
+	b, err := ioutil.ReadAll(tr.Request().Body)
+	if err != nil {
+		t.Fatalf("unexpected ReadAll error: %v", err)
+	}
+	if string(b) != "some data" {
+		t.Errorf("body = %s, want %s", b, "some data")
+	}
+}
+
+func TestSenderPostBad(t *testing.T) {
+	tests := []struct {
+		u    string
+		code int
+		err  error
+	}{
+		// bad url
+		{":bad url", http.StatusNoContent, nil},
+		// RoundTrip returns error
+		{"http://10.0.0.1", 0, errors.New("blah")},
+		// unexpected response status code
+		{"http://10.0.0.1", http.StatusOK, nil},
+		{"http://10.0.0.1", http.StatusCreated, nil},
+	}
+	for i, tt := range tests {
+		s := newSender(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), nil)
+		err := s.post([]byte("some data"))
+		s.stop()
+
+		if err == nil {
+			t.Errorf("#%d: err = nil, want not nil", i)
+		}
+	}
+}
+
+type roundTripperBlocker struct {
+	c chan struct{}
+}
+
+func newRoundTripperBlocker() *roundTripperBlocker {
+	return &roundTripperBlocker{c: make(chan struct{})}
+}
+func (t *roundTripperBlocker) RoundTrip(req *http.Request) (*http.Response, error) {
+	<-t.c
+	return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
+}
+func (t *roundTripperBlocker) unblock() {
+	close(t.c)
+}
+
+type respRoundTripper struct {
+	code int
+	err  error
+}
+
+func newRespRoundTripper(code int, err error) *respRoundTripper {
+	return &respRoundTripper{code: code, err: err}
+}
+func (t *respRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+	return &http.Response{StatusCode: t.code, Body: &nopReadCloser{}}, t.err
+}
+
+type roundTripperRecorder struct {
+	req *http.Request
+	sync.Mutex
+}
+
+func (t *roundTripperRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
+	t.Lock()
+	defer t.Unlock()
+	t.req = req
+	return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
+}
+func (t *roundTripperRecorder) Request() *http.Request {
+	t.Lock()
+	defer t.Unlock()
+	return t.req
+}
+
+type nopReadCloser struct{}
+
+func (n *nopReadCloser) Read(p []byte) (int, error) { return 0, nil }
+func (n *nopReadCloser) Close() error               { return nil }