Browse Source

Merge pull request #1594 from unihorn/201

etcdhttp/etcdserver: support HEAD on /v2/keys/ namespace
Yicheng Qin 11 years ago
parent
commit
ac49e1d50f

+ 1 - 1
etcdserver/etcdhttp/client.go

@@ -93,7 +93,7 @@ type keysHandler struct {
 }
 
 func (h *keysHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	if !allowMethod(w, r.Method, "GET", "PUT", "POST", "DELETE") {
+	if !allowMethod(w, r.Method, "HEAD", "GET", "PUT", "POST", "DELETE") {
 		return
 	}
 	w.Header().Set("X-Etcd-Cluster-ID", h.clusterInfo.ID().String())

+ 61 - 0
etcdserver/etcdhttp/client_test.go

@@ -63,6 +63,18 @@ func mustNewForm(t *testing.T, p string, vals url.Values) *http.Request {
 	return req
 }
 
+// mustNewPostForm takes a set of Values and constructs a POST *http.Request,
+// with a URL constructed from appending the given path to the standard keysPrefix
+func mustNewPostForm(t *testing.T, p string, vals url.Values) *http.Request {
+	u := mustNewURL(t, path.Join(keysPrefix, p))
+	req, err := http.NewRequest("POST", u.String(), strings.NewReader(vals.Encode()))
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	if err != nil {
+		t.Fatalf("error creating new request: %v", err)
+	}
+	return req
+}
+
 // mustNewRequest takes a path, appends it to the standard keysPrefix, and constructs
 // a GET *http.Request referencing the resulting URL
 func mustNewRequest(t *testing.T, p string) *http.Request {
@@ -1171,6 +1183,55 @@ func TestBadServeKeys(t *testing.T) {
 	}
 }
 
+func TestServeKeysGood(t *testing.T) {
+	tests := []struct {
+		req   *http.Request
+		wcode int
+	}{
+		{
+			mustNewMethodRequest(t, "HEAD", "foo"),
+			http.StatusOK,
+		},
+		{
+			mustNewMethodRequest(t, "GET", "foo"),
+			http.StatusOK,
+		},
+		{
+			mustNewForm(t, "foo", url.Values{"value": []string{"bar"}}),
+			http.StatusOK,
+		},
+		{
+			mustNewMethodRequest(t, "DELETE", "foo"),
+			http.StatusOK,
+		},
+		{
+			mustNewPostForm(t, "foo", url.Values{"value": []string{"bar"}}),
+			http.StatusOK,
+		},
+	}
+	server := &resServer{
+		etcdserver.Response{
+			Event: &store.Event{
+				Action: store.Get,
+				Node:   &store.NodeExtern{},
+			},
+		},
+	}
+	for i, tt := range tests {
+		h := &keysHandler{
+			timeout:     time.Hour,
+			server:      server,
+			timer:       &dummyRaftTimer{},
+			clusterInfo: &fakeCluster{id: 1},
+		}
+		rw := httptest.NewRecorder()
+		h.ServeHTTP(rw, tt.req)
+		if rw.Code != tt.wcode {
+			t.Errorf("#%d: got code=%d, want %d", i, rw.Code, tt.wcode)
+		}
+	}
+}
+
 func TestServeKeysEvent(t *testing.T) {
 	req := mustNewRequest(t, "foo")
 	server := &resServer{

+ 6 - 0
etcdserver/server.go

@@ -409,6 +409,12 @@ func (s *EtcdServer) Do(ctx context.Context, r pb.Request) (Response, error) {
 			}
 			return Response{Event: ev}, nil
 		}
+	case "HEAD":
+		ev, err := s.store.Get(r.Path, r.Recursive, r.Sorted)
+		if err != nil {
+			return Response{}, err
+		}
+		return Response{Event: ev}, nil
 	default:
 		return Response{}, ErrUnknownMethod
 	}

+ 14 - 0
etcdserver/server_test.go

@@ -87,6 +87,16 @@ func TestDoLocalAction(t *testing.T) {
 				},
 			},
 		},
+		{
+			pb.Request{Method: "HEAD", ID: 1},
+			Response{Event: &store.Event{}}, nil,
+			[]action{
+				action{
+					name:   "Get",
+					params: []interface{}{"", false, false},
+				},
+			},
+		},
 		{
 			pb.Request{Method: "BADMETHOD", ID: 1},
 			Response{}, ErrUnknownMethod, []action{},
@@ -127,6 +137,10 @@ func TestDoBadLocalAction(t *testing.T) {
 			pb.Request{Method: "GET", ID: 1},
 			[]action{action{name: "Get"}},
 		},
+		{
+			pb.Request{Method: "HEAD", ID: 1},
+			[]action{action{name: "Get"}},
+		},
 	}
 	for i, tt := range tests {
 		st := &errStoreRecorder{err: storeErr}

+ 4 - 7
integration/v2_http_kv_test.go

@@ -912,8 +912,6 @@ func TestV2WatchKeyInDir(t *testing.T) {
 	}
 }
 
-// TODO(jonboulle): enable once #1590 is fixed
-/*
 func TestV2Head(t *testing.T) {
 	cl := cluster{Size: 1}
 	cl.Launch(t)
@@ -930,8 +928,8 @@ func TestV2Head(t *testing.T) {
 	if resp.StatusCode != http.StatusNotFound {
 		t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound)
 	}
-	if resp.ContentLength != -1 {
-		t.Errorf("ContentLength = %d, want -1", resp.ContentLength)
+	if resp.ContentLength <= 0 {
+		t.Errorf("ContentLength = %d, want > 0", resp.ContentLength)
 	}
 
 	resp, _ = tc.PutForm(fullURL, v)
@@ -942,11 +940,10 @@ func TestV2Head(t *testing.T) {
 	if resp.StatusCode != http.StatusOK {
 		t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
 	}
-	if resp.ContentLength != -1 {
-		t.Errorf("ContentLength = %d, want -1", resp.ContentLength)
+	if resp.ContentLength <= 0 {
+		t.Errorf("ContentLength = %d, want > 0", resp.ContentLength)
 	}
 }
-*/
 
 func checkBody(body map[string]interface{}, w map[string]interface{}) error {
 	if body["node"] != nil {