Browse Source

Merge pull request #4528 from heyitsanthony/fix-watchcurrev

fix several watcher races
Anthony Romano 9 years ago
parent
commit
ef2d3feca6
3 changed files with 216 additions and 100 deletions
  1. 48 10
      etcdserver/api/v3rpc/watch.go
  2. 100 32
      integration/v3_watch_test.go
  3. 68 58
      storage/watchable_store.go

+ 48 - 10
etcdserver/api/v3rpc/watch.go

@@ -102,9 +102,16 @@ func (sws *serverWatchStream) recvLoop() error {
 					toWatch = creq.Prefix
 					prefix = true
 				}
-				id := sws.watchStream.Watch(toWatch, prefix, creq.StartRevision)
+
+				rev := creq.StartRevision
+				wsrev := sws.watchStream.Rev()
+				if rev == 0 {
+					// rev 0 watches past the current revision
+					rev = wsrev + 1
+				}
+				id := sws.watchStream.Watch(toWatch, prefix, rev)
 				sws.ctrlStream <- &pb.WatchResponse{
-					Header:  sws.newResponseHeader(sws.watchStream.Rev()),
+					Header:  sws.newResponseHeader(wsrev),
 					WatchId: int64(id),
 					Created: true,
 				}
@@ -129,6 +136,11 @@ func (sws *serverWatchStream) recvLoop() error {
 }
 
 func (sws *serverWatchStream) sendLoop() {
+	// watch ids that are currently active
+	ids := make(map[storage.WatchID]struct{})
+	// watch responses pending on a watch id creation message
+	pending := make(map[storage.WatchID][]*pb.WatchResponse)
+
 	for {
 		select {
 		case wresp, ok := <-sws.watchStream.Chan():
@@ -145,14 +157,22 @@ func (sws *serverWatchStream) sendLoop() {
 				events[i] = &evs[i]
 			}
 
-			err := sws.gRPCStream.Send(&pb.WatchResponse{
+			wr := &pb.WatchResponse{
 				Header:          sws.newResponseHeader(wresp.Revision),
 				WatchId:         int64(wresp.WatchID),
 				Events:          events,
 				CompactRevision: wresp.CompactRevision,
-			})
+			}
+
+			if _, hasId := ids[wresp.WatchID]; !hasId {
+				// buffer if id not yet announced
+				wrs := append(pending[wresp.WatchID], wr)
+				pending[wresp.WatchID] = wrs
+				continue
+			}
+
 			storage.ReportEventReceived()
-			if err != nil {
+			if err := sws.gRPCStream.Send(wr); err != nil {
 				return
 			}
 
@@ -165,15 +185,33 @@ func (sws *serverWatchStream) sendLoop() {
 				return
 			}
 
+			// track id creation
+			wid := storage.WatchID(c.WatchId)
+			if c.Canceled {
+				delete(ids, wid)
+				continue
+			}
+			if c.Created {
+				// flush buffered events
+				ids[wid] = struct{}{}
+				for _, v := range pending[wid] {
+					storage.ReportEventReceived()
+					if err := sws.gRPCStream.Send(v); err != nil {
+						return
+					}
+				}
+				delete(pending, wid)
+			}
 		case <-sws.closec:
 			// drain the chan to clean up pending events
-			for {
-				_, ok := <-sws.watchStream.Chan()
-				if !ok {
-					return
-				}
+			for range sws.watchStream.Chan() {
 				storage.ReportEventReceived()
 			}
+			for _, wrs := range pending {
+				for range wrs {
+					storage.ReportEventReceived()
+				}
+			}
 		}
 	}
 }

+ 100 - 32
integration/v3_watch_test.go

@@ -45,10 +45,6 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 					Key: []byte("foo")}}},
 
 			[]*pb.WatchResponse{
-				{
-					Header:  &pb.ResponseHeader{Revision: 1},
-					Created: true,
-				},
 				{
 					Header:  &pb.ResponseHeader{Revision: 2},
 					Created: false,
@@ -68,12 +64,7 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 				CreateRequest: &pb.WatchCreateRequest{
 					Key: []byte("helloworld")}}},
 
-			[]*pb.WatchResponse{
-				{
-					Header:  &pb.ResponseHeader{Revision: 1},
-					Created: true,
-				},
-			},
+			[]*pb.WatchResponse{},
 		},
 		// watch the prefix, matching
 		{
@@ -83,10 +74,6 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 					Prefix: []byte("foo")}}},
 
 			[]*pb.WatchResponse{
-				{
-					Header:  &pb.ResponseHeader{Revision: 1},
-					Created: true,
-				},
 				{
 					Header:  &pb.ResponseHeader{Revision: 2},
 					Created: false,
@@ -106,12 +93,7 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 				CreateRequest: &pb.WatchCreateRequest{
 					Prefix: []byte("helloworld")}}},
 
-			[]*pb.WatchResponse{
-				{
-					Header:  &pb.ResponseHeader{Revision: 1},
-					Created: true,
-				},
-			},
+			[]*pb.WatchResponse{},
 		},
 		// multiple puts, one watcher with matching key
 		{
@@ -121,10 +103,6 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 					Key: []byte("foo")}}},
 
 			[]*pb.WatchResponse{
-				{
-					Header:  &pb.ResponseHeader{Revision: 1},
-					Created: true,
-				},
 				{
 					Header:  &pb.ResponseHeader{Revision: 2},
 					Created: false,
@@ -165,10 +143,6 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 					Prefix: []byte("foo")}}},
 
 			[]*pb.WatchResponse{
-				{
-					Header:  &pb.ResponseHeader{Revision: 1},
-					Created: true,
-				},
 				{
 					Header:  &pb.ResponseHeader{Revision: 2},
 					Created: false,
@@ -218,6 +192,23 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 			t.Fatalf("#%d: wStream.Send error: %v", i, err)
 		}
 
+		// ensure watcher request created a new watcher
+		cresp, err := wStream.Recv()
+		if err != nil {
+			t.Errorf("#%d: wStream.Recv error: %v", i, err)
+			continue
+		}
+		if cresp.Created != true {
+			t.Errorf("#%d: did not create watchid, got +%v", i, cresp)
+			continue
+		}
+		createdWatchId := cresp.WatchId
+		if cresp.Header == nil || cresp.Header.Revision != 1 {
+			t.Errorf("#%d: header revision got +%v, wanted revison 1", i, cresp)
+			continue
+		}
+
+		// asynchronously create keys
 		go func() {
 			for _, k := range tt.putKeys {
 				kvc := clus.RandClient().KV
@@ -228,7 +219,7 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 			}
 		}()
 
-		var createdWatchId int64
+		// check stream results
 		for j, wresp := range tt.wresps {
 			resp, err := wStream.Recv()
 			if err != nil {
@@ -245,9 +236,6 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 			if wresp.Created != resp.Created {
 				t.Errorf("#%d.%d: resp.Created got = %v, want = %v", i, j, resp.Created, wresp.Created)
 			}
-			if resp.Created {
-				createdWatchId = resp.WatchId
-			}
 			if resp.WatchId != createdWatchId {
 				t.Errorf("#%d.%d: resp.WatchId got = %d, want = %d", i, j, resp.WatchId, createdWatchId)
 			}
@@ -333,6 +321,86 @@ func testV3WatchCancel(t *testing.T, startRev int64) {
 	clus.Terminate(t)
 }
 
+// TestV3WatchCurrentPutOverlap ensures current watchers receive all events with
+// overlapping puts.
+func TestV3WatchCurrentPutOverlap(t *testing.T) {
+	defer testutil.AfterTest(t)
+	clus := NewClusterV3(t, &ClusterConfig{Size: 3})
+	defer clus.Terminate(t)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+	defer cancel()
+	wStream, wErr := clus.RandClient().Watch.Watch(ctx)
+	if wErr != nil {
+		t.Fatalf("wAPI.Watch error: %v", wErr)
+	}
+
+	// last mod_revision that will be observed
+	nrRevisions := 32
+	// first revision already allocated as empty revision
+	for i := 1; i < nrRevisions; i++ {
+		go func() {
+			kvc := clus.RandClient().KV
+			req := &pb.PutRequest{Key: []byte("foo"), Value: []byte("bar")}
+			if _, err := kvc.Put(context.TODO(), req); err != nil {
+				t.Fatalf("couldn't put key (%v)", err)
+			}
+		}()
+	}
+
+	// maps watcher to current expected revision
+	progress := make(map[int64]int64)
+
+	wreq := &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
+		CreateRequest: &pb.WatchCreateRequest{Prefix: []byte("foo")}}}
+	if err := wStream.Send(wreq); err != nil {
+		t.Fatalf("first watch request failed (%v)", err)
+	}
+
+	more := true
+	progress[-1] = 0 // watcher creation pending
+	for more {
+		resp, err := wStream.Recv()
+		if err != nil {
+			t.Fatalf("wStream.Recv error: %v", err)
+		}
+
+		if resp.Created {
+			// accept events > header revision
+			progress[resp.WatchId] = resp.Header.Revision + 1
+			if resp.Header.Revision == int64(nrRevisions) {
+				// covered all revisions; create no more watchers
+				progress[-1] = int64(nrRevisions) + 1
+			} else if err := wStream.Send(wreq); err != nil {
+				t.Fatalf("watch request failed (%v)", err)
+			}
+		} else if len(resp.Events) == 0 {
+			t.Fatalf("got events %v, want non-empty", resp.Events)
+		} else {
+			wRev, ok := progress[resp.WatchId]
+			if !ok {
+				t.Fatalf("got %+v, but watch id shouldn't exist ", resp)
+			}
+			if resp.Events[0].Kv.ModRevision != wRev {
+				t.Fatalf("got %+v, wanted first revision %d", resp, wRev)
+			}
+			lastRev := resp.Events[len(resp.Events)-1].Kv.ModRevision
+			progress[resp.WatchId] = lastRev + 1
+		}
+		more = false
+		for _, v := range progress {
+			if v <= int64(nrRevisions) {
+				more = true
+				break
+			}
+		}
+	}
+
+	if rok, nr := waitResponse(wStream, time.Second); !rok {
+		t.Errorf("unexpected pb.WatchResponse is received %+v", nr)
+	}
+}
+
 func TestV3WatchMultipleWatchersSynced(t *testing.T) {
 	defer testutil.AfterTest(t)
 	testV3WatchMultipleWatchers(t, 0)

+ 68 - 58
storage/watchable_store.go

@@ -155,6 +155,7 @@ func (s *watchableStore) DeleteRange(key, end []byte) (n, rev int64) {
 		evs[i] = storagepb.Event{
 			Type: storagepb.DELETE,
 			Kv:   &change}
+		evs[i].Kv.ModRevision = rev
 	}
 	s.notify(rev, evs)
 	return n, rev
@@ -177,6 +178,7 @@ func (s *watchableStore) TxnEnd(txnID int64) error {
 		return nil
 	}
 
+	rev := s.store.Rev()
 	evs := make([]storagepb.Event, len(changes))
 	for i, change := range changes {
 		switch change.Value {
@@ -184,6 +186,7 @@ func (s *watchableStore) TxnEnd(txnID int64) error {
 			evs[i] = storagepb.Event{
 				Type: storagepb.DELETE,
 				Kv:   &changes[i]}
+			evs[i].Kv.ModRevision = rev
 		default:
 			evs[i] = storagepb.Event{
 				Type: storagepb.PUT,
@@ -191,7 +194,7 @@ func (s *watchableStore) TxnEnd(txnID int64) error {
 		}
 	}
 
-	s.notify(s.store.Rev(), evs)
+	s.notify(rev, evs)
 	s.mu.Unlock()
 
 	return nil
@@ -224,7 +227,16 @@ func (s *watchableStore) watch(key []byte, prefix bool, startRev int64, id Watch
 		ch:     ch,
 	}
 
-	if startRev == 0 {
+	s.store.mu.Lock()
+	synced := startRev > s.store.currentRev.main || startRev == 0
+	if synced {
+		wa.cur = s.store.currentRev.main + 1
+	}
+	s.store.mu.Unlock()
+	if synced {
+		if startRev > wa.cur {
+			panic("can't watch past sync revision")
+		}
 		s.synced.add(wa)
 	} else {
 		slowWatcherGauge.Inc()
@@ -284,12 +296,45 @@ func (s *watchableStore) syncWatchers() {
 	// in order to find key-value pairs from unsynced watchers, we need to
 	// find min revision index, and these revisions can be used to
 	// query the backend store of key-value pairs
-	minRev := int64(math.MaxInt64)
+	prefixes, minRev := s.scanUnsync()
+	curRev := s.store.currentRev.main
+	minBytes, maxBytes := newRevBytes(), newRevBytes()
+	revToBytes(revision{main: minRev}, minBytes)
+	revToBytes(revision{main: curRev + 1}, maxBytes)
+
+	// UnsafeRange returns keys and values. And in boltdb, keys are revisions.
+	// values are actual key-value pairs in backend.
+	tx := s.store.b.BatchTx()
+	tx.Lock()
+	revs, vs := tx.UnsafeRange(keyBucketName, minBytes, maxBytes, 0)
+	evs := kvsToEvents(revs, vs, s.unsynced, prefixes)
+	tx.Unlock()
+
+	for w, es := range newWatcherToEventMap(s.unsynced, evs) {
+		select {
+		// s.store.Rev also uses Lock, so just return directly
+		case w.ch <- WatchResponse{WatchID: w.id, Events: es, Revision: s.store.currentRev.main}:
+			pendingEventsGauge.Add(float64(len(es)))
+		default:
+			// TODO: handle the full unsynced watchers.
+			// continue to process other watchers for now, the full ones
+			// will be processed next time and hopefully it will not be full.
+			continue
+		}
+		w.cur = curRev
+		s.synced.add(w)
+		s.unsynced.delete(w)
+	}
 
+	slowWatcherGauge.Set(float64(len(s.unsynced)))
+}
+
+func (s *watchableStore) scanUnsync() (prefixes map[string]struct{}, minRev int64) {
 	curRev := s.store.currentRev.main
 	compactionRev := s.store.compactMainRev
 
-	prefixes := make(map[string]struct{})
+	prefixes = make(map[string]struct{})
+	minRev = int64(math.MaxInt64)
 	for _, set := range s.unsynced {
 		for w := range set {
 			k := string(w.key)
@@ -308,7 +353,7 @@ func (s *watchableStore) syncWatchers() {
 				continue
 			}
 
-			if minRev >= w.cur {
+			if minRev > w.cur {
 				minRev = w.cur
 			}
 
@@ -318,60 +363,31 @@ func (s *watchableStore) syncWatchers() {
 		}
 	}
 
-	minBytes, maxBytes := newRevBytes(), newRevBytes()
-	revToBytes(revision{main: minRev}, minBytes)
-	revToBytes(revision{main: curRev + 1}, maxBytes)
-
-	// UnsafeRange returns keys and values. And in boltdb, keys are revisions.
-	// values are actual key-value pairs in backend.
-	tx := s.store.b.BatchTx()
-	tx.Lock()
-	ks, vs := tx.UnsafeRange(keyBucketName, minBytes, maxBytes, 0)
-
-	evs := []storagepb.Event{}
+	return prefixes, minRev
+}
 
-	// get the list of all events from all key-value pairs
-	for i, v := range vs {
+// kvsToEvents gets all events for the watchers from all key-value pairs
+func kvsToEvents(revs, vals [][]byte, wsk watcherSetByKey, pfxs map[string]struct{}) (evs []storagepb.Event) {
+	for i, v := range vals {
 		var kv storagepb.KeyValue
 		if err := kv.Unmarshal(v); err != nil {
 			log.Panicf("storage: cannot unmarshal event: %v", err)
 		}
 
 		k := string(kv.Key)
-		if _, ok := s.unsynced.getSetByKey(k); !ok && !matchPrefix(k, prefixes) {
+		if _, ok := wsk.getSetByKey(k); !ok && !matchPrefix(k, pfxs) {
 			continue
 		}
 
-		var ev storagepb.Event
-		switch {
-		case isTombstone(ks[i]):
-			ev.Type = storagepb.DELETE
-		default:
-			ev.Type = storagepb.PUT
-		}
-		ev.Kv = &kv
-
-		evs = append(evs, ev)
-	}
-	tx.Unlock()
-
-	for w, es := range newWatcherToEventMap(s.unsynced, evs) {
-		select {
-		// s.store.Rev also uses Lock, so just return directly
-		case w.ch <- WatchResponse{WatchID: w.id, Events: es, Revision: s.store.currentRev.main}:
-			pendingEventsGauge.Add(float64(len(es)))
-		default:
-			// TODO: handle the full unsynced watchers.
-			// continue to process other watchers for now, the full ones
-			// will be processed next time and hopefully it will not be full.
-			continue
+		ty := storagepb.PUT
+		if isTombstone(revs[i]) {
+			ty = storagepb.DELETE
+			// patch in mod revision so watchers won't skip
+			kv.ModRevision = bytesToRev(revs[i]).main
 		}
-		w.cur = curRev
-		s.synced.add(w)
-		s.unsynced.delete(w)
+		evs = append(evs, storagepb.Event{Kv: &kv, Type: ty})
 	}
-
-	slowWatcherGauge.Set(float64(len(s.unsynced)))
+	return evs
 }
 
 // notify notifies the fact that given event at the given rev just happened to
@@ -426,23 +442,17 @@ func newWatcherToEventMap(sm watcherSetByKey, evs []storagepb.Event) map[*watche
 
 		// check all prefixes of the key to notify all corresponded watchers
 		for i := 0; i <= len(key); i++ {
-			k := string(key[:i])
-
-			wm, ok := sm[k]
-			if !ok {
-				continue
-			}
+			for w := range sm[key[:i]] {
+				// don't double notify
+				if ev.Kv.ModRevision < w.cur {
+					continue
+				}
 
-			for w := range wm {
 				// the watcher needs to be notified when either it watches prefix or
 				// the key is exactly matched.
 				if !w.prefix && i != len(ev.Kv.Key) {
 					continue
 				}
-
-				if _, ok := watcherToEvents[w]; !ok {
-					watcherToEvents[w] = []storagepb.Event{}
-				}
 				watcherToEvents[w] = append(watcherToEvents[w], ev)
 			}
 		}