Browse Source

Merge pull request #1669 from yichengq/215

*: add rafthttp as a separate package
Yicheng Qin 11 years ago
parent
commit
a2c568a144

+ 2 - 59
etcdserver/etcdhttp/peer.go

@@ -18,14 +18,11 @@ package etcdhttp
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
-	"io/ioutil"
 	"log"
 	"log"
 	"net/http"
 	"net/http"
 
 
-	"github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context"
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver"
-	"github.com/coreos/etcd/pkg/types"
-	"github.com/coreos/etcd/raft/raftpb"
+	"github.com/coreos/etcd/rafthttp"
 )
 )
 
 
 const (
 const (
@@ -35,12 +32,7 @@ const (
 
 
 // NewPeerHandler generates an http.Handler to handle etcd peer (raft) requests.
 // NewPeerHandler generates an http.Handler to handle etcd peer (raft) requests.
 func NewPeerHandler(server *etcdserver.EtcdServer) http.Handler {
 func NewPeerHandler(server *etcdserver.EtcdServer) http.Handler {
-	rh := &raftHandler{
-		stats:       server,
-		server:      server,
-		clusterInfo: server.Cluster,
-	}
-
+	rh := rafthttp.NewHandler(server, server.Cluster.ID())
 	mh := &peerMembersHandler{
 	mh := &peerMembersHandler{
 		clusterInfo: server.Cluster,
 		clusterInfo: server.Cluster,
 	}
 	}
@@ -52,55 +44,6 @@ func NewPeerHandler(server *etcdserver.EtcdServer) http.Handler {
 	return mux
 	return mux
 }
 }
 
 
-type raftHandler struct {
-	stats       etcdserver.Stats
-	server      etcdserver.Server
-	clusterInfo etcdserver.ClusterInfo
-}
-
-func (h *raftHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	if !allowMethod(w, r.Method, "POST") {
-		return
-	}
-
-	wcid := h.clusterInfo.ID().String()
-	w.Header().Set("X-Etcd-Cluster-ID", wcid)
-
-	gcid := r.Header.Get("X-Etcd-Cluster-ID")
-	if gcid != wcid {
-		log.Printf("etcdhttp: request ignored due to cluster ID mismatch got %s want %s", gcid, wcid)
-		http.Error(w, "clusterID mismatch", http.StatusPreconditionFailed)
-		return
-	}
-
-	b, err := ioutil.ReadAll(r.Body)
-	if err != nil {
-		log.Println("etcdhttp: error reading raft message:", err)
-		http.Error(w, "error reading raft message", http.StatusBadRequest)
-		return
-	}
-	var m raftpb.Message
-	if err := m.Unmarshal(b); err != nil {
-		log.Println("etcdhttp: error unmarshaling raft message:", err)
-		http.Error(w, "error unmarshaling raft message", http.StatusBadRequest)
-		return
-	}
-	if err := h.server.Process(context.TODO(), m); err != nil {
-		switch err {
-		case etcdserver.ErrRemoved:
-			log.Printf("etcdhttp: reject message from removed member %s", types.ID(m.From).String())
-			http.Error(w, "cannot process message from removed member", http.StatusForbidden)
-		default:
-			writeError(w, err)
-		}
-		return
-	}
-	if m.Type == raftpb.MsgApp {
-		h.stats.UpdateRecvApp(types.ID(m.From), r.ContentLength)
-	}
-	w.WriteHeader(http.StatusNoContent)
-}
-
 type peerMembersHandler struct {
 type peerMembersHandler struct {
 	clusterInfo etcdserver.ClusterInfo
 	clusterInfo etcdserver.ClusterInfo
 }
 }

+ 0 - 150
etcdserver/etcdhttp/peer_test.go

@@ -17,165 +17,15 @@
 package etcdhttp
 package etcdhttp
 
 
 import (
 import (
-	"bytes"
 	"encoding/json"
 	"encoding/json"
-	"errors"
-	"io"
 	"net/http"
 	"net/http"
 	"net/http/httptest"
 	"net/http/httptest"
 	"path"
 	"path"
-	"strings"
 	"testing"
 	"testing"
 
 
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver"
-	"github.com/coreos/etcd/raft/raftpb"
 )
 )
 
 
-func mustMarshalMsg(t *testing.T, m raftpb.Message) []byte {
-	json, err := m.Marshal()
-	if err != nil {
-		t.Fatalf("error marshalling raft Message: %#v", err)
-	}
-	return json
-}
-
-// errReader implements io.Reader to facilitate a broken request.
-type errReader struct{}
-
-func (er *errReader) Read(_ []byte) (int, error) { return 0, errors.New("some error") }
-
-func TestServeRaft(t *testing.T) {
-	testCases := []struct {
-		method    string
-		body      io.Reader
-		serverErr error
-		clusterID string
-
-		wcode int
-	}{
-		{
-			// bad method
-			"GET",
-			bytes.NewReader(
-				mustMarshalMsg(
-					t,
-					raftpb.Message{},
-				),
-			),
-			nil,
-			"0",
-			http.StatusMethodNotAllowed,
-		},
-		{
-			// bad method
-			"PUT",
-			bytes.NewReader(
-				mustMarshalMsg(
-					t,
-					raftpb.Message{},
-				),
-			),
-			nil,
-			"0",
-			http.StatusMethodNotAllowed,
-		},
-		{
-			// bad method
-			"DELETE",
-			bytes.NewReader(
-				mustMarshalMsg(
-					t,
-					raftpb.Message{},
-				),
-			),
-			nil,
-			"0",
-			http.StatusMethodNotAllowed,
-		},
-		{
-			// bad request body
-			"POST",
-			&errReader{},
-			nil,
-			"0",
-			http.StatusBadRequest,
-		},
-		{
-			// bad request protobuf
-			"POST",
-			strings.NewReader("malformed garbage"),
-			nil,
-			"0",
-			http.StatusBadRequest,
-		},
-		{
-			// good request, etcdserver.Server internal error
-			"POST",
-			bytes.NewReader(
-				mustMarshalMsg(
-					t,
-					raftpb.Message{},
-				),
-			),
-			errors.New("some error"),
-			"0",
-			http.StatusInternalServerError,
-		},
-		{
-			// good request from removed member
-			"POST",
-			bytes.NewReader(
-				mustMarshalMsg(
-					t,
-					raftpb.Message{},
-				),
-			),
-			etcdserver.ErrRemoved,
-			"0",
-			http.StatusForbidden,
-		},
-		{
-			// good request
-			"POST",
-			bytes.NewReader(
-				mustMarshalMsg(
-					t,
-					raftpb.Message{},
-				),
-			),
-			nil,
-			"1",
-			http.StatusPreconditionFailed,
-		},
-		{
-			// good request
-			"POST",
-			bytes.NewReader(
-				mustMarshalMsg(
-					t,
-					raftpb.Message{},
-				),
-			),
-			nil,
-			"0",
-			http.StatusNoContent,
-		},
-	}
-	for i, tt := range testCases {
-		req, err := http.NewRequest(tt.method, "foo", tt.body)
-		if err != nil {
-			t.Fatalf("#%d: could not create request: %#v", i, err)
-		}
-		req.Header.Set("X-Etcd-Cluster-ID", tt.clusterID)
-		rw := httptest.NewRecorder()
-		h := &raftHandler{stats: nil, server: &errServer{tt.serverErr}, clusterInfo: &fakeCluster{id: 0}}
-		h.ServeHTTP(rw, req)
-		if rw.Code != tt.wcode {
-			t.Errorf("#%d: got code=%d, want %d", i, rw.Code, tt.wcode)
-		}
-	}
-}
-
 func TestServeMembersFails(t *testing.T) {
 func TestServeMembersFails(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		method string
 		method string

+ 131 - 0
etcdserver/sendhub.go

@@ -0,0 +1,131 @@
+/*
+   Copyright 2014 CoreOS, Inc.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package etcdserver
+
+import (
+	"log"
+	"net/http"
+	"net/url"
+	"path"
+
+	"github.com/coreos/etcd/etcdserver/stats"
+	"github.com/coreos/etcd/pkg/types"
+	"github.com/coreos/etcd/raft/raftpb"
+	"github.com/coreos/etcd/rafthttp"
+)
+
+const (
+	raftPrefix = "/raft"
+)
+
+type sendHub struct {
+	tr         http.RoundTripper
+	cl         ClusterInfo
+	ss         *stats.ServerStats
+	ls         *stats.LeaderStats
+	senders    map[types.ID]rafthttp.Sender
+	shouldstop chan struct{}
+}
+
+// newSendHub creates the default send hub used to transport raft messages
+// to other members. The returned sendHub will update the given ServerStats and
+// LeaderStats appropriately.
+func newSendHub(t http.RoundTripper, cl ClusterInfo, ss *stats.ServerStats, ls *stats.LeaderStats) *sendHub {
+	h := &sendHub{
+		tr:         t,
+		cl:         cl,
+		ss:         ss,
+		ls:         ls,
+		senders:    make(map[types.ID]rafthttp.Sender),
+		shouldstop: make(chan struct{}, 1),
+	}
+	for _, m := range cl.Members() {
+		h.Add(m)
+	}
+	return h
+}
+
+func (h *sendHub) Send(msgs []raftpb.Message) {
+	for _, m := range msgs {
+		to := types.ID(m.To)
+		s, ok := h.senders[to]
+		if !ok {
+			if !h.cl.IsIDRemoved(to) {
+				log.Printf("etcdserver: send message to unknown receiver %s", to)
+			}
+			continue
+		}
+
+		// TODO: don't block. we should be able to have 1000s
+		// of messages out at a time.
+		data, err := m.Marshal()
+		if err != nil {
+			log.Println("sender: dropping message:", err)
+			return // drop bad message
+		}
+		if m.Type == raftpb.MsgApp {
+			h.ss.SendAppendReq(len(data))
+		}
+
+		s.Send(data)
+	}
+}
+
+func (h *sendHub) Stop() {
+	for _, s := range h.senders {
+		s.Stop()
+	}
+}
+
+func (h *sendHub) ShouldStopNotify() <-chan struct{} {
+	return h.shouldstop
+}
+
+func (h *sendHub) Add(m *Member) {
+	if _, ok := h.senders[m.ID]; ok {
+		return
+	}
+	// TODO: considering how to switch between all available peer urls
+	peerURL := m.PickPeerURL()
+	u, err := url.Parse(peerURL)
+	if err != nil {
+		log.Panicf("unexpect peer url %s", peerURL)
+	}
+	u.Path = path.Join(u.Path, raftPrefix)
+	fs := h.ls.Follower(m.ID.String())
+	s := rafthttp.NewSender(h.tr, u.String(), h.cl.ID(), fs, h.shouldstop)
+	h.senders[m.ID] = s
+}
+
+func (h *sendHub) Remove(id types.ID) {
+	h.senders[id].Stop()
+	delete(h.senders, id)
+}
+
+func (h *sendHub) Update(m *Member) {
+	// TODO: return error or just panic?
+	if _, ok := h.senders[m.ID]; !ok {
+		return
+	}
+	peerURL := m.PickPeerURL()
+	u, err := url.Parse(peerURL)
+	if err != nil {
+		log.Panicf("unexpect peer url %s", peerURL)
+	}
+	u.Path = path.Join(u.Path, raftPrefix)
+	h.senders[m.ID].Update(u.String())
+}

+ 126 - 0
etcdserver/sendhub_test.go

@@ -0,0 +1,126 @@
+/*
+   Copyright 2014 CoreOS, Inc.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package etcdserver
+
+import (
+	"net/http"
+	"testing"
+	"time"
+
+	"github.com/coreos/etcd/etcdserver/stats"
+	"github.com/coreos/etcd/pkg/testutil"
+	"github.com/coreos/etcd/pkg/types"
+)
+
+func TestSendHubInitSenders(t *testing.T) {
+	membs := []*Member{
+		newTestMember(1, []string{"http://a"}, "", nil),
+		newTestMember(2, []string{"http://b"}, "", nil),
+		newTestMember(3, []string{"http://c"}, "", nil),
+	}
+	cl := newTestCluster(membs)
+	ls := stats.NewLeaderStats("")
+	h := newSendHub(nil, cl, nil, ls)
+
+	ids := cl.MemberIDs()
+	if len(h.senders) != len(ids) {
+		t.Errorf("len(ids) = %d, want %d", len(h.senders), len(ids))
+	}
+	for _, id := range ids {
+		if _, ok := h.senders[id]; !ok {
+			t.Errorf("senders[%s] is nil, want exists", id)
+		}
+	}
+}
+
+func TestSendHubAdd(t *testing.T) {
+	cl := newTestCluster(nil)
+	ls := stats.NewLeaderStats("")
+	h := newSendHub(nil, cl, nil, ls)
+	m := newTestMember(1, []string{"http://a"}, "", nil)
+	h.Add(m)
+
+	if _, ok := ls.Followers["1"]; !ok {
+		t.Errorf("FollowerStats[1] is nil, want exists")
+	}
+	s, ok := h.senders[types.ID(1)]
+	if !ok {
+		t.Fatalf("senders[1] is nil, want exists")
+	}
+
+	h.Add(m)
+	ns := h.senders[types.ID(1)]
+	if s != ns {
+		t.Errorf("sender = %p, want %p", ns, s)
+	}
+}
+
+func TestSendHubRemove(t *testing.T) {
+	membs := []*Member{
+		newTestMember(1, []string{"http://a"}, "", nil),
+	}
+	cl := newTestCluster(membs)
+	ls := stats.NewLeaderStats("")
+	h := newSendHub(nil, cl, nil, ls)
+	h.Remove(types.ID(1))
+
+	if _, ok := h.senders[types.ID(1)]; ok {
+		t.Fatalf("senders[1] exists, want removed")
+	}
+}
+
+func TestSendHubShouldStop(t *testing.T) {
+	membs := []*Member{
+		newTestMember(1, []string{"http://a"}, "", nil),
+	}
+	tr := newRespRoundTripper(http.StatusForbidden, nil)
+	cl := newTestCluster(membs)
+	ls := stats.NewLeaderStats("")
+	h := newSendHub(tr, cl, nil, ls)
+
+	shouldstop := h.ShouldStopNotify()
+	select {
+	case <-shouldstop:
+		t.Fatalf("received unexpected shouldstop notification")
+	case <-time.After(10 * time.Millisecond):
+	}
+	h.senders[1].Send([]byte("somedata"))
+
+	testutil.ForceGosched()
+	select {
+	case <-shouldstop:
+	default:
+		t.Fatalf("cannot receive stop notification")
+	}
+}
+
+type respRoundTripper struct {
+	code int
+	err  error
+}
+
+func newRespRoundTripper(code int, err error) *respRoundTripper {
+	return &respRoundTripper{code: code, err: err}
+}
+func (t *respRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+	return &http.Response{StatusCode: t.code, Body: &nopReadCloser{}}, t.err
+}
+
+type nopReadCloser struct{}
+
+func (n *nopReadCloser) Read(p []byte) (int, error) { return 0, nil }
+func (n *nopReadCloser) Close() error               { return nil }

+ 6 - 8
etcdserver/server.go

@@ -33,6 +33,7 @@ import (
 
 
 	"github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context"
 	"github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context"
 	"github.com/coreos/etcd/discovery"
 	"github.com/coreos/etcd/discovery"
+	"github.com/coreos/etcd/etcdserver/etcdhttp/httptypes"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
 	"github.com/coreos/etcd/etcdserver/stats"
 	"github.com/coreos/etcd/etcdserver/stats"
 	"github.com/coreos/etcd/pkg/pbutil"
 	"github.com/coreos/etcd/pkg/pbutil"
@@ -61,7 +62,6 @@ const (
 var (
 var (
 	ErrUnknownMethod = errors.New("etcdserver: unknown method")
 	ErrUnknownMethod = errors.New("etcdserver: unknown method")
 	ErrStopped       = errors.New("etcdserver: server stopped")
 	ErrStopped       = errors.New("etcdserver: server stopped")
-	ErrRemoved       = errors.New("etcdserver: server removed")
 	ErrIDRemoved     = errors.New("etcdserver: ID removed")
 	ErrIDRemoved     = errors.New("etcdserver: ID removed")
 	ErrIDExists      = errors.New("etcdserver: ID exists")
 	ErrIDExists      = errors.New("etcdserver: ID exists")
 	ErrIDNotFound    = errors.New("etcdserver: ID not found")
 	ErrIDNotFound    = errors.New("etcdserver: ID not found")
@@ -145,8 +145,6 @@ type Stats interface {
 	LeaderStats() []byte
 	LeaderStats() []byte
 	// StoreStats returns statistics of the store backing this EtcdServer
 	// StoreStats returns statistics of the store backing this EtcdServer
 	StoreStats() []byte
 	StoreStats() []byte
-	// UpdateRecvApp updates the underlying statistics in response to a receiving an Append request
-	UpdateRecvApp(from types.ID, length int64)
 }
 }
 
 
 type RaftTimer interface {
 type RaftTimer interface {
@@ -320,7 +318,11 @@ func (s *EtcdServer) ID() types.ID { return s.id }
 
 
 func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
 func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
 	if s.Cluster.IsIDRemoved(types.ID(m.From)) {
 	if s.Cluster.IsIDRemoved(types.ID(m.From)) {
-		return ErrRemoved
+		log.Printf("etcdserver: reject message from removed member %s", types.ID(m.From).String())
+		return httptypes.NewHTTPError(http.StatusForbidden, "cannot process message from removed member")
+	}
+	if m.Type == raftpb.MsgApp {
+		s.stats.RecvAppendReq(types.ID(m.From).String(), m.Size())
 	}
 	}
 	return s.node.Step(ctx, m)
 	return s.node.Step(ctx, m)
 }
 }
@@ -488,10 +490,6 @@ func (s *EtcdServer) LeaderStats() []byte {
 
 
 func (s *EtcdServer) StoreStats() []byte { return s.store.JsonStats() }
 func (s *EtcdServer) StoreStats() []byte { return s.store.JsonStats() }
 
 
-func (s *EtcdServer) UpdateRecvApp(from types.ID, length int64) {
-	s.stats.RecvAppendReq(from.String(), int(length))
-}
-
 func (s *EtcdServer) AddMember(ctx context.Context, memb Member) error {
 func (s *EtcdServer) AddMember(ctx context.Context, memb Member) error {
 	// TODO: move Member to protobuf type
 	// TODO: move Member to protobuf type
 	b, err := json.Marshal(memb)
 	b, err := json.Marshal(memb)

+ 90 - 0
rafthttp/http.go

@@ -0,0 +1,90 @@
+/*
+   Copyright 2014 CoreOS, Inc.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package rafthttp
+
+import (
+	"io/ioutil"
+	"log"
+	"net/http"
+
+	"github.com/coreos/etcd/pkg/types"
+	"github.com/coreos/etcd/raft/raftpb"
+
+	"github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context"
+)
+
+type Processor interface {
+	Process(ctx context.Context, m raftpb.Message) error
+}
+
+func NewHandler(p Processor, cid types.ID) http.Handler {
+	return &handler{
+		p:   p,
+		cid: cid,
+	}
+}
+
+type handler struct {
+	p   Processor
+	cid types.ID
+}
+
+func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	if r.Method != "POST" {
+		w.Header().Set("Allow", "POST")
+		http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
+		return
+	}
+
+	wcid := h.cid.String()
+	w.Header().Set("X-Etcd-Cluster-ID", wcid)
+
+	gcid := r.Header.Get("X-Etcd-Cluster-ID")
+	if gcid != wcid {
+		log.Printf("rafthttp: request ignored due to cluster ID mismatch got %s want %s", gcid, wcid)
+		http.Error(w, "clusterID mismatch", http.StatusPreconditionFailed)
+		return
+	}
+
+	b, err := ioutil.ReadAll(r.Body)
+	if err != nil {
+		log.Println("rafthttp: error reading raft message:", err)
+		http.Error(w, "error reading raft message", http.StatusBadRequest)
+		return
+	}
+	var m raftpb.Message
+	if err := m.Unmarshal(b); err != nil {
+		log.Println("rafthttp: error unmarshaling raft message:", err)
+		http.Error(w, "error unmarshaling raft message", http.StatusBadRequest)
+		return
+	}
+	if err := h.p.Process(context.TODO(), m); err != nil {
+		switch v := err.(type) {
+		case writerToResponse:
+			v.WriteTo(w)
+		default:
+			log.Printf("rafthttp: error processing raft message: %v", err)
+			http.Error(w, "error processing raft message", http.StatusInternalServerError)
+		}
+		return
+	}
+	w.WriteHeader(http.StatusNoContent)
+}
+
+type writerToResponse interface {
+	WriteTo(w http.ResponseWriter)
+}

+ 184 - 0
rafthttp/http_test.go

@@ -0,0 +1,184 @@
+/*
+   Copyright 2014 CoreOS, Inc.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package rafthttp
+
+import (
+	"bytes"
+	"errors"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"testing"
+
+	"github.com/coreos/etcd/pkg/pbutil"
+	"github.com/coreos/etcd/pkg/types"
+	"github.com/coreos/etcd/raft/raftpb"
+
+	"github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context"
+)
+
+func TestServeRaft(t *testing.T) {
+	testCases := []struct {
+		method    string
+		body      io.Reader
+		p         Processor
+		clusterID string
+
+		wcode int
+	}{
+		{
+			// bad method
+			"GET",
+			bytes.NewReader(
+				pbutil.MustMarshal(&raftpb.Message{}),
+			),
+			&nopProcessor{},
+			"0",
+			http.StatusMethodNotAllowed,
+		},
+		{
+			// bad method
+			"PUT",
+			bytes.NewReader(
+				pbutil.MustMarshal(&raftpb.Message{}),
+			),
+			&nopProcessor{},
+			"0",
+			http.StatusMethodNotAllowed,
+		},
+		{
+			// bad method
+			"DELETE",
+			bytes.NewReader(
+				pbutil.MustMarshal(&raftpb.Message{}),
+			),
+			&nopProcessor{},
+			"0",
+			http.StatusMethodNotAllowed,
+		},
+		{
+			// bad request body
+			"POST",
+			&errReader{},
+			&nopProcessor{},
+			"0",
+			http.StatusBadRequest,
+		},
+		{
+			// bad request protobuf
+			"POST",
+			strings.NewReader("malformed garbage"),
+			&nopProcessor{},
+			"0",
+			http.StatusBadRequest,
+		},
+		{
+			// good request, wrong cluster ID
+			"POST",
+			bytes.NewReader(
+				pbutil.MustMarshal(&raftpb.Message{}),
+			),
+			&nopProcessor{},
+			"1",
+			http.StatusPreconditionFailed,
+		},
+		{
+			// good request, Processor failure
+			"POST",
+			bytes.NewReader(
+				pbutil.MustMarshal(&raftpb.Message{}),
+			),
+			&errProcessor{
+				err: &resWriterToError{code: http.StatusForbidden},
+			},
+			"0",
+			http.StatusForbidden,
+		},
+		{
+			// good request, Processor failure
+			"POST",
+			bytes.NewReader(
+				pbutil.MustMarshal(&raftpb.Message{}),
+			),
+			&errProcessor{
+				err: &resWriterToError{code: http.StatusInternalServerError},
+			},
+			"0",
+			http.StatusInternalServerError,
+		},
+		{
+			// good request, Processor failure
+			"POST",
+			bytes.NewReader(
+				pbutil.MustMarshal(&raftpb.Message{}),
+			),
+			&errProcessor{err: errors.New("blah")},
+			"0",
+			http.StatusInternalServerError,
+		},
+		{
+			// good request
+			"POST",
+			bytes.NewReader(
+				pbutil.MustMarshal(&raftpb.Message{}),
+			),
+			&nopProcessor{},
+			"0",
+			http.StatusNoContent,
+		},
+	}
+	for i, tt := range testCases {
+		req, err := http.NewRequest(tt.method, "foo", tt.body)
+		if err != nil {
+			t.Fatalf("#%d: could not create request: %#v", i, err)
+		}
+		req.Header.Set("X-Etcd-Cluster-ID", tt.clusterID)
+		rw := httptest.NewRecorder()
+		h := NewHandler(tt.p, types.ID(0), &nopStats{})
+		h.ServeHTTP(rw, req)
+		if rw.Code != tt.wcode {
+			t.Errorf("#%d: got code=%d, want %d", i, rw.Code, tt.wcode)
+		}
+	}
+}
+
+// errReader implements io.Reader to facilitate a broken request.
+type errReader struct{}
+
+func (er *errReader) Read(_ []byte) (int, error) { return 0, errors.New("some error") }
+
+type nopProcessor struct{}
+
+func (p *nopProcessor) Process(ctx context.Context, m raftpb.Message) error { return nil }
+
+type errProcessor struct {
+	err error
+}
+
+func (p *errProcessor) Process(ctx context.Context, m raftpb.Message) error { return p.err }
+
+type nopStats struct{}
+
+func (s *nopStats) UpdateRecvApp(from types.ID, length int64) {}
+
+type resWriterToError struct {
+	code int
+}
+
+func (e *resWriterToError) Error() string                 { return "" }
+func (e *resWriterToError) WriteTo(w http.ResponseWriter) { w.WriteHeader(e.code) }

+ 28 - 112
etcdserver/sender.go → rafthttp/sender.go

@@ -14,124 +14,49 @@
    limitations under the License.
    limitations under the License.
 */
 */
 
 
-package etcdserver
+package rafthttp
 
 
 import (
 import (
 	"bytes"
 	"bytes"
 	"fmt"
 	"fmt"
 	"log"
 	"log"
 	"net/http"
 	"net/http"
-	"net/url"
-	"path"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
 	"github.com/coreos/etcd/etcdserver/stats"
 	"github.com/coreos/etcd/etcdserver/stats"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/types"
-	"github.com/coreos/etcd/raft/raftpb"
 )
 )
 
 
 const (
 const (
-	raftPrefix    = "/raft"
 	connPerSender = 4
 	connPerSender = 4
 	senderBufSize = connPerSender * 4
 	senderBufSize = connPerSender * 4
 )
 )
 
 
-type sendHub struct {
-	tr         http.RoundTripper
-	cl         ClusterInfo
-	ss         *stats.ServerStats
-	ls         *stats.LeaderStats
-	senders    map[types.ID]*sender
-	shouldstop chan struct{}
-}
-
-// newSendHub creates the default send hub used to transport raft messages
-// to other members. The returned sendHub will update the given ServerStats and
-// LeaderStats appropriately.
-func newSendHub(t http.RoundTripper, cl ClusterInfo, ss *stats.ServerStats, ls *stats.LeaderStats) *sendHub {
-	h := &sendHub{
-		tr:         t,
-		cl:         cl,
-		ss:         ss,
-		ls:         ls,
-		senders:    make(map[types.ID]*sender),
-		shouldstop: make(chan struct{}, 1),
-	}
-	for _, m := range cl.Members() {
-		h.Add(m)
-	}
-	return h
-}
-
-func (h *sendHub) Send(msgs []raftpb.Message) {
-	for _, m := range msgs {
-		to := types.ID(m.To)
-		s, ok := h.senders[to]
-		if !ok {
-			if !h.cl.IsIDRemoved(to) {
-				log.Printf("etcdserver: send message to unknown receiver %s", to)
-			}
-			continue
-		}
-
-		// TODO: don't block. we should be able to have 1000s
-		// of messages out at a time.
-		data, err := m.Marshal()
-		if err != nil {
-			log.Println("sender: dropping message:", err)
-			return // drop bad message
-		}
-		if m.Type == raftpb.MsgApp {
-			h.ss.SendAppendReq(len(data))
-		}
-
-		// TODO (xiangli): reasonable retry logic
-		s.send(data)
-	}
-}
-
-func (h *sendHub) Stop() {
-	for _, s := range h.senders {
-		s.stop()
-	}
-}
-
-func (h *sendHub) ShouldStopNotify() <-chan struct{} {
-	return h.shouldstop
+type Sender interface {
+	Update(u string)
+	// Send sends the data to the remote node. It is always non-blocking.
+	// It may be fail to send data if it returns nil error.
+	Send(data []byte) error
+	// Stop performs any necessary finalization and terminates the Sender
+	// elegantly.
+	Stop()
 }
 }
 
 
-func (h *sendHub) Add(m *Member) {
-	if _, ok := h.senders[m.ID]; ok {
-		return
-	}
-	// TODO: considering how to switch between all available peer urls
-	u := fmt.Sprintf("%s%s", m.PickPeerURL(), raftPrefix)
-	fs := h.ls.Follower(m.ID.String())
-	s := newSender(h.tr, u, h.cl.ID(), fs, h.shouldstop)
-	h.senders[m.ID] = s
-}
-
-func (h *sendHub) Remove(id types.ID) {
-	h.senders[id].stop()
-	delete(h.senders, id)
-}
-
-func (h *sendHub) Update(m *Member) {
-	// TODO: return error or just panic?
-	if _, ok := h.senders[m.ID]; !ok {
-		return
+func NewSender(tr http.RoundTripper, u string, cid types.ID, fs *stats.FollowerStats, shouldstop chan struct{}) *sender {
+	s := &sender{
+		tr:         tr,
+		u:          u,
+		cid:        cid,
+		fs:         fs,
+		q:          make(chan []byte, senderBufSize),
+		shouldstop: shouldstop,
 	}
 	}
-	peerURL := m.PickPeerURL()
-	u, err := url.Parse(peerURL)
-	if err != nil {
-		log.Panicf("unexpect peer url %s", peerURL)
+	s.wg.Add(connPerSender)
+	for i := 0; i < connPerSender; i++ {
+		go s.handle()
 	}
 	}
-	u.Path = path.Join(u.Path, raftPrefix)
-	s := h.senders[m.ID]
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	s.u = u.String()
+	return s
 }
 }
 
 
 type sender struct {
 type sender struct {
@@ -145,23 +70,14 @@ type sender struct {
 	shouldstop chan struct{}
 	shouldstop chan struct{}
 }
 }
 
 
-func newSender(tr http.RoundTripper, u string, cid types.ID, fs *stats.FollowerStats, shouldstop chan struct{}) *sender {
-	s := &sender{
-		tr:         tr,
-		u:          u,
-		cid:        cid,
-		fs:         fs,
-		q:          make(chan []byte, senderBufSize),
-		shouldstop: shouldstop,
-	}
-	s.wg.Add(connPerSender)
-	for i := 0; i < connPerSender; i++ {
-		go s.handle()
-	}
-	return s
+func (s *sender) Update(u string) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	s.u = u
 }
 }
 
 
-func (s *sender) send(data []byte) error {
+// TODO (xiangli): reasonable retry logic
+func (s *sender) Send(data []byte) error {
 	select {
 	select {
 	case s.q <- data:
 	case s.q <- data:
 		return nil
 		return nil
@@ -171,7 +87,7 @@ func (s *sender) send(data []byte) error {
 	}
 	}
 }
 }
 
 
-func (s *sender) stop() {
+func (s *sender) Stop() {
 	close(s.q)
 	close(s.q)
 	s.wg.Wait()
 	s.wg.Wait()
 }
 }

+ 19 - 105
etcdserver/sender_test.go → rafthttp/sender_test.go

@@ -14,7 +14,7 @@
    limitations under the License.
    limitations under the License.
 */
 */
 
 
-package etcdserver
+package rafthttp
 
 
 import (
 import (
 	"errors"
 	"errors"
@@ -22,109 +22,23 @@ import (
 	"net/http"
 	"net/http"
 	"sync"
 	"sync"
 	"testing"
 	"testing"
-	"time"
 
 
 	"github.com/coreos/etcd/etcdserver/stats"
 	"github.com/coreos/etcd/etcdserver/stats"
 	"github.com/coreos/etcd/pkg/testutil"
 	"github.com/coreos/etcd/pkg/testutil"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/types"
 )
 )
 
 
-func TestSendHubInitSenders(t *testing.T) {
-	membs := []*Member{
-		newTestMember(1, []string{"http://a"}, "", nil),
-		newTestMember(2, []string{"http://b"}, "", nil),
-		newTestMember(3, []string{"http://c"}, "", nil),
-	}
-	cl := newTestCluster(membs)
-	ls := stats.NewLeaderStats("")
-	h := newSendHub(nil, cl, nil, ls)
-
-	ids := cl.MemberIDs()
-	if len(h.senders) != len(ids) {
-		t.Errorf("len(ids) = %d, want %d", len(h.senders), len(ids))
-	}
-	for _, id := range ids {
-		if _, ok := h.senders[id]; !ok {
-			t.Errorf("senders[%s] is nil, want exists", id)
-		}
-	}
-}
-
-func TestSendHubAdd(t *testing.T) {
-	cl := newTestCluster(nil)
-	ls := stats.NewLeaderStats("")
-	h := newSendHub(nil, cl, nil, ls)
-	m := newTestMember(1, []string{"http://a"}, "", nil)
-	h.Add(m)
-
-	if _, ok := ls.Followers["1"]; !ok {
-		t.Errorf("FollowerStats[1] is nil, want exists")
-	}
-	s, ok := h.senders[types.ID(1)]
-	if !ok {
-		t.Fatalf("senders[1] is nil, want exists")
-	}
-	if s.u != "http://a/raft" {
-		t.Errorf("url = %s, want %s", s.u, "http://a/raft")
-	}
-
-	h.Add(m)
-	ns := h.senders[types.ID(1)]
-	if s != ns {
-		t.Errorf("sender = %p, want %p", ns, s)
-	}
-}
-
-func TestSendHubRemove(t *testing.T) {
-	membs := []*Member{
-		newTestMember(1, []string{"http://a"}, "", nil),
-	}
-	cl := newTestCluster(membs)
-	ls := stats.NewLeaderStats("")
-	h := newSendHub(nil, cl, nil, ls)
-	h.Remove(types.ID(1))
-
-	if _, ok := h.senders[types.ID(1)]; ok {
-		t.Fatalf("senders[1] exists, want removed")
-	}
-}
-
-func TestSendHubShouldStop(t *testing.T) {
-	membs := []*Member{
-		newTestMember(1, []string{"http://a"}, "", nil),
-	}
-	tr := newRespRoundTripper(http.StatusForbidden, nil)
-	cl := newTestCluster(membs)
-	ls := stats.NewLeaderStats("")
-	h := newSendHub(tr, cl, nil, ls)
-
-	shouldstop := h.ShouldStopNotify()
-	select {
-	case <-shouldstop:
-		t.Fatalf("received unexpected shouldstop notification")
-	case <-time.After(10 * time.Millisecond):
-	}
-	h.senders[1].send([]byte("somedata"))
-
-	testutil.ForceGosched()
-	select {
-	case <-shouldstop:
-	default:
-		t.Fatalf("cannot receive stop notification")
-	}
-}
-
 // TestSenderSend tests that send func could post data using roundtripper
 // TestSenderSend tests that send func could post data using roundtripper
 // and increase success count in stats.
 // and increase success count in stats.
 func TestSenderSend(t *testing.T) {
 func TestSenderSend(t *testing.T) {
 	tr := &roundTripperRecorder{}
 	tr := &roundTripperRecorder{}
 	fs := &stats.FollowerStats{}
 	fs := &stats.FollowerStats{}
-	s := newSender(tr, "http://10.0.0.1", types.ID(1), fs, nil)
+	s := NewSender(tr, "http://10.0.0.1", types.ID(1), fs, nil)
 
 
-	if err := s.send([]byte("some data")); err != nil {
+	if err := s.Send([]byte("some data")); err != nil {
 		t.Fatalf("unexpect send error: %v", err)
 		t.Fatalf("unexpect send error: %v", err)
 	}
 	}
-	s.stop()
+	s.Stop()
 
 
 	if tr.Request() == nil {
 	if tr.Request() == nil {
 		t.Errorf("sender fails to post the data")
 		t.Errorf("sender fails to post the data")
@@ -139,12 +53,12 @@ func TestSenderSend(t *testing.T) {
 func TestSenderExceedMaximalServing(t *testing.T) {
 func TestSenderExceedMaximalServing(t *testing.T) {
 	tr := newRoundTripperBlocker()
 	tr := newRoundTripperBlocker()
 	fs := &stats.FollowerStats{}
 	fs := &stats.FollowerStats{}
-	s := newSender(tr, "http://10.0.0.1", types.ID(1), fs, nil)
+	s := NewSender(tr, "http://10.0.0.1", types.ID(1), fs, nil)
 
 
 	// keep the sender busy and make the buffer full
 	// keep the sender busy and make the buffer full
 	// nothing can go out as we block the sender
 	// nothing can go out as we block the sender
 	for i := 0; i < connPerSender+senderBufSize; i++ {
 	for i := 0; i < connPerSender+senderBufSize; i++ {
-		if err := s.send([]byte("some data")); err != nil {
+		if err := s.Send([]byte("some data")); err != nil {
 			t.Errorf("send err = %v, want nil", err)
 			t.Errorf("send err = %v, want nil", err)
 		}
 		}
 		// force the sender to grab data
 		// force the sender to grab data
@@ -152,7 +66,7 @@ func TestSenderExceedMaximalServing(t *testing.T) {
 	}
 	}
 
 
 	// try to send a data when we are sure the buffer is full
 	// try to send a data when we are sure the buffer is full
-	if err := s.send([]byte("some data")); err == nil {
+	if err := s.Send([]byte("some data")); err == nil {
 		t.Errorf("unexpect send success")
 		t.Errorf("unexpect send success")
 	}
 	}
 
 
@@ -161,22 +75,22 @@ func TestSenderExceedMaximalServing(t *testing.T) {
 	testutil.ForceGosched()
 	testutil.ForceGosched()
 
 
 	// It could send new data after previous ones succeed
 	// It could send new data after previous ones succeed
-	if err := s.send([]byte("some data")); err != nil {
+	if err := s.Send([]byte("some data")); err != nil {
 		t.Errorf("send err = %v, want nil", err)
 		t.Errorf("send err = %v, want nil", err)
 	}
 	}
-	s.stop()
+	s.Stop()
 }
 }
 
 
 // TestSenderSendFailed tests that when send func meets the post error,
 // TestSenderSendFailed tests that when send func meets the post error,
 // it increases fail count in stats.
 // it increases fail count in stats.
 func TestSenderSendFailed(t *testing.T) {
 func TestSenderSendFailed(t *testing.T) {
 	fs := &stats.FollowerStats{}
 	fs := &stats.FollowerStats{}
-	s := newSender(newRespRoundTripper(0, errors.New("blah")), "http://10.0.0.1", types.ID(1), fs, nil)
+	s := NewSender(newRespRoundTripper(0, errors.New("blah")), "http://10.0.0.1", types.ID(1), fs, nil)
 
 
-	if err := s.send([]byte("some data")); err != nil {
-		t.Fatalf("unexpect send error: %v", err)
+	if err := s.Send([]byte("some data")); err != nil {
+		t.Fatalf("unexpect Send error: %v", err)
 	}
 	}
-	s.stop()
+	s.Stop()
 
 
 	fs.Lock()
 	fs.Lock()
 	defer fs.Unlock()
 	defer fs.Unlock()
@@ -187,11 +101,11 @@ func TestSenderSendFailed(t *testing.T) {
 
 
 func TestSenderPost(t *testing.T) {
 func TestSenderPost(t *testing.T) {
 	tr := &roundTripperRecorder{}
 	tr := &roundTripperRecorder{}
-	s := newSender(tr, "http://10.0.0.1", types.ID(1), nil, nil)
+	s := NewSender(tr, "http://10.0.0.1", types.ID(1), nil, nil)
 	if err := s.post([]byte("some data")); err != nil {
 	if err := s.post([]byte("some data")); err != nil {
 		t.Fatalf("unexpect post error: %v", err)
 		t.Fatalf("unexpect post error: %v", err)
 	}
 	}
-	s.stop()
+	s.Stop()
 
 
 	if g := tr.Request().Method; g != "POST" {
 	if g := tr.Request().Method; g != "POST" {
 		t.Errorf("method = %s, want %s", g, "POST")
 		t.Errorf("method = %s, want %s", g, "POST")
@@ -230,9 +144,9 @@ func TestSenderPostBad(t *testing.T) {
 	}
 	}
 	for i, tt := range tests {
 	for i, tt := range tests {
 		shouldstop := make(chan struct{})
 		shouldstop := make(chan struct{})
-		s := newSender(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), nil, shouldstop)
+		s := NewSender(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), nil, shouldstop)
 		err := s.post([]byte("some data"))
 		err := s.post([]byte("some data"))
-		s.stop()
+		s.Stop()
 
 
 		if err == nil {
 		if err == nil {
 			t.Errorf("#%d: err = nil, want not nil", i)
 			t.Errorf("#%d: err = nil, want not nil", i)
@@ -251,9 +165,9 @@ func TestSenderPostShouldStop(t *testing.T) {
 	}
 	}
 	for i, tt := range tests {
 	for i, tt := range tests {
 		shouldstop := make(chan struct{}, 1)
 		shouldstop := make(chan struct{}, 1)
-		s := newSender(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), nil, shouldstop)
+		s := NewSender(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), nil, shouldstop)
 		s.post([]byte("some data"))
 		s.post([]byte("some data"))
-		s.stop()
+		s.Stop()
 		select {
 		select {
 		case <-shouldstop:
 		case <-shouldstop:
 		default:
 		default: