Browse Source

Merge pull request #1138 from jonboulle/1138_timeout

etcdserver: handle watch timeouts and streaming
Jonathan Boulle 11 years ago
parent
commit
ec1df42d04

+ 65 - 31
etcdserver/etcdhttp/http.go

@@ -12,7 +12,6 @@ import (
 	"strings"
 	"time"
 
-	"github.com/coreos/etcd/elog"
 	etcdErr "github.com/coreos/etcd/error"
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver/etcdserverpb"
@@ -26,7 +25,11 @@ const (
 	machinesPrefix = "/v2/machines"
 	raftPrefix     = "/raft"
 
-	DefaultTimeout = 500 * time.Millisecond
+	// time to wait for response from EtcdServer requests
+	defaultServerTimeout = 500 * time.Millisecond
+
+	// time to wait for a Watch request
+	defaultWatchTimeout = 5 * time.Minute
 )
 
 var errClosed = errors.New("etcdhttp: client closed connection")
@@ -39,7 +42,7 @@ func NewClientHandler(server etcdserver.Server, peers Peers, timeout time.Durati
 		timeout: timeout,
 	}
 	if sh.timeout == 0 {
-		sh.timeout = DefaultTimeout
+		sh.timeout = defaultServerTimeout
 	}
 	mux := http.NewServeMux()
 	mux.HandleFunc(keysPrefix, sh.serveKeys)
@@ -89,23 +92,18 @@ func (h serverHandler) serveKeys(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	var ev *store.Event
 	switch {
 	case resp.Event != nil:
-		ev = resp.Event
-	case resp.Watcher != nil:
-		if ev, err = waitForEvent(ctx, w, resp.Watcher); err != nil {
-			http.Error(w, err.Error(), http.StatusGatewayTimeout)
-			return
+		if err := writeEvent(w, resp.Event); err != nil {
+			// Should never be reached
+			log.Println("error writing event: %v", err)
 		}
+	case resp.Watcher != nil:
+		ctx, cancel := context.WithTimeout(context.Background(), defaultWatchTimeout)
+		defer cancel()
+		handleWatch(ctx, w, resp.Watcher, rr.Stream)
 	default:
 		writeError(w, errors.New("received response with no Event/Watcher!"))
-		return
-	}
-
-	if err = writeEvent(w, ev); err != nil {
-		// Should never be reached
-		log.Println("error writing event: %v", err)
 	}
 }
 
@@ -187,7 +185,7 @@ func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) {
 		)
 	}
 
-	var rec, sort, wait bool
+	var rec, sort, wait, stream bool
 	if rec, err = getBool(r.Form, "recursive"); err != nil {
 		return emptyReq, etcdErr.NewRequestError(
 			etcdErr.EcodeInvalidField,
@@ -206,6 +204,19 @@ func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) {
 			`invalid value for "wait"`,
 		)
 	}
+	if stream, err = getBool(r.Form, "stream"); err != nil {
+		return emptyReq, etcdErr.NewRequestError(
+			etcdErr.EcodeInvalidField,
+			`invalid value for "stream"`,
+		)
+	}
+
+	if wait && r.Method != "GET" {
+		return emptyReq, etcdErr.NewRequestError(
+			etcdErr.EcodeInvalidField,
+			`"wait" can only be used with GET requests`,
+		)
+	}
 
 	// prevExist is nullable, so leave it null if not specified
 	var pe *bool
@@ -231,6 +242,7 @@ func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) {
 		Recursive: rec,
 		Since:     wIdx,
 		Sorted:    sort,
+		Stream:    stream,
 		Wait:      wait,
 	}
 
@@ -285,8 +297,9 @@ func writeError(w http.ResponseWriter, err error) {
 	}
 }
 
-// writeEvent serializes the given Event and writes the resulting JSON to the
-// given ResponseWriter
+// writeEvent serializes a single Event and writes the resulting
+// JSON to the given ResponseWriter, along with the appropriate
+// headers
 func writeEvent(w http.ResponseWriter, ev *store.Event) error {
 	if ev == nil {
 		return errors.New("cannot write empty Event!")
@@ -301,24 +314,45 @@ func writeEvent(w http.ResponseWriter, ev *store.Event) error {
 	return json.NewEncoder(w).Encode(ev)
 }
 
-// waitForEvent waits for a given Watcher to return its associated
-// event. It returns a non-nil error if the given Context times out
-// or the given ResponseWriter triggers a CloseNotify.
-func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher) (*store.Event, error) {
-	// TODO(bmizerany): support streaming?
+func handleWatch(ctx context.Context, w http.ResponseWriter, wa store.Watcher, stream bool) {
 	defer wa.Remove()
+	ech := wa.EventChan()
 	var nch <-chan bool
 	if x, ok := w.(http.CloseNotifier); ok {
 		nch = x.CloseNotify()
 	}
-	select {
-	case ev := <-wa.EventChan():
-		return ev, nil
-	case <-nch:
-		elog.TODO()
-		return nil, errClosed
-	case <-ctx.Done():
-		return nil, ctx.Err()
+
+	w.Header().Set("Content-Type", "application/json")
+	w.WriteHeader(http.StatusOK)
+
+	// Ensure headers are flushed early, in case of long polling
+	w.(http.Flusher).Flush()
+
+	for {
+		select {
+		case <-nch:
+			// Client closed connection. Nothing to do.
+			return
+		case <-ctx.Done():
+			// Timed out. net/http will close the connection for us, so nothing to do.
+			return
+		case ev, ok := <-ech:
+			if !ok {
+				// If the channel is closed this may be an indication of
+				// that notifications are much more than we are able to
+				// send to the client in time. Then we simply end streaming.
+				return
+			}
+			if err := json.NewEncoder(w).Encode(ev); err != nil {
+				// Should never be reached
+				log.Println("error writing event: %v", err)
+				return
+			}
+			if !stream {
+				return
+			}
+			w.(http.Flusher).Flush()
+		}
 	}
 }
 

+ 259 - 114
etcdserver/etcdhttp/http_test.go

@@ -11,7 +11,6 @@ import (
 	"path"
 	"reflect"
 	"strings"
-	"sync"
 	"testing"
 	"time"
 
@@ -36,8 +35,12 @@ func mustNewURL(t *testing.T, s string) *url.URL {
 // mustNewRequest takes a path, appends it to the standard keysPrefix, and constructs
 // a GET *http.Request referencing the resulting URL
 func mustNewRequest(t *testing.T, p string) *http.Request {
+	return mustNewMethodRequest(t, "GET", p)
+}
+
+func mustNewMethodRequest(t *testing.T, m, p string) *http.Request {
 	return &http.Request{
-		Method: "GET",
+		Method: m,
 		URL:    mustNewURL(t, path.Join(keysPrefix, p)),
 	}
 }
@@ -99,7 +102,7 @@ func TestBadParseRequest(t *testing.T) {
 			mustNewForm(t, "foo", url.Values{"ttl": []string{"-1"}}),
 			etcdErr.EcodeTTLNaN,
 		},
-		// bad values for recursive, sorted, wait, prevExist
+		// bad values for recursive, sorted, wait, prevExist, stream
 		{
 			mustNewForm(t, "foo", url.Values{"recursive": []string{"hahaha"}}),
 			etcdErr.EcodeInvalidField,
@@ -136,6 +139,19 @@ func TestBadParseRequest(t *testing.T) {
 			mustNewForm(t, "foo", url.Values{"prevExist": []string{"#2"}}),
 			etcdErr.EcodeInvalidField,
 		},
+		{
+			mustNewForm(t, "foo", url.Values{"stream": []string{"zzz"}}),
+			etcdErr.EcodeInvalidField,
+		},
+		{
+			mustNewForm(t, "foo", url.Values{"stream": []string{"something"}}),
+			etcdErr.EcodeInvalidField,
+		},
+		// wait is only valid with GET requests
+		{
+			mustNewMethodRequest(t, "HEAD", "foo?wait=true"),
+			etcdErr.EcodeInvalidField,
+		},
 		// query values are considered
 		{
 			mustNewRequest(t, "foo?prevExist=wrong"),
@@ -256,14 +272,10 @@ func TestGoodParseRequest(t *testing.T) {
 		},
 		{
 			// wait specified
-			mustNewForm(
-				t,
-				"foo",
-				url.Values{"wait": []string{"true"}},
-			),
+			mustNewRequest(t, "foo?wait=true"),
 			etcdserverpb.Request{
 				Id:     1234,
-				Method: "PUT",
+				Method: "GET",
 				Wait:   true,
 				Path:   "/foo",
 			},
@@ -492,100 +504,6 @@ func (w *dummyWatcher) EventChan() chan *store.Event {
 }
 func (w *dummyWatcher) Remove() {}
 
-type dummyResponseWriter struct {
-	cnchan chan bool
-	http.ResponseWriter
-}
-
-func (rw *dummyResponseWriter) CloseNotify() <-chan bool {
-	return rw.cnchan
-}
-
-func TestWaitForEventChan(t *testing.T) {
-	ctx := context.Background()
-	ec := make(chan *store.Event)
-	dw := &dummyWatcher{
-		echan: ec,
-	}
-	w := httptest.NewRecorder()
-	var wg sync.WaitGroup
-	var ev *store.Event
-	var err error
-	wg.Add(1)
-	go func() {
-		ev, err = waitForEvent(ctx, w, dw)
-		wg.Done()
-	}()
-	ec <- &store.Event{
-		Action: store.Get,
-		Node: &store.NodeExtern{
-			Key:           "/foo/bar",
-			ModifiedIndex: 12345,
-		},
-	}
-	wg.Wait()
-	want := &store.Event{
-		Action: store.Get,
-		Node: &store.NodeExtern{
-			Key:           "/foo/bar",
-			ModifiedIndex: 12345,
-		},
-	}
-	if !reflect.DeepEqual(ev, want) {
-		t.Fatalf("bad event: got %#v, want %#v", ev, want)
-	}
-	if err != nil {
-		t.Fatalf("unexpected error: %v", err)
-	}
-}
-
-func TestWaitForEventCloseNotify(t *testing.T) {
-	ctx := context.Background()
-	dw := &dummyWatcher{}
-	cnchan := make(chan bool)
-	w := &dummyResponseWriter{
-		cnchan: cnchan,
-	}
-	var wg sync.WaitGroup
-	var ev *store.Event
-	var err error
-	wg.Add(1)
-	go func() {
-		ev, err = waitForEvent(ctx, w, dw)
-		wg.Done()
-	}()
-	close(cnchan)
-	wg.Wait()
-	if ev != nil {
-		t.Fatalf("non-nil Event returned with CloseNotifier: %v", ev)
-	}
-	if err == nil {
-		t.Fatalf("nil err returned with CloseNotifier!")
-	}
-}
-
-func TestWaitForEventCancelledContext(t *testing.T) {
-	cctx, cancel := context.WithCancel(context.Background())
-	dw := &dummyWatcher{}
-	w := httptest.NewRecorder()
-	var wg sync.WaitGroup
-	var ev *store.Event
-	var err error
-	wg.Add(1)
-	go func() {
-		ev, err = waitForEvent(cctx, w, dw)
-		wg.Done()
-	}()
-	cancel()
-	wg.Wait()
-	if ev != nil {
-		t.Fatalf("non-nil Event returned with cancelled context: %v", ev)
-	}
-	if err == nil {
-		t.Fatalf("nil err returned with cancelled context!")
-	}
-}
-
 func TestV2MachinesEndpoint(t *testing.T) {
 	tests := []struct {
 		method string
@@ -950,17 +868,6 @@ func TestBadServeKeys(t *testing.T) {
 
 			http.StatusInternalServerError,
 		},
-		{
-			// timeout waiting for event (watcher never returns)
-			mustNewRequest(t, "foo"),
-			&resServer{
-				etcdserver.Response{
-					Watcher: &dummyWatcher{},
-				},
-			},
-
-			http.StatusGatewayTimeout,
-		},
 		{
 			// non-event/watcher response from etcdserver.Server
 			mustNewRequest(t, "foo"),
@@ -1065,3 +972,241 @@ func TestServeKeysWatch(t *testing.T) {
 		t.Errorf("got body=%#v, want %#v", g, wbody)
 	}
 }
+
+func TestHandleWatch(t *testing.T) {
+	rw := httptest.NewRecorder()
+	wa := &dummyWatcher{
+		echan: make(chan *store.Event, 1),
+	}
+	wa.echan <- &store.Event{
+		Action: store.Get,
+		Node:   &store.NodeExtern{},
+	}
+
+	handleWatch(context.Background(), rw, wa, false)
+
+	wcode := http.StatusOK
+	wct := "application/json"
+	wbody := mustMarshalEvent(
+		t,
+		&store.Event{
+			Action: store.Get,
+			Node:   &store.NodeExtern{},
+		},
+	)
+
+	if rw.Code != wcode {
+		t.Errorf("got code=%d, want %d", rw.Code, wcode)
+	}
+	h := rw.Header()
+	if ct := h.Get("Content-Type"); ct != wct {
+		t.Errorf("Content-Type=%q, want %q", ct, wct)
+	}
+	g := rw.Body.String()
+	if g != wbody {
+		t.Errorf("got body=%#v, want %#v", g, wbody)
+	}
+}
+
+func TestHandleWatchNoEvent(t *testing.T) {
+	rw := httptest.NewRecorder()
+	wa := &dummyWatcher{
+		echan: make(chan *store.Event, 1),
+	}
+	close(wa.echan)
+
+	handleWatch(context.Background(), rw, wa, false)
+
+	wcode := http.StatusOK
+	wct := "application/json"
+	wbody := ""
+
+	if rw.Code != wcode {
+		t.Errorf("got code=%d, want %d", rw.Code, wcode)
+	}
+	h := rw.Header()
+	if ct := h.Get("Content-Type"); ct != wct {
+		t.Errorf("Content-Type=%q, want %q", ct, wct)
+	}
+	g := rw.Body.String()
+	if g != wbody {
+		t.Errorf("got body=%#v, want %#v", g, wbody)
+	}
+}
+
+type recordingCloseNotifier struct {
+	*httptest.ResponseRecorder
+	cn chan bool
+}
+
+func (rcn *recordingCloseNotifier) CloseNotify() <-chan bool {
+	return rcn.cn
+}
+
+func TestHandleWatchCloseNotified(t *testing.T) {
+	rw := &recordingCloseNotifier{
+		ResponseRecorder: httptest.NewRecorder(),
+		cn:               make(chan bool, 1),
+	}
+	rw.cn <- true
+	wa := &dummyWatcher{}
+
+	handleWatch(context.Background(), rw, wa, false)
+
+	wcode := http.StatusOK
+	wct := "application/json"
+	wbody := ""
+
+	if rw.Code != wcode {
+		t.Errorf("got code=%d, want %d", rw.Code, wcode)
+	}
+	h := rw.Header()
+	if ct := h.Get("Content-Type"); ct != wct {
+		t.Errorf("Content-Type=%q, want %q", ct, wct)
+	}
+	g := rw.Body.String()
+	if g != wbody {
+		t.Errorf("got body=%#v, want %#v", g, wbody)
+	}
+}
+
+func TestHandleWatchTimeout(t *testing.T) {
+	rw := httptest.NewRecorder()
+	wa := &dummyWatcher{}
+	// Simulate a timed-out context
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+
+	handleWatch(ctx, rw, wa, false)
+
+	wcode := http.StatusOK
+	wct := "application/json"
+	wbody := ""
+
+	if rw.Code != wcode {
+		t.Errorf("got code=%d, want %d", rw.Code, wcode)
+	}
+	h := rw.Header()
+	if ct := h.Get("Content-Type"); ct != wct {
+		t.Errorf("Content-Type=%q, want %q", ct, wct)
+	}
+	g := rw.Body.String()
+	if g != wbody {
+		t.Errorf("got body=%#v, want %#v", g, wbody)
+	}
+}
+
+// flushingRecorder provides a channel to allow users to block until the Recorder is Flushed()
+type flushingRecorder struct {
+	*httptest.ResponseRecorder
+	ch chan struct{}
+}
+
+func (fr *flushingRecorder) Flush() {
+	fr.ResponseRecorder.Flush()
+	fr.ch <- struct{}{}
+}
+
+func TestHandleWatchStreaming(t *testing.T) {
+	rw := &flushingRecorder{
+		httptest.NewRecorder(),
+		make(chan struct{}, 1),
+	}
+	wa := &dummyWatcher{
+		echan: make(chan *store.Event),
+	}
+
+	// Launch the streaming handler in the background with a cancellable context
+	ctx, cancel := context.WithCancel(context.Background())
+	done := make(chan struct{})
+	go func() {
+		handleWatch(ctx, rw, wa, true)
+		close(done)
+	}()
+
+	// Expect one Flush for the headers etc.
+	select {
+	case <-rw.ch:
+	case <-time.After(time.Second):
+		t.Fatalf("timed out waiting for flush")
+	}
+
+	// Expect headers but no body
+	wcode := http.StatusOK
+	wct := "application/json"
+	wbody := ""
+
+	if rw.Code != wcode {
+		t.Errorf("got code=%d, want %d", rw.Code, wcode)
+	}
+	h := rw.Header()
+	if ct := h.Get("Content-Type"); ct != wct {
+		t.Errorf("Content-Type=%q, want %q", ct, wct)
+	}
+	g := rw.Body.String()
+	if g != wbody {
+		t.Errorf("got body=%#v, want %#v", g, wbody)
+	}
+
+	// Now send the first event
+	select {
+	case wa.echan <- &store.Event{
+		Action: store.Get,
+		Node:   &store.NodeExtern{},
+	}:
+	case <-time.After(time.Second):
+		t.Fatal("timed out waiting for send")
+	}
+
+	// Wait for it to be flushed...
+	select {
+	case <-rw.ch:
+	case <-time.After(time.Second):
+		t.Fatalf("timed out waiting for flush")
+	}
+
+	// And check the body is as expected
+	wbody = mustMarshalEvent(
+		t,
+		&store.Event{
+			Action: store.Get,
+			Node:   &store.NodeExtern{},
+		},
+	)
+	g = rw.Body.String()
+	if g != wbody {
+		t.Errorf("got body=%#v, want %#v", g, wbody)
+	}
+
+	// Rinse and repeat
+	select {
+	case wa.echan <- &store.Event{
+		Action: store.Get,
+		Node:   &store.NodeExtern{},
+	}:
+	case <-time.After(time.Second):
+		t.Fatal("timed out waiting for send")
+	}
+
+	select {
+	case <-rw.ch:
+	case <-time.After(time.Second):
+		t.Fatalf("timed out waiting for flush")
+	}
+
+	// This time, we expect to see both events
+	wbody = wbody + wbody
+	g = rw.Body.String()
+	if g != wbody {
+		t.Errorf("got body=%#v, want %#v", g, wbody)
+	}
+
+	// Finally, time out the connection and ensure the serving goroutine returns
+	cancel()
+
+	select {
+	case <-done:
+	case <-time.After(time.Second):
+		t.Fatalf("timed out waiting for done")
+	}
+}

+ 29 - 0
etcdserver/etcdserverpb/etcdserver.pb.go

@@ -43,6 +43,7 @@ type Request struct {
 	Sorted           bool   `protobuf:"varint,13,req,name=sorted" json:"sorted"`
 	Quorum           bool   `protobuf:"varint,14,req,name=quorum" json:"quorum"`
 	Time             int64  `protobuf:"varint,15,req,name=time" json:"time"`
+	Stream           bool   `protobuf:"varint,16,req,name=stream" json:"stream"`
 	XXX_unrecognized []byte `json:"-"`
 }
 
@@ -337,6 +338,23 @@ func (m *Request) Unmarshal(data []byte) error {
 					break
 				}
 			}
+		case 16:
+			if wireType != 0 {
+				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+			}
+			var v int
+			for shift := uint(0); ; shift += 7 {
+				if index >= l {
+					return io.ErrUnexpectedEOF
+				}
+				b := data[index]
+				index++
+				v |= (int(b) & 0x7F) << shift
+				if b < 0x80 {
+					break
+				}
+			}
+			m.Stream = bool(v != 0)
 		default:
 			var sizeOfWire int
 			for {
@@ -384,6 +402,7 @@ func (m *Request) Size() (n int) {
 	n += 2
 	n += 2
 	n += 1 + sovEtcdserver(uint64(m.Time))
+	n += 3
 	if m.XXX_unrecognized != nil {
 		n += len(m.XXX_unrecognized)
 	}
@@ -499,6 +518,16 @@ func (m *Request) MarshalTo(data []byte) (n int, err error) {
 	data[i] = 0x78
 	i++
 	i = encodeVarintEtcdserver(data, i, uint64(m.Time))
+	data[i] = 0x80
+	i++
+	data[i] = 0x1
+	i++
+	if m.Stream {
+		data[i] = 1
+	} else {
+		data[i] = 0
+	}
+	i++
 	if m.XXX_unrecognized != nil {
 		i += copy(data[i:], m.XXX_unrecognized)
 	}

+ 1 - 0
etcdserver/etcdserverpb/etcdserver.proto

@@ -23,4 +23,5 @@ message Request {
 	required bool   sorted     = 13 [(gogoproto.nullable) = false];
 	required bool   quorum     = 14 [(gogoproto.nullable) = false];
 	required int64  time       = 15 [(gogoproto.nullable) = false];
+	required bool   stream     = 16 [(gogoproto.nullable) = false];
 }

+ 1 - 1
etcdserver/server.go

@@ -213,7 +213,7 @@ func (s *EtcdServer) Do(ctx context.Context, r pb.Request) (Response, error) {
 	case "GET":
 		switch {
 		case r.Wait:
-			wc, err := s.Store.Watch(r.Path, r.Recursive, false, r.Since)
+			wc, err := s.Store.Watch(r.Path, r.Recursive, r.Stream, r.Since)
 			if err != nil {
 				return Response{}, err
 			}