Browse Source

Merge pull request #1069 from jonboulle/methods

etcdhttp: check method for every endpoint, add tests
Jonathan Boulle 11 years ago
parent
commit
35ae488120
2 changed files with 85 additions and 5 deletions
  1. 19 5
      etcdserver/etcdhttp/http.go
  2. 66 0
      etcdserver/etcdhttp/http_test.go

+ 19 - 5
etcdserver/etcdhttp/http.go

@@ -68,6 +68,10 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 }
 
 func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+	if !allowMethod(w, r.Method, "GET", "PUT", "POST", "DELETE") {
+		return
+	}
+
 	rr, err := parseRequest(r, genID())
 	if err != nil {
 		writeError(w, err)
@@ -103,8 +107,7 @@ func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.R
 // serveMachines responds address list in the format '0.0.0.0, 1.1.1.1'.
 // TODO: rethink the format of machine list because it is not json format.
 func (h Handler) serveMachines(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" && r.Method != "HEAD" {
-		allow(w, "GET", "HEAD")
+	if !allowMethod(w, r.Method, "GET", "HEAD") {
 		return
 	}
 	endpoints := h.Peers.Endpoints()
@@ -112,6 +115,9 @@ func (h Handler) serveMachines(w http.ResponseWriter, r *http.Request) {
 }
 
 func (h Handler) serveRaft(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+	if !allowMethod(w, r.Method, "POST") {
+		return
+	}
 	b, err := ioutil.ReadAll(r.Body)
 	if err != nil {
 		log.Println("etcdhttp: error reading raft message:", err)
@@ -317,8 +323,16 @@ func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher)
 	}
 }
 
-// allow writes response for the case that Method Not Allowed
-func allow(w http.ResponseWriter, m ...string) {
-	w.Header().Set("Allow", strings.Join(m, ","))
+// allowMethod verifies that the given method is one of the allowed methods,
+// and if not, it writes an error to w.  A boolean is returned indicating
+// whether or not the method is allowed.
+func allowMethod(w http.ResponseWriter, m string, ms ...string) bool {
+	for _, meth := range ms {
+		if m == meth {
+			return true
+		}
+	}
+	w.Header().Set("Allow", strings.Join(ms, ","))
 	http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
+	return false
 }

+ 66 - 0
etcdserver/etcdhttp/http_test.go

@@ -680,3 +680,69 @@ func TestPeersEndpoints(t *testing.T) {
 		}
 	}
 }
+
+func TestAllowMethod(t *testing.T) {
+	tests := []struct {
+		m  string
+		ms []string
+		w  bool
+		wh string
+	}{
+		// Accepted methods
+		{
+			m:  "GET",
+			ms: []string{"GET", "POST", "PUT"},
+			w:  true,
+		},
+		{
+			m:  "POST",
+			ms: []string{"POST"},
+			w:  true,
+		},
+		// Made-up methods no good
+		{
+			m:  "FAKE",
+			ms: []string{"GET", "POST", "PUT"},
+			w:  false,
+			wh: "GET,POST,PUT",
+		},
+		// Empty methods no good
+		{
+			m:  "",
+			ms: []string{"GET", "POST"},
+			w:  false,
+			wh: "GET,POST",
+		},
+		// Empty accepted methods no good
+		{
+			m:  "GET",
+			ms: []string{""},
+			w:  false,
+			wh: "",
+		},
+		// No methods accepted
+		{
+			m:  "GET",
+			ms: []string{},
+			w:  false,
+			wh: "",
+		},
+	}
+
+	for i, tt := range tests {
+		rw := httptest.NewRecorder()
+		g := allowMethod(rw, tt.m, tt.ms...)
+		if g != tt.w {
+			t.Errorf("#%d: got allowMethod()=%t, want %t", i, g, tt.w)
+		}
+		if !tt.w {
+			if rw.Code != http.StatusMethodNotAllowed {
+				t.Errorf("#%d: code=%d, want %d", i, rw.Code, http.StatusMethodNotAllowed)
+			}
+			gh := rw.Header().Get("Allow")
+			if gh != tt.wh {
+				t.Errorf("#%d: Allow header=%q, want %q", i, gh, tt.wh)
+			}
+		}
+	}
+}