Browse Source

Merge remote-tracking branch 'coreos/master' into log-storage-interface

* coreos/master:
  etcdserver: add sender tests
  raft: Only call stableTo when we have ready entries or a snapshot.
  etcdserver: add ID() function to the Server interface.
  sender: use RoundTripper instead of Client in sender
Ben Darnell 11 years ago
parent
commit
39eddd8565

+ 6 - 4
etcdserver/etcdhttp/client_test.go

@@ -92,6 +92,9 @@ type serverRecorder struct {
 	actions []action
 }
 
+func (s *serverRecorder) Start()       {}
+func (s *serverRecorder) Stop()        {}
+func (s *serverRecorder) ID() types.ID { return types.ID(1) }
 func (s *serverRecorder) Do(_ context.Context, r etcdserverpb.Request) (etcdserver.Response, error) {
 	s.actions = append(s.actions, action{name: "Do", params: []interface{}{r}})
 	return etcdserver.Response{}, nil
@@ -100,8 +103,6 @@ func (s *serverRecorder) Process(_ context.Context, m raftpb.Message) error {
 	s.actions = append(s.actions, action{name: "Process", params: []interface{}{m}})
 	return nil
 }
-func (s *serverRecorder) Start() {}
-func (s *serverRecorder) Stop()  {}
 func (s *serverRecorder) AddMember(_ context.Context, m etcdserver.Member) error {
 	s.actions = append(s.actions, action{name: "AddMember", params: []interface{}{m}})
 	return nil
@@ -138,12 +139,13 @@ type resServer struct {
 	res etcdserver.Response
 }
 
+func (rs *resServer) Start()       {}
+func (rs *resServer) Stop()        {}
+func (rs *resServer) ID() types.ID { return types.ID(1) }
 func (rs *resServer) Do(_ context.Context, _ etcdserverpb.Request) (etcdserver.Response, error) {
 	return rs.res, nil
 }
 func (rs *resServer) Process(_ context.Context, _ raftpb.Message) error         { return nil }
-func (rs *resServer) Start()                                                    {}
-func (rs *resServer) Stop()                                                     {}
 func (rs *resServer) AddMember(_ context.Context, _ etcdserver.Member) error    { return nil }
 func (rs *resServer) RemoveMember(_ context.Context, _ uint64) error            { return nil }
 func (rs *resServer) UpdateMember(_ context.Context, _ etcdserver.Member) error { return nil }

+ 3 - 2
etcdserver/etcdhttp/http_test.go

@@ -65,14 +65,15 @@ type errServer struct {
 	err error
 }
 
+func (fs *errServer) Start()       {}
+func (fs *errServer) Stop()        {}
+func (fs *errServer) ID() types.ID { return types.ID(1) }
 func (fs *errServer) Do(ctx context.Context, r etcdserverpb.Request) (etcdserver.Response, error) {
 	return etcdserver.Response{}, fs.err
 }
 func (fs *errServer) Process(ctx context.Context, m raftpb.Message) error {
 	return fs.err
 }
-func (fs *errServer) Start() {}
-func (fs *errServer) Stop()  {}
 func (fs *errServer) AddMember(ctx context.Context, m etcdserver.Member) error {
 	return fs.err
 }

+ 8 - 7
etcdserver/sender.go

@@ -100,9 +100,8 @@ func (h *sendHub) Add(m *Member) {
 	}
 	// TODO: considering how to switch between all available peer urls
 	u := fmt.Sprintf("%s%s", m.PickPeerURL(), raftPrefix)
-	c := &http.Client{Transport: h.tr}
 	fs := h.ls.Follower(m.ID.String())
-	s := newSender(u, h.cl.ID(), c, fs)
+	s := newSender(h.tr, u, h.cl.ID(), fs)
 	h.senders[m.ID] = s
 }
 
@@ -129,19 +128,19 @@ func (h *sendHub) Update(m *Member) {
 }
 
 type sender struct {
+	tr  http.RoundTripper
 	u   string
 	cid types.ID
-	c   *http.Client
 	fs  *stats.FollowerStats
 	q   chan []byte
 	mu  sync.RWMutex
 }
 
-func newSender(u string, cid types.ID, c *http.Client, fs *stats.FollowerStats) *sender {
+func newSender(tr http.RoundTripper, u string, cid types.ID, fs *stats.FollowerStats) *sender {
 	s := &sender{
+		tr:  tr,
 		u:   u,
 		cid: cid,
-		c:   c,
 		fs:  fs,
 		q:   make(chan []byte),
 	}
@@ -151,11 +150,13 @@ func newSender(u string, cid types.ID, c *http.Client, fs *stats.FollowerStats)
 	return s
 }
 
-func (s *sender) send(data []byte) {
+func (s *sender) send(data []byte) error {
 	select {
 	case s.q <- data:
+		return nil
 	default:
 		log.Printf("sender: reach the maximal serving to %s", s.u)
+		return fmt.Errorf("reach maximal serving")
 	}
 }
 
@@ -188,7 +189,7 @@ func (s *sender) post(data []byte) error {
 	}
 	req.Header.Set("Content-Type", "application/protobuf")
 	req.Header.Set("X-Etcd-Cluster-ID", s.cid.String())
-	resp, err := s.c.Do(req)
+	resp, err := s.tr.RoundTrip(req)
 	if err != nil {
 		return fmt.Errorf("error posting to %q: %v", req.URL.String(), err)
 	}

+ 186 - 0
etcdserver/sender_test.go

@@ -17,9 +17,15 @@
 package etcdserver
 
 import (
+	"errors"
+	"io/ioutil"
+	"net/http"
+	"sync"
 	"testing"
+	"time"
 
 	"github.com/coreos/etcd/etcdserver/stats"
+	"github.com/coreos/etcd/pkg/testutil"
 	"github.com/coreos/etcd/pkg/types"
 )
 
@@ -82,3 +88,183 @@ func TestSendHubRemove(t *testing.T) {
 		t.Fatalf("senders[1] exists, want removed")
 	}
 }
+
+// TestSenderSend tests that send func could post data using roundtripper
+// and increase success count in stats.
+func TestSenderSend(t *testing.T) {
+	tr := &roundTripperRecorder{}
+	fs := &stats.FollowerStats{}
+	s := newSender(tr, "http://10.0.0.1", types.ID(1), fs)
+	// wait for handle goroutines start
+	// TODO: wait for goroutines ready before return newSender
+	time.Sleep(10 * time.Millisecond)
+	if err := s.send([]byte("some data")); err != nil {
+		t.Fatalf("unexpect send error: %v", err)
+	}
+	s.stop()
+	// wait for goroutines end
+	// TODO: elegant stop
+	time.Sleep(10 * time.Millisecond)
+
+	if tr.Request() == nil {
+		t.Errorf("sender fails to post the data")
+	}
+	fs.Lock()
+	defer fs.Unlock()
+	if fs.Counts.Success != 1 {
+		t.Errorf("success = %d, want 1", fs.Counts.Success)
+	}
+}
+
+func TestSenderExceedMaximalServing(t *testing.T) {
+	tr := newRoundTripperBlocker()
+	fs := &stats.FollowerStats{}
+	s := newSender(tr, "http://10.0.0.1", types.ID(1), fs)
+	// wait for handle goroutines start
+	// TODO: wait for goroutines ready before return newSender
+	time.Sleep(10 * time.Millisecond)
+	// It could handle that many requests at the same time.
+	for i := 0; i < connPerSender; i++ {
+		if err := s.send([]byte("some data")); err != nil {
+			t.Errorf("send err = %v, want nil", err)
+		}
+	}
+	// This one exceeds its maximal serving ability
+	if err := s.send([]byte("some data")); err == nil {
+		t.Errorf("unexpect send success")
+	}
+	tr.unblock()
+	// Make handles finish their post
+	testutil.ForceGosched()
+	// It could send new data after previous ones succeed
+	if err := s.send([]byte("some data")); err != nil {
+		t.Errorf("send err = %v, want nil", err)
+	}
+	s.stop()
+}
+
+// TestSenderSendFailed tests that when send func meets the post error,
+// it increases fail count in stats.
+func TestSenderSendFailed(t *testing.T) {
+	fs := &stats.FollowerStats{}
+	s := newSender(newRespRoundTripper(0, errors.New("blah")), "http://10.0.0.1", types.ID(1), fs)
+	// wait for handle goroutines start
+	// TODO: wait for goroutines ready before return newSender
+	time.Sleep(10 * time.Millisecond)
+	if err := s.send([]byte("some data")); err != nil {
+		t.Fatalf("unexpect send error: %v", err)
+	}
+	s.stop()
+	// wait for goroutines end
+	// TODO: elegant stop
+	time.Sleep(10 * time.Millisecond)
+
+	fs.Lock()
+	defer fs.Unlock()
+	if fs.Counts.Fail != 1 {
+		t.Errorf("fail = %d, want 1", fs.Counts.Fail)
+	}
+}
+
+func TestSenderPost(t *testing.T) {
+	tr := &roundTripperRecorder{}
+	s := newSender(tr, "http://10.0.0.1", types.ID(1), nil)
+	if err := s.post([]byte("some data")); err != nil {
+		t.Fatalf("unexpect post error: %v", err)
+	}
+	s.stop()
+
+	if g := tr.Request().Method; g != "POST" {
+		t.Errorf("method = %s, want %s", g, "POST")
+	}
+	if g := tr.Request().URL.String(); g != "http://10.0.0.1" {
+		t.Errorf("url = %s, want %s", g, "http://10.0.0.1")
+	}
+	if g := tr.Request().Header.Get("Content-Type"); g != "application/protobuf" {
+		t.Errorf("content type = %s, want %s", g, "application/protobuf")
+	}
+	if g := tr.Request().Header.Get("X-Etcd-Cluster-ID"); g != "1" {
+		t.Errorf("cluster id = %s, want %s", g, "1")
+	}
+	b, err := ioutil.ReadAll(tr.Request().Body)
+	if err != nil {
+		t.Fatalf("unexpected ReadAll error: %v", err)
+	}
+	if string(b) != "some data" {
+		t.Errorf("body = %s, want %s", b, "some data")
+	}
+}
+
+func TestSenderPostBad(t *testing.T) {
+	tests := []struct {
+		u    string
+		code int
+		err  error
+	}{
+		// bad url
+		{":bad url", http.StatusNoContent, nil},
+		// RoundTrip returns error
+		{"http://10.0.0.1", 0, errors.New("blah")},
+		// unexpected response status code
+		{"http://10.0.0.1", http.StatusOK, nil},
+		{"http://10.0.0.1", http.StatusCreated, nil},
+	}
+	for i, tt := range tests {
+		s := newSender(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), nil)
+		err := s.post([]byte("some data"))
+		s.stop()
+
+		if err == nil {
+			t.Errorf("#%d: err = nil, want not nil", i)
+		}
+	}
+}
+
+type roundTripperBlocker struct {
+	c chan struct{}
+}
+
+func newRoundTripperBlocker() *roundTripperBlocker {
+	return &roundTripperBlocker{c: make(chan struct{})}
+}
+func (t *roundTripperBlocker) RoundTrip(req *http.Request) (*http.Response, error) {
+	<-t.c
+	return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
+}
+func (t *roundTripperBlocker) unblock() {
+	close(t.c)
+}
+
+type respRoundTripper struct {
+	code int
+	err  error
+}
+
+func newRespRoundTripper(code int, err error) *respRoundTripper {
+	return &respRoundTripper{code: code, err: err}
+}
+func (t *respRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+	return &http.Response{StatusCode: t.code, Body: &nopReadCloser{}}, t.err
+}
+
+type roundTripperRecorder struct {
+	req *http.Request
+	sync.Mutex
+}
+
+func (t *roundTripperRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
+	t.Lock()
+	defer t.Unlock()
+	t.req = req
+	return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
+}
+func (t *roundTripperRecorder) Request() *http.Request {
+	t.Lock()
+	defer t.Unlock()
+	return t.req
+}
+
+type nopReadCloser struct{}
+
+func (n *nopReadCloser) Read(p []byte) (int, error) { return 0, nil }
+func (n *nopReadCloser) Close() error               { return nil }

+ 4 - 0
etcdserver/server.go

@@ -115,6 +115,8 @@ type Server interface {
 	// Stop terminates the Server and performs any necessary finalization.
 	// Do and Process cannot be called after Stop has been invoked.
 	Stop()
+	// ID returns the ID of the Server.
+	ID() types.ID
 	// Do takes a request and attempts to fulfill it, returning a Response.
 	Do(ctx context.Context, r pb.Request) (Response, error)
 	// Process takes a raft message and applies it to the server's raft state
@@ -309,6 +311,8 @@ func (s *EtcdServer) start() {
 	go s.run()
 }
 
+func (s *EtcdServer) ID() types.ID { return s.id }
+
 func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
 	if s.Cluster.IsIDRemoved(types.ID(m.From)) {
 		return ErrRemoved

+ 7 - 1
raft/node.go

@@ -225,6 +225,7 @@ func (n *node) run(r *raft) {
 	var readyc chan Ready
 	var advancec chan struct{}
 	var prevLastUnstablei uint64
+	var havePrevLastUnstablei bool
 	var rd Ready
 
 	lead := None
@@ -293,6 +294,7 @@ func (n *node) run(r *raft) {
 			}
 			if len(rd.Entries) > 0 {
 				prevLastUnstablei = rd.Entries[len(rd.Entries)-1].Index
+				havePrevLastUnstablei = true
 			}
 			if !IsEmptyHardState(rd.HardState) {
 				prevHardSt = rd.HardState
@@ -301,6 +303,7 @@ func (n *node) run(r *raft) {
 				prevSnapi = rd.Snapshot.Index
 				if prevSnapi > prevLastUnstablei {
 					prevLastUnstablei = prevSnapi
+					havePrevLastUnstablei = true
 				}
 			}
 			r.msgs = nil
@@ -309,7 +312,10 @@ func (n *node) run(r *raft) {
 			if prevHardSt.Commit != 0 {
 				r.raftLog.appliedTo(prevHardSt.Commit)
 			}
-			r.raftLog.stableTo(prevLastUnstablei)
+			if havePrevLastUnstablei {
+				r.raftLog.stableTo(prevLastUnstablei)
+				havePrevLastUnstablei = false
+			}
 			advancec = nil
 		case <-n.stop:
 			close(n.done)

+ 4 - 4
raft/node_test.go

@@ -346,7 +346,7 @@ func TestNodeStart(t *testing.T) {
 	select {
 	case rd := <-n.Ready():
 		t.Errorf("unexpected Ready: %+v", rd)
-	default:
+	case <-time.After(time.Millisecond):
 	}
 }
 
@@ -375,7 +375,7 @@ func TestNodeRestart(t *testing.T) {
 	select {
 	case rd := <-n.Ready():
 		t.Errorf("unexpected Ready: %+v", rd)
-	default:
+	case <-time.After(time.Millisecond):
 	}
 }
 
@@ -448,13 +448,13 @@ func TestNodeAdvance(t *testing.T) {
 	select {
 	case rd = <-n.Ready():
 		t.Fatalf("unexpected Ready before Advance: %+v", rd)
-	default:
+	case <-time.After(time.Millisecond):
 	}
 	storage.Append(rd.Entries)
 	n.Advance()
 	select {
 	case <-n.Ready():
-	default:
+	case <-time.After(time.Millisecond):
 		t.Errorf("expect Ready after Advance, but there is no Ready available")
 	}
 }