Browse Source

etcdhttp: add test for streaming watches

Jonathan Boulle 11 years ago
parent
commit
a9caa24f8a
2 changed files with 115 additions and 30 deletions
  1. 2 9
      etcdserver/etcdhttp/http.go
  2. 113 21
      etcdserver/etcdhttp/http_test.go

+ 2 - 9
etcdserver/etcdhttp/http.go

@@ -7,7 +7,6 @@ import (
 	"io/ioutil"
 	"log"
 	"net/http"
-	"net/http/httputil"
 	"net/url"
 	"strconv"
 	"strings"
@@ -324,23 +323,18 @@ func handleWatch(ctx context.Context, w http.ResponseWriter, wa store.Watcher, s
 	}
 
 	w.Header().Set("Content-Type", "application/json")
-	// WriteHeader will do this implicitly, but best to be explicit.
-	w.Header().Set("Transfer-Encoding", "chunked")
 	w.WriteHeader(http.StatusOK)
 
 	// Ensure headers are flushed early, in case of long polling
 	w.(http.Flusher).Flush()
 
-	cw := httputil.NewChunkedWriter(w)
-
 	for {
 		select {
 		case <-nch:
 			// Client closed connection. Nothing to do.
 			return
 		case <-ctx.Done():
-			// Timed out. Close the connection gracefully.
-			cw.Close()
+			// Timed out. net/http will close the connection for us, so nothing to do.
 			return
 		case ev, ok := <-ech:
 			if !ok {
@@ -349,7 +343,7 @@ func handleWatch(ctx context.Context, w http.ResponseWriter, wa store.Watcher, s
 				// send to the client in time. Then we simply end streaming.
 				return
 			}
-			if err := json.NewEncoder(cw).Encode(ev); err != nil {
+			if err := json.NewEncoder(w).Encode(ev); err != nil {
 				// Should never be reached
 				log.Println("error writing event: %v", err)
 				return
@@ -360,7 +354,6 @@ func handleWatch(ctx context.Context, w http.ResponseWriter, wa store.Watcher, s
 			w.(http.Flusher).Flush()
 		}
 	}
-
 }
 
 // allowMethod verifies that the given method is one of the allowed methods,

+ 113 - 21
etcdserver/etcdhttp/http_test.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"encoding/json"
 	"errors"
-	"fmt"
 	"io"
 	"net/http"
 	"net/http/httptest"
@@ -964,7 +963,6 @@ func TestServeKeysWatch(t *testing.T) {
 			Node:   &store.NodeExtern{},
 		},
 	)
-	wbody = fmt.Sprintf("%x\r\n%s\r\n", len(wbody), wbody)
 
 	if rw.Code != wcode {
 		t.Errorf("got code=%d, want %d", rw.Code, wcode)
@@ -988,7 +986,6 @@ func TestHandleWatch(t *testing.T) {
 	handleWatch(context.Background(), rw, wa, false)
 
 	wcode := http.StatusOK
-	wte := "chunked"
 	wct := "application/json"
 	wbody := mustMarshalEvent(
 		t,
@@ -997,7 +994,6 @@ func TestHandleWatch(t *testing.T) {
 			Node:   &store.NodeExtern{},
 		},
 	)
-	wbody = fmt.Sprintf("%x\r\n%s\r\n", len(wbody), wbody)
 
 	if rw.Code != wcode {
 		t.Errorf("got code=%d, want %d", rw.Code, wcode)
@@ -1006,9 +1002,6 @@ func TestHandleWatch(t *testing.T) {
 	if ct := h.Get("Content-Type"); ct != wct {
 		t.Errorf("Content-Type=%q, want %q", ct, wct)
 	}
-	if te := h.Get("Transfer-Encoding"); te != wte {
-		t.Errorf("Transfer-Encoding=%q, want %q", te, wte)
-	}
 	g := rw.Body.String()
 	if g != wbody {
 		t.Errorf("got body=%#v, want %#v", g, wbody)
@@ -1025,7 +1018,6 @@ func TestHandleWatchNoEvent(t *testing.T) {
 	handleWatch(context.Background(), rw, wa, false)
 
 	wcode := http.StatusOK
-	wte := "chunked"
 	wct := "application/json"
 	wbody := ""
 
@@ -1036,9 +1028,6 @@ func TestHandleWatchNoEvent(t *testing.T) {
 	if ct := h.Get("Content-Type"); ct != wct {
 		t.Errorf("Content-Type=%q, want %q", ct, wct)
 	}
-	if te := h.Get("Transfer-Encoding"); te != wte {
-		t.Errorf("Transfer-Encoding=%q, want %q", te, wte)
-	}
 	g := rw.Body.String()
 	if g != wbody {
 		t.Errorf("got body=%#v, want %#v", g, wbody)
@@ -1065,7 +1054,6 @@ func TestHandleWatchCloseNotified(t *testing.T) {
 	handleWatch(context.Background(), rw, wa, false)
 
 	wcode := http.StatusOK
-	wte := "chunked"
 	wct := "application/json"
 	wbody := ""
 
@@ -1076,9 +1064,6 @@ func TestHandleWatchCloseNotified(t *testing.T) {
 	if ct := h.Get("Content-Type"); ct != wct {
 		t.Errorf("Content-Type=%q, want %q", ct, wct)
 	}
-	if te := h.Get("Transfer-Encoding"); te != wte {
-		t.Errorf("Transfer-Encoding=%q, want %q", te, wte)
-	}
 	g := rw.Body.String()
 	if g != wbody {
 		t.Errorf("got body=%#v, want %#v", g, wbody)
@@ -1095,9 +1080,8 @@ func TestHandleWatchTimeout(t *testing.T) {
 	handleWatch(ctx, rw, wa, false)
 
 	wcode := http.StatusOK
-	wte := "chunked"
 	wct := "application/json"
-	wbody := "0\r\n"
+	wbody := ""
 
 	if rw.Code != wcode {
 		t.Errorf("got code=%d, want %d", rw.Code, wcode)
@@ -1106,15 +1090,123 @@ func TestHandleWatchTimeout(t *testing.T) {
 	if ct := h.Get("Content-Type"); ct != wct {
 		t.Errorf("Content-Type=%q, want %q", ct, wct)
 	}
-	if te := h.Get("Transfer-Encoding"); te != wte {
-		t.Errorf("Transfer-Encoding=%q, want %q", te, wte)
-	}
 	g := rw.Body.String()
 	if g != wbody {
 		t.Errorf("got body=%#v, want %#v", g, wbody)
 	}
 }
 
+// flushingRecorder provides a channel to allow users to block until the Recorder is Flushed()
+type flushingRecorder struct {
+	*httptest.ResponseRecorder
+	ch chan struct{}
+}
+
+func (fr *flushingRecorder) Flush() {
+	fr.ResponseRecorder.Flush()
+	fr.ch <- struct{}{}
+}
+
 func TestHandleWatchStreaming(t *testing.T) {
-	// TODO(jonboulle): me
+	rw := &flushingRecorder{
+		httptest.NewRecorder(),
+		make(chan struct{}, 1),
+	}
+	wa := &dummyWatcher{
+		echan: make(chan *store.Event),
+	}
+
+	// Launch the streaming handler in the background with a cancellable context
+	ctx, cancel := context.WithCancel(context.Background())
+	done := make(chan struct{})
+	go func() {
+		handleWatch(ctx, rw, wa, true)
+		close(done)
+	}()
+
+	// Expect one Flush for the headers etc.
+	select {
+	case <-rw.ch:
+	case <-time.After(time.Second):
+		t.Fatalf("timed out waiting for flush")
+	}
+
+	// Expect headers but no body
+	wcode := http.StatusOK
+	wct := "application/json"
+	wbody := ""
+
+	if rw.Code != wcode {
+		t.Errorf("got code=%d, want %d", rw.Code, wcode)
+	}
+	h := rw.Header()
+	if ct := h.Get("Content-Type"); ct != wct {
+		t.Errorf("Content-Type=%q, want %q", ct, wct)
+	}
+	g := rw.Body.String()
+	if g != wbody {
+		t.Errorf("got body=%#v, want %#v", g, wbody)
+	}
+
+	// Now send the first event
+	select {
+	case wa.echan <- &store.Event{
+		Action: store.Get,
+		Node:   &store.NodeExtern{},
+	}:
+	case <-time.After(time.Second):
+		t.Fatal("timed out waiting for send")
+	}
+
+	// Wait for it to be flushed...
+	select {
+	case <-rw.ch:
+	case <-time.After(time.Second):
+		t.Fatalf("timed out waiting for flush")
+	}
+
+	// And check the body is as expected
+	wbody = mustMarshalEvent(
+		t,
+		&store.Event{
+			Action: store.Get,
+			Node:   &store.NodeExtern{},
+		},
+	)
+	g = rw.Body.String()
+	if g != wbody {
+		t.Errorf("got body=%#v, want %#v", g, wbody)
+	}
+
+	// Rinse and repeat
+	select {
+	case wa.echan <- &store.Event{
+		Action: store.Get,
+		Node:   &store.NodeExtern{},
+	}:
+	case <-time.After(time.Second):
+		t.Fatal("timed out waiting for send")
+	}
+
+	select {
+	case <-rw.ch:
+	case <-time.After(time.Second):
+		t.Fatalf("timed out waiting for flush")
+	}
+
+	// This time, we expect to see both events
+	wbody = wbody + wbody
+	g = rw.Body.String()
+	if g != wbody {
+		t.Errorf("got body=%#v, want %#v", g, wbody)
+	}
+
+	// Finally, time out the connection and ensure the serving goroutine returns
+	cancel()
+
+	select {
+	case <-done:
+	case <-time.After(time.Second):
+		t.Fatalf("timed out waiting for done")
+	}
 }