Browse Source

etcdhttp: perform validation of query parameters

Add basic input validation of all query parameters supported by
serveKeys. Also restructures etcdhttp a bit to better facilitate
testing.

Test coverage is slightly improved.
Jonathan Boulle 11 years ago
parent
commit
e736a11ac4
2 changed files with 216 additions and 113 deletions
  1. 90 56
      etcdserver/etcdhttp/http.go
  2. 126 57
      etcdserver/etcdhttp/http_test.go

+ 90 - 56
etcdserver/etcdhttp/http.go

@@ -20,7 +20,7 @@ import (
 	"math/rand"
 
 	"github.com/coreos/etcd/elog"
-	etcderrors "github.com/coreos/etcd/error"
+	etcdErr "github.com/coreos/etcd/error"
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver/etcdserverpb"
 	"github.com/coreos/etcd/raft/raftpb"
@@ -33,6 +33,8 @@ const (
 	machinesPrefix = "/v2/machines"
 )
 
+var emptyReq = etcdserverpb.Request{}
+
 type Peers map[int64][]string
 
 func (ps Peers) Pick(id int64) string {
@@ -178,28 +180,32 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.Request) {
 	rr, err := parseRequest(r, genId())
 	if err != nil {
-		log.Println(err) // reading of body failed
+		http.Error(w, err.Error(), http.StatusBadRequest)
 		return
 	}
 
 	resp, err := h.Server.Do(ctx, rr)
-	switch e := err.(type) {
-	case nil:
-	case *etcderrors.Error:
-		// TODO: gross. this should be handled in encodeResponse
-		log.Println(err)
-		e.Write(w)
-		return
-	default:
-		log.Println(err)
-		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+	if err != nil {
+		writeInternalError(w, err)
 		return
 	}
 
-	if err := encodeResponse(ctx, w, resp); err != nil {
-		http.Error(w, "Timeout while waiting for response", http.StatusGatewayTimeout)
+	var ev *store.Event
+	switch {
+	case resp.Event != nil:
+		ev = resp.Event
+	case resp.Watcher != nil:
+		ev, err = waitForEvent(ctx, w, resp.Watcher)
+		if err != nil {
+			http.Error(w, err.Error(), http.StatusGatewayTimeout)
+			return
+		}
+	default:
+		writeInternalError(w, errors.New("received response with no Event/Watcher!"))
 		return
 	}
+
+	writeEvent(w, ev)
 }
 
 // serveMachines responds address list in the format '0.0.0.0, 1.1.1.1'.
@@ -249,38 +255,60 @@ func genId() int64 {
 }
 
 func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) {
-	if err := r.ParseForm(); err != nil {
-		return etcdserverpb.Request{}, err
+	var err error
+
+	if err = r.ParseForm(); err != nil {
+		return emptyReq, err
 	}
+
 	if !strings.HasPrefix(r.URL.Path, keysPrefix) {
-		return etcdserverpb.Request{}, errors.New("unexpected key prefix!")
+		return emptyReq, errors.New("unexpected key prefix!")
 	}
+	path := r.URL.Path[len(keysPrefix):]
 
 	q := r.URL.Query()
-	// TODO(jonboulle): perform strict validation of all parameters
-	// https://github.com/coreos/etcd/issues/1011
+
+	var pIdx, wIdx, ttl uint64
+	if pIdx, err = parseUint64(q.Get("prevIndex")); err != nil {
+		return emptyReq, errors.New("invalid value for prevIndex")
+	}
+	if wIdx, err = parseUint64(q.Get("waitIndex")); err != nil {
+		return emptyReq, errors.New("invalid value for waitIndex")
+	}
+	if ttl, err = parseUint64(q.Get("ttl")); err != nil {
+		return emptyReq, errors.New("invalid value for ttl")
+	}
+
+	var rec, sort, wait bool
+	if rec, err = parseBool(q.Get("recursive")); err != nil {
+		return emptyReq, errors.New("invalid value for recursive")
+	}
+	if sort, err = parseBool(q.Get("sorted")); err != nil {
+		return emptyReq, errors.New("invalid value for sorted")
+	}
+	if wait, err = parseBool(q.Get("wait")); err != nil {
+		return emptyReq, errors.New("invalid value for wait")
+	}
+
 	rr := etcdserverpb.Request{
 		Id:        id,
 		Method:    r.Method,
 		Val:       r.FormValue("value"),
-		Path:      r.URL.Path[len(keysPrefix):],
+		Path:      path,
 		PrevValue: q.Get("prevValue"),
-		PrevIndex: parseUint64(q.Get("prevIndex")),
-		Recursive: parseBool(q.Get("recursive")),
-		Since:     parseUint64(q.Get("waitIndex")),
-		Sorted:    parseBool(q.Get("sorted")),
-		Wait:      parseBool(q.Get("wait")),
-	}
-
-	// PrevExists is nullable, so we leave it null if prevExist wasn't
-	// specified.
-	_, ok := q["prevExists"]
-	if ok {
-		bv := parseBool(q.Get("prevExists"))
+		PrevIndex: pIdx,
+		Recursive: rec,
+		Since:     wIdx,
+		Sorted:    sort,
+		Wait:      wait,
+	}
+
+	// prevExists is nullable, so leave it null if not specified
+	if _, ok := q["prevExists"]; ok {
+		bv, _ := parseBool(q.Get("prevExists"))
 		rr.PrevExists = &bv
 	}
 
-	ttl := parseUint64(q.Get("ttl"))
 	if ttl > 0 {
 		expr := time.Duration(ttl) * time.Second
 		// TODO(jonboulle): use fake clock instead of time module
@@ -291,32 +319,40 @@ func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) {
 	return rr, nil
 }
 
-func parseBool(s string) bool {
-	v, _ := strconv.ParseBool(s)
-	return v
+func parseBool(s string) (bool, error) {
+	if s == "" {
+		return false, nil
+	}
+	return strconv.ParseBool(s)
 }
 
-func parseUint64(s string) uint64 {
-	v, _ := strconv.ParseUint(s, 10, 64)
-	return v
+func parseUint64(s string) (uint64, error) {
+	if s == "" {
+		return 0, nil
+	}
+	return strconv.ParseUint(s, 10, 64)
 }
 
-// encodeResponse serializes the given etcdserver Response and writes the
-// resulting JSON to the given ResponseWriter, utilizing the provided context
-func encodeResponse(ctx context.Context, w http.ResponseWriter, resp etcdserver.Response) (err error) {
-	var ev *store.Event
-	switch {
-	case resp.Event != nil:
-		ev = resp.Event
-	case resp.Watcher != nil:
-		ev, err = waitForEvent(ctx, w, resp.Watcher)
-		if err != nil {
-			return err
-		}
-	default:
-		panic("should not be reachable")
+// writeInternalError logs and writes the given Error to the ResponseWriter
+// If Error is an etcdErr, it is rendered to the ResponseWriter
+func writeInternalError(w http.ResponseWriter, err error) {
+	if err == nil {
+		return
+	}
+	log.Println(err)
+	if e, ok := err.(*etcdErr.Error); ok {
+		e.Write(w)
+	} else {
+		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
 	}
+}
 
+// writeEvent serializes the given Event and writes the resulting JSON to the
+// given ResponseWriter
+func writeEvent(w http.ResponseWriter, ev *store.Event) {
+	if ev == nil {
+		return
+	}
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Add("X-Etcd-Index", fmt.Sprint(ev.Index()))
 
@@ -327,10 +363,9 @@ func encodeResponse(ctx context.Context, w http.ResponseWriter, resp etcdserver.
 	if err := json.NewEncoder(w).Encode(ev); err != nil {
 		panic(err) // should never be reached
 	}
-	return nil
 }
 
-// waitForEvent waits for a given watcher to return its associated
+// waitForEvent waits for a given Watcher to return its associated
 // event. It returns a non-nil error if the given Context times out
 // or the given ResponseWriter triggers a CloseNotify.
 func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher) (*store.Event, error) {
@@ -340,7 +375,6 @@ func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher)
 	if x, ok := w.(http.CloseNotifier); ok {
 		nch = x.CloseNotify()
 	}
-
 	select {
 	case ev := <-wa.EventChan():
 		return ev, nil

+ 126 - 57
etcdserver/etcdhttp/http_test.go

@@ -1,6 +1,7 @@
 package etcdhttp
 
 import (
+	"errors"
 	"net/http"
 	"net/http/httptest"
 	"net/url"
@@ -9,7 +10,7 @@ import (
 	"sync"
 	"testing"
 
-	"github.com/coreos/etcd/etcdserver"
+	etcdErr "github.com/coreos/etcd/error"
 	"github.com/coreos/etcd/etcdserver/etcdserverpb"
 	"github.com/coreos/etcd/store"
 	"github.com/coreos/etcd/third_party/code.google.com/p/go.net/context"
@@ -25,6 +26,12 @@ func mustNewURL(t *testing.T, s string) *url.URL {
 	return u
 }
 
+func mustNewRequest(t *testing.T, p string) *http.Request {
+	return &http.Request{
+		URL: mustNewURL(t, path.Join(keysPrefix, p)),
+	}
+}
+
 func TestBadParseRequest(t *testing.T) {
 	tests := []struct {
 		in *http.Request
@@ -42,6 +49,47 @@ func TestBadParseRequest(t *testing.T) {
 				URL: mustNewURL(t, "/badprefix/"),
 			},
 		},
+		// bad values for prevIndex, waitIndex, ttl
+		{
+			mustNewRequest(t, "?prevIndex=foo"),
+		},
+		{
+			mustNewRequest(t, "?prevIndex=1.5"),
+		},
+		{
+			mustNewRequest(t, "?prevIndex=-1"),
+		},
+		{
+			mustNewRequest(t, "?waitIndex=garbage"),
+		},
+		{
+			mustNewRequest(t, "?waitIndex=??"),
+		},
+		{
+			mustNewRequest(t, "?ttl=-1"),
+		},
+		// bad values for recursive, sorted, wait
+		{
+			mustNewRequest(t, "?recursive=hahaha"),
+		},
+		{
+			mustNewRequest(t, "?recursive=1234"),
+		},
+		{
+			mustNewRequest(t, "?recursive=?"),
+		},
+		{
+			mustNewRequest(t, "?sorted=hahaha"),
+		},
+		{
+			mustNewRequest(t, "?sorted=!!"),
+		},
+		{
+			mustNewRequest(t, "?wait=notreally"),
+		},
+		{
+			mustNewRequest(t, "?wait=what!"),
+		},
 	}
 	for i, tt := range tests {
 		got, err := parseRequest(tt.in, 1234)
@@ -61,9 +109,7 @@ func TestGoodParseRequest(t *testing.T) {
 	}{
 		{
 			// good prefix, all other values default
-			&http.Request{
-				URL: mustNewURL(t, path.Join(keysPrefix, "foo")),
-			},
+			mustNewRequest(t, "foo"),
 			etcdserverpb.Request{
 				Id:   1234,
 				Path: "/foo",
@@ -71,9 +117,7 @@ func TestGoodParseRequest(t *testing.T) {
 		},
 		{
 			// value specified
-			&http.Request{
-				URL: mustNewURL(t, path.Join(keysPrefix, "foo?value=some_value")),
-			},
+			mustNewRequest(t, "foo?value=some_value"),
 			etcdserverpb.Request{
 				Id:   1234,
 				Val:  "some_value",
@@ -82,9 +126,7 @@ func TestGoodParseRequest(t *testing.T) {
 		},
 		{
 			// prevIndex specified
-			&http.Request{
-				URL: mustNewURL(t, path.Join(keysPrefix, "foo?prevIndex=98765")),
-			},
+			mustNewRequest(t, "foo?prevIndex=98765"),
 			etcdserverpb.Request{
 				Id:        1234,
 				PrevIndex: 98765,
@@ -93,9 +135,7 @@ func TestGoodParseRequest(t *testing.T) {
 		},
 		{
 			// recursive specified
-			&http.Request{
-				URL: mustNewURL(t, path.Join(keysPrefix, "foo?recursive=true")),
-			},
+			mustNewRequest(t, "foo?recursive=true"),
 			etcdserverpb.Request{
 				Id:        1234,
 				Recursive: true,
@@ -104,9 +144,7 @@ func TestGoodParseRequest(t *testing.T) {
 		},
 		{
 			// sorted specified
-			&http.Request{
-				URL: mustNewURL(t, path.Join(keysPrefix, "foo?sorted=true")),
-			},
+			mustNewRequest(t, "foo?sorted=true"),
 			etcdserverpb.Request{
 				Id:     1234,
 				Sorted: true,
@@ -115,9 +153,7 @@ func TestGoodParseRequest(t *testing.T) {
 		},
 		{
 			// wait specified
-			&http.Request{
-				URL: mustNewURL(t, path.Join(keysPrefix, "foo?wait=true")),
-			},
+			mustNewRequest(t, "foo?wait=true"),
 			etcdserverpb.Request{
 				Id:   1234,
 				Wait: true,
@@ -126,9 +162,7 @@ func TestGoodParseRequest(t *testing.T) {
 		},
 		{
 			// prevExists should be non-null if specified
-			&http.Request{
-				URL: mustNewURL(t, path.Join(keysPrefix, "foo?prevExists=true")),
-			},
+			mustNewRequest(t, "foo?prevExists=true"),
 			etcdserverpb.Request{
 				Id:         1234,
 				PrevExists: boolp(true),
@@ -137,9 +171,7 @@ func TestGoodParseRequest(t *testing.T) {
 		},
 		{
 			// prevExists should be non-null if specified
-			&http.Request{
-				URL: mustNewURL(t, path.Join(keysPrefix, "foo?prevExists=false")),
-			},
+			mustNewRequest(t, "foo?prevExists=false"),
 			etcdserverpb.Request{
 				Id:         1234,
 				PrevExists: boolp(false),
@@ -177,22 +209,77 @@ func (w *eventingWatcher) EventChan() chan *store.Event {
 
 func (w *eventingWatcher) Remove() {}
 
-func TestEncodeResponse(t *testing.T) {
+func TestWriteInternalError(t *testing.T) {
+	// nil error should not panic
+	rw := httptest.NewRecorder()
+	writeInternalError(rw, nil)
+	h := rw.Header()
+	if len(h) > 0 {
+		t.Fatalf("unexpected non-empty headers: %#v", h)
+	}
+	b := rw.Body.String()
+	if len(b) > 0 {
+		t.Fatalf("unexpected non-empty body: %q", b)
+	}
+
+	tests := []struct {
+		err  error
+		code int
+		idx  string
+	}{
+		{
+			etcdErr.NewError(etcdErr.EcodeKeyNotFound, "/foo/bar", 123),
+			http.StatusNotFound,
+			"123",
+		},
+		{
+			etcdErr.NewError(etcdErr.EcodeTestFailed, "/foo/bar", 456),
+			http.StatusPreconditionFailed,
+			"456",
+		},
+		{
+			err:  errors.New("something went wrong"),
+			code: http.StatusInternalServerError,
+		},
+	}
+
+	for i, tt := range tests {
+		rw := httptest.NewRecorder()
+		writeInternalError(rw, tt.err)
+		if code := rw.Code; code != tt.code {
+			t.Errorf("#%d: got %d, want %d", i, code, tt.code)
+		}
+		if idx := rw.Header().Get("X-Etcd-Index"); idx != tt.idx {
+			t.Errorf("#%d: got %q, want %q", i, idx, tt.idx)
+		}
+	}
+}
+
+func TestWriteEvent(t *testing.T) {
+	// nil event should not panic
+	rw := httptest.NewRecorder()
+	writeEvent(rw, nil)
+	h := rw.Header()
+	if len(h) > 0 {
+		t.Fatalf("unexpected non-empty headers: %#v", h)
+	}
+	b := rw.Body.String()
+	if len(b) > 0 {
+		t.Fatalf("unexpected non-empty body: %q", b)
+	}
+
 	tests := []struct {
-		resp etcdserver.Response
+		ev   *store.Event
 		idx  string
 		code int
 		err  error
 	}{
 		// standard case, standard 200 response
 		{
-			etcdserver.Response{
-				Event: &store.Event{
-					Action:   store.Get,
-					Node:     &store.NodeExtern{},
-					PrevNode: &store.NodeExtern{},
-				},
-				Watcher: nil,
+			&store.Event{
+				Action:   store.Get,
+				Node:     &store.NodeExtern{},
+				PrevNode: &store.NodeExtern{},
 			},
 			"0",
 			http.StatusOK,
@@ -200,21 +287,10 @@ func TestEncodeResponse(t *testing.T) {
 		},
 		// check new nodes return StatusCreated
 		{
-			etcdserver.Response{
-				Event: &store.Event{
-					Action:   store.Create,
-					Node:     &store.NodeExtern{},
-					PrevNode: &store.NodeExtern{},
-				},
-				Watcher: nil,
-			},
-			"0",
-			http.StatusCreated,
-			nil,
-		},
-		{
-			etcdserver.Response{
-				Watcher: &eventingWatcher{store.Create},
+			&store.Event{
+				Action:   store.Create,
+				Node:     &store.NodeExtern{},
+				PrevNode: &store.NodeExtern{},
 			},
 			"0",
 			http.StatusCreated,
@@ -224,20 +300,13 @@ func TestEncodeResponse(t *testing.T) {
 
 	for i, tt := range tests {
 		rw := httptest.NewRecorder()
-		err := encodeResponse(context.Background(), rw, tt.resp)
-		if err != tt.err {
-			t.Errorf("case %d: unexpected err: got %v, want %v", i, err, tt.err)
-			continue
-		}
-
+		writeEvent(rw, tt.ev)
 		if gct := rw.Header().Get("Content-Type"); gct != "application/json" {
 			t.Errorf("case %d: bad Content-Type: got %q, want application/json", i, gct)
 		}
-
 		if gei := rw.Header().Get("X-Etcd-Index"); gei != tt.idx {
 			t.Errorf("case %d: bad X-Etcd-Index header: got %s, want %s", i, gei, tt.idx)
 		}
-
 		if rw.Code != tt.code {
 			t.Errorf("case %d: bad response code: got %d, want %v", i, rw.Code, tt.code)
 		}