Browse Source

etcdserver: add test coverage for parseRequest

Jonathan Boulle 11 years ago
parent
commit
c78239a629
2 changed files with 157 additions and 5 deletions
  1. 14 5
      etcdserver/etcdhttp/http.go
  2. 143 0
      etcdserver/etcdhttp/http_test.go

+ 14 - 5
etcdserver/etcdhttp/http.go

@@ -27,6 +27,8 @@ import (
 	"github.com/coreos/etcd/third_party/code.google.com/p/go.net/context"
 )
 
+const keysPrefix = "/v2/keys"
+
 type Peers map[int64][]string
 
 func (ps Peers) Pick(id int64) string {
@@ -152,7 +154,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	switch {
 	case strings.HasPrefix(r.URL.Path, "/raft"):
 		h.serveRaft(ctx, w, r)
-	case strings.HasPrefix(r.URL.Path, "/v2/keys/"):
+	case strings.HasPrefix(r.URL.Path, keysPrefix):
 		h.serveKeys(ctx, w, r)
 	default:
 		http.NotFound(w, r)
@@ -160,7 +162,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 }
 
 func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.Request) {
-	rr, err := parseRequest(r)
+	rr, err := parseRequest(r, genId())
 	if err != nil {
 		log.Println(err) // reading of body failed
 		return
@@ -215,17 +217,22 @@ func genId() int64 {
 	}
 }
 
-func parseRequest(r *http.Request) (etcdserverpb.Request, error) {
+func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) {
 	if err := r.ParseForm(); err != nil {
 		return etcdserverpb.Request{}, err
 	}
+	if !strings.HasPrefix(r.URL.Path, keysPrefix) {
+		return etcdserverpb.Request{}, errors.New("expected key prefix!")
+	}
 
 	q := r.URL.Query()
+	// TODO(jonboulle): perform strict validation of all parameters
+	// https://github.com/coreos/etcd/issues/1011
 	rr := etcdserverpb.Request{
-		Id:        genId(),
+		Id:        id,
 		Method:    r.Method,
 		Val:       r.FormValue("value"),
-		Path:      r.URL.Path[len("/v2/keys"):],
+		Path:      r.URL.Path[len(keysPrefix):],
 		PrevValue: q.Get("prevValue"),
 		PrevIndex: parseUint64(q.Get("prevIndex")),
 		Recursive: parseBool(q.Get("recursive")),
@@ -245,6 +252,8 @@ func parseRequest(r *http.Request) (etcdserverpb.Request, error) {
 	ttl := parseUint64(q.Get("ttl"))
 	if ttl > 0 {
 		expr := time.Duration(ttl) * time.Second
+		// TODO(jonboulle): use fake clock instead of time module
+		// https://github.com/coreos/etcd/issues/1021
 		rr.Expiration = time.Now().Add(expr).UnixNano()
 	}
 

+ 143 - 0
etcdserver/etcdhttp/http_test.go

@@ -5,6 +5,7 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"net/url"
+	"path"
 	"reflect"
 	"strconv"
 	"sync"
@@ -12,6 +13,7 @@ import (
 	"time"
 
 	"github.com/coreos/etcd/etcdserver"
+	"github.com/coreos/etcd/etcdserver/etcdserverpb"
 	"github.com/coreos/etcd/raft"
 	"github.com/coreos/etcd/raft/raftpb"
 	"github.com/coreos/etcd/store"
@@ -73,6 +75,147 @@ func TestSet(t *testing.T) {
 }
 
 func stringp(s string) *string { return &s }
+func boolp(b bool) *bool       { return &b }
+
+func makeURL(t *testing.T, s string) *url.URL {
+	u, err := url.Parse(s)
+	if err != nil {
+		t.Fatalf("error creating URL from %q: %v", s, err)
+	}
+	return u
+}
+
+func TestParseRequest(t *testing.T) {
+	badTestCases := []struct {
+		in *http.Request
+	}{
+		{
+			// parseForm failure
+			&http.Request{
+				Body:   nil,
+				Method: "PUT",
+			},
+		},
+		{
+			// bad key prefix
+			&http.Request{
+				URL: makeURL(t, "/badprefix/"),
+			},
+		},
+	}
+	for i, tt := range badTestCases {
+		got, err := parseRequest(tt.in, 1234)
+		if err == nil {
+			t.Errorf("case %d: unexpected nil error!")
+		}
+		if !reflect.DeepEqual(got, etcdserverpb.Request{}) {
+			t.Errorf("case %d: unexpected non-empty Request: %#v", i, got)
+		}
+	}
+
+	goodTestCases := []struct {
+		in   *http.Request
+		want etcdserverpb.Request
+	}{
+		{
+			// good prefix, all other values default
+			&http.Request{
+				URL: makeURL(t, path.Join(keysPrefix, "foo")),
+			},
+			etcdserverpb.Request{
+				Id:   1234,
+				Path: "/foo",
+			},
+		},
+		{
+			// value specified
+			&http.Request{
+				URL: makeURL(t, path.Join(keysPrefix, "foo?value=some_value")),
+			},
+			etcdserverpb.Request{
+				Id:   1234,
+				Val:  "some_value",
+				Path: "/foo",
+			},
+		},
+		{
+			// prevIndex specified
+			&http.Request{
+				URL: makeURL(t, path.Join(keysPrefix, "foo?prevIndex=98765")),
+			},
+			etcdserverpb.Request{
+				Id:        1234,
+				PrevIndex: 98765,
+				Path:      "/foo",
+			},
+		},
+		{
+			// recursive specified
+			&http.Request{
+				URL: makeURL(t, path.Join(keysPrefix, "foo?recursive=true")),
+			},
+			etcdserverpb.Request{
+				Id:        1234,
+				Recursive: true,
+				Path:      "/foo",
+			},
+		},
+		{
+			// sorted specified
+			&http.Request{
+				URL: makeURL(t, path.Join(keysPrefix, "foo?sorted=true")),
+			},
+			etcdserverpb.Request{
+				Id:     1234,
+				Sorted: true,
+				Path:   "/foo",
+			},
+		},
+		{
+			// wait specified
+			&http.Request{
+				URL: makeURL(t, path.Join(keysPrefix, "foo?wait=true")),
+			},
+			etcdserverpb.Request{
+				Id:   1234,
+				Wait: true,
+				Path: "/foo",
+			},
+		},
+		{
+			// prevExists should be non-null if specified
+			&http.Request{
+				URL: makeURL(t, path.Join(keysPrefix, "foo?prevExists=true")),
+			},
+			etcdserverpb.Request{
+				Id:         1234,
+				PrevExists: boolp(true),
+				Path:       "/foo",
+			},
+		},
+		{
+			// prevExists should be non-null if specified
+			&http.Request{
+				URL: makeURL(t, path.Join(keysPrefix, "foo?prevExists=false")),
+			},
+			etcdserverpb.Request{
+				Id:         1234,
+				PrevExists: boolp(false),
+				Path:       "/foo",
+			},
+		},
+	}
+
+	for i, tt := range goodTestCases {
+		got, err := parseRequest(tt.in, 1234)
+		if err != nil {
+			t.Errorf("case %d: unexpected error: %#v", err)
+		}
+		if !reflect.DeepEqual(got, tt.want) {
+			t.Errorf("case %d: bad request: got %#v, want %#v", i, got, tt.want)
+		}
+	}
+}
 
 // eventingWatcher immediately returns a simple event of the given action on its channel
 type eventingWatcher struct {