Browse Source

Merge pull request #1086 from jonboulle/serve_raft_test

etcdserver/etcdhttp: add test for serveRaft
Jonathan Boulle 11 years ago
parent
commit
43acdef660

+ 5 - 1
etcdserver/etcdhttp/http.go

@@ -121,10 +121,14 @@ func (h serverHandler) serveRaft(w http.ResponseWriter, r *http.Request) {
 	b, err := ioutil.ReadAll(r.Body)
 	if err != nil {
 		log.Println("etcdhttp: error reading raft message:", err)
+		http.Error(w, "error reading raft message", http.StatusBadRequest)
+		return
 	}
 	var m raftpb.Message
 	if err := m.Unmarshal(b); err != nil {
 		log.Println("etcdhttp: error unmarshaling raft message:", err)
+		http.Error(w, "error unmarshaling raft message", http.StatusBadRequest)
+		return
 	}
 	log.Printf("etcdhttp: raft recv message from %#x: %+v", m.From, m)
 	if err := h.server.Process(context.TODO(), m); err != nil {
@@ -132,7 +136,7 @@ func (h serverHandler) serveRaft(w http.ResponseWriter, r *http.Request) {
 		writeError(w, err)
 		return
 	}
-	w.WriteHeader(http.StatusOK)
+	w.WriteHeader(http.StatusNoContent)
 }
 
 // genID generates a random id that is: n < 0 < n.

+ 87 - 0
etcdserver/etcdhttp/http_test.go

@@ -1,7 +1,9 @@
 package etcdhttp
 
 import (
+	"bytes"
 	"errors"
+	"io"
 	"net/http"
 	"net/http/httptest"
 	"net/url"
@@ -13,7 +15,9 @@ import (
 	"time"
 
 	etcdErr "github.com/coreos/etcd/error"
+	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver/etcdserverpb"
+	"github.com/coreos/etcd/raft/raftpb"
 	"github.com/coreos/etcd/store"
 	"github.com/coreos/etcd/third_party/code.google.com/p/go.net/context"
 )
@@ -747,3 +751,86 @@ func TestAllowMethod(t *testing.T) {
 		}
 	}
 }
+
+// errServer implements the etcd.Server interface for testing.
+// It returns the given error from any Do/Process calls.
+type errServer struct {
+	err error
+}
+
+func (fs *errServer) Do(ctx context.Context, r etcdserverpb.Request) (etcdserver.Response, error) {
+	return etcdserver.Response{}, fs.err
+}
+func (fs *errServer) Process(ctx context.Context, m raftpb.Message) error {
+	return fs.err
+}
+func (fs *errServer) Start() {}
+func (fs *errServer) Stop()  {}
+
+// errReader implements io.Reader to facilitate a broken request.
+type errReader struct{}
+
+func (er *errReader) Read(_ []byte) (int, error) { return 0, errors.New("some error") }
+
+func mustMarshalMsg(t *testing.T, m raftpb.Message) []byte {
+	json, err := m.Marshal()
+	if err != nil {
+		t.Fatalf("error marshalling raft Message: %#v", err)
+	}
+	return json
+}
+
+func TestServeRaft(t *testing.T) {
+	testCases := []struct {
+		reqBody   io.Reader
+		serverErr error
+		wcode     int
+	}{
+		{
+			&errReader{},
+			nil,
+			http.StatusBadRequest,
+		},
+		{
+			strings.NewReader("malformed garbage"),
+			nil,
+			http.StatusBadRequest,
+		},
+		{
+			bytes.NewReader(
+				mustMarshalMsg(
+					t,
+					raftpb.Message{},
+				),
+			),
+			errors.New("some error"),
+			http.StatusInternalServerError,
+		},
+		{
+			bytes.NewReader(
+				mustMarshalMsg(
+					t,
+					raftpb.Message{},
+				),
+			),
+			nil,
+			http.StatusNoContent,
+		},
+	}
+	for i, tt := range testCases {
+		req, err := http.NewRequest("POST", "foo", tt.reqBody)
+		if err != nil {
+			t.Fatalf("#%d: could not create request: %#v", i, err)
+		}
+		h := &serverHandler{
+			timeout: time.Hour,
+			server:  &errServer{tt.serverErr},
+			peers:   nil,
+		}
+		rw := httptest.NewRecorder()
+		h.serveRaft(rw, req)
+		if rw.Code != tt.wcode {
+			t.Errorf("#%d: got code=%d, want %d", i, rw.Code, tt.wcode)
+		}
+	}
+}

+ 1 - 1
etcdserver/etcdhttp/peers.go

@@ -131,7 +131,7 @@ func httpPost(url string, data []byte) bool {
 		return false
 	}
 	resp.Body.Close()
-	if resp.StatusCode != 200 {
+	if resp.StatusCode != http.StatusNoContent {
 		elog.TODO()
 		return false
 	}

+ 9 - 2
etcdserver/etcdhttp/peers_test.go

@@ -124,14 +124,21 @@ func TestHttpPost(t *testing.T) {
 		{
 			http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 				tr = r
-				w.WriteHeader(200)
+				w.WriteHeader(http.StatusNoContent)
 			}),
 			true,
 		},
 		{
 			http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 				tr = r
-				w.WriteHeader(404)
+				w.WriteHeader(http.StatusNotFound)
+			}),
+			false,
+		},
+		{
+			http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				tr = r
+				w.WriteHeader(http.StatusInternalServerError)
 			}),
 			false,
 		},