Browse Source

Merge pull request #4614 from heyitsanthony/future-watch-rpc

etcdserver, storage, clientv3: watcher ranges
Anthony Romano 10 years ago
parent
commit
3a9d532140

+ 14 - 0
clientv3/integration/watch_test.go

@@ -157,6 +157,20 @@ func testWatchMultiWatcher(t *testing.T, wctx *watchctx) {
 	}
 }
 
+// TestWatchRange tests watcher creates ranges
+func TestWatchRange(t *testing.T) {
+	runWatchTest(t, testWatchReconnInit)
+}
+
+func testWatchRange(t *testing.T, wctx *watchctx) {
+	if wctx.ch = wctx.w.Watch(context.TODO(), "a", clientv3.WithRange("c")); wctx.ch == nil {
+		t.Fatalf("expected non-nil channel")
+	}
+	putAndWatch(t, wctx, "a", "a")
+	putAndWatch(t, wctx, "b", "b")
+	putAndWatch(t, wctx, "bar", "bar")
+}
+
 // TestWatchReconnRequest tests the send failure path when requesting a watcher.
 func TestWatchReconnRequest(t *testing.T) {
 	runWatchTest(t, testWatchReconnRequest)

+ 0 - 25
clientv3/op.go

@@ -15,8 +15,6 @@
 package clientv3
 
 import (
-	"reflect"
-
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
 	"github.com/coreos/etcd/lease"
 )
@@ -69,27 +67,6 @@ func (op Op) toRequestUnion() *pb.RequestUnion {
 	}
 }
 
-func (op Op) toWatchRequest() *watchRequest {
-	switch op.t {
-	case tRange:
-		key := string(op.key)
-		prefix := ""
-		if op.end != nil {
-			prefix = key
-			key = ""
-		}
-		wr := &watchRequest{
-			key:    key,
-			prefix: prefix,
-			rev:    op.rev,
-		}
-		return wr
-
-	default:
-		panic("Only for tRange")
-	}
-}
-
 func (op Op) isWrite() bool {
 	return op.t != tRange
 }
@@ -140,8 +117,6 @@ func opWatch(key string, opts ...OpOption) Op {
 	ret := Op{t: tRange, key: []byte(key)}
 	ret.applyOpts(opts)
 	switch {
-	case ret.end != nil && !reflect.DeepEqual(ret.end, getPrefix(ret.key)):
-		panic("only supports single keys or prefixes")
 	case ret.leaseID != 0:
 		panic("unexpected lease in watch")
 	case ret.limit != 0:

+ 15 - 13
clientv3/watch.go

@@ -78,10 +78,10 @@ type watcher struct {
 
 // watchRequest is issued by the subscriber to start a new watcher
 type watchRequest struct {
-	ctx    context.Context
-	key    string
-	prefix string
-	rev    int64
+	ctx context.Context
+	key string
+	end string
+	rev int64
 	// retc receives a chan WatchResponse once the watcher is established
 	retc chan chan WatchResponse
 }
@@ -129,11 +129,14 @@ func NewWatcher(c *Client) Watcher {
 func (w *watcher) Watch(ctx context.Context, key string, opts ...OpOption) WatchChan {
 	ow := opWatch(key, opts...)
 
-	wr := ow.toWatchRequest()
-	wr.ctx = ctx
-
 	retc := make(chan chan WatchResponse, 1)
-	wr.retc = retc
+	wr := &watchRequest{
+		ctx:  ctx,
+		key:  string(ow.key),
+		end:  string(ow.end),
+		rev:  ow.rev,
+		retc: retc,
+	}
 
 	ok := false
 
@@ -502,11 +505,10 @@ func (w *watcher) resumeWatchers(wc pb.Watch_WatchClient) error {
 
 // toPB converts an internal watch request structure to its protobuf messagefunc (wr *watchRequest)
 func (wr *watchRequest) toPB() *pb.WatchRequest {
-	req := &pb.WatchCreateRequest{StartRevision: wr.rev}
-	if wr.key != "" {
-		req.Key = []byte(wr.key)
-	} else {
-		req.Prefix = []byte(wr.prefix)
+	req := &pb.WatchCreateRequest{
+		StartRevision: wr.rev,
+		Key:           []byte(wr.key),
+		RangeEnd:      []byte(wr.end),
 	}
 	cr := &pb.WatchRequest_CreateRequest{CreateRequest: req}
 	return &pb.WatchRequest{RequestUnion: cr}

+ 26 - 28
etcdserver/api/v3rpc/watch.go

@@ -94,35 +94,33 @@ func (sws *serverWatchStream) recvLoop() error {
 
 		switch uv := req.RequestUnion.(type) {
 		case *pb.WatchRequest_CreateRequest:
-			if uv.CreateRequest != nil {
-				creq := uv.CreateRequest
-				var prefix bool
-				toWatch := creq.Key
-				if len(creq.Key) == 0 {
-					toWatch = creq.Prefix
-					prefix = true
-				}
+			if uv.CreateRequest == nil {
+				break
+			}
 
-				rev := creq.StartRevision
-				wsrev := sws.watchStream.Rev()
-				if rev == 0 {
-					// rev 0 watches past the current revision
-					rev = wsrev + 1
-				} else if rev > wsrev { // do not allow watching future revision.
-					sws.ctrlStream <- &pb.WatchResponse{
-						Header:   sws.newResponseHeader(wsrev),
-						WatchId:  -1,
-						Created:  true,
-						Canceled: true,
-					}
-					continue
-				}
-				id := sws.watchStream.Watch(toWatch, prefix, rev)
-				sws.ctrlStream <- &pb.WatchResponse{
-					Header:  sws.newResponseHeader(wsrev),
-					WatchId: int64(id),
-					Created: true,
-				}
+			creq := uv.CreateRequest
+			if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
+				// support  >= key queries
+				creq.RangeEnd = []byte{}
+			}
+
+			rev := creq.StartRevision
+			wsrev := sws.watchStream.Rev()
+			futureRev := rev > wsrev
+			if rev == 0 {
+				// rev 0 watches past the current revision
+				rev = wsrev + 1
+			}
+			// do not allow future watch revision
+			id := storage.WatchID(-1)
+			if !futureRev {
+				id = sws.watchStream.Watch(creq.Key, creq.RangeEnd, rev)
+			}
+			sws.ctrlStream <- &pb.WatchResponse{
+				Header:   sws.newResponseHeader(wsrev),
+				WatchId:  int64(id),
+				Created:  true,
+				Canceled: futureRev,
 			}
 		case *pb.WatchRequest_CancelRequest:
 			if uv.CancelRequest != nil {

+ 13 - 12
etcdserver/etcdserverpb/rpc.pb.go

@@ -870,8 +870,9 @@ func _WatchRequest_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.B
 type WatchCreateRequest struct {
 	// the key to be watched
 	Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
-	// the prefix to be watched.
-	Prefix []byte `protobuf:"bytes,2,opt,name=prefix,proto3" json:"prefix,omitempty"`
+	// if the range_end is given, keys in [key, range_end) are watched
+	// NOTE: only range_end == prefixEnd(key) is accepted now
+	RangeEnd []byte `protobuf:"bytes,2,opt,name=range_end,proto3" json:"range_end,omitempty"`
 	// start_revision is an optional revision (including) to watch from. No start_revision is "now".
 	StartRevision int64 `protobuf:"varint,3,opt,name=start_revision,proto3" json:"start_revision,omitempty"`
 }
@@ -2588,12 +2589,12 @@ func (m *WatchCreateRequest) MarshalTo(data []byte) (int, error) {
 			i += copy(data[i:], m.Key)
 		}
 	}
-	if m.Prefix != nil {
-		if len(m.Prefix) > 0 {
+	if m.RangeEnd != nil {
+		if len(m.RangeEnd) > 0 {
 			data[i] = 0x12
 			i++
-			i = encodeVarintRpc(data, i, uint64(len(m.Prefix)))
-			i += copy(data[i:], m.Prefix)
+			i = encodeVarintRpc(data, i, uint64(len(m.RangeEnd)))
+			i += copy(data[i:], m.RangeEnd)
 		}
 	}
 	if m.StartRevision != 0 {
@@ -3592,8 +3593,8 @@ func (m *WatchCreateRequest) Size() (n int) {
 			n += 1 + l + sovRpc(uint64(l))
 		}
 	}
-	if m.Prefix != nil {
-		l = len(m.Prefix)
+	if m.RangeEnd != nil {
+		l = len(m.RangeEnd)
 		if l > 0 {
 			n += 1 + l + sovRpc(uint64(l))
 		}
@@ -6004,7 +6005,7 @@ func (m *WatchCreateRequest) Unmarshal(data []byte) error {
 			iNdEx = postIndex
 		case 2:
 			if wireType != 2 {
-				return fmt.Errorf("proto: wrong wireType = %d for field Prefix", wireType)
+				return fmt.Errorf("proto: wrong wireType = %d for field RangeEnd", wireType)
 			}
 			var byteLen int
 			for shift := uint(0); ; shift += 7 {
@@ -6028,9 +6029,9 @@ func (m *WatchCreateRequest) Unmarshal(data []byte) error {
 			if postIndex > l {
 				return io.ErrUnexpectedEOF
 			}
-			m.Prefix = append(m.Prefix[:0], data[iNdEx:postIndex]...)
-			if m.Prefix == nil {
-				m.Prefix = []byte{}
+			m.RangeEnd = append(m.RangeEnd[:0], data[iNdEx:postIndex]...)
+			if m.RangeEnd == nil {
+				m.RangeEnd = []byte{}
 			}
 			iNdEx = postIndex
 		case 3:

+ 3 - 3
etcdserver/etcdserverpb/rpc.proto

@@ -262,11 +262,11 @@ message WatchRequest {
 message WatchCreateRequest {
   // the key to be watched
   bytes key = 1;
-  // the prefix to be watched.
-  bytes prefix = 2;
+  // if the range_end is given, keys in [key, range_end) are watched
+  // NOTE: only range_end == prefixEnd(key) is accepted now
+  bytes range_end = 2;
   // start_revision is an optional revision (including) to watch from. No start_revision is "now".
   int64 start_revision = 3;
-  // TODO: support Range watch?
 }
 
 message WatchCancelRequest {

+ 15 - 7
integration/v3_watch_test.go

@@ -71,7 +71,8 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 			[]string{"fooLong"},
 			&pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
 				CreateRequest: &pb.WatchCreateRequest{
-					Prefix: []byte("foo")}}},
+					Key:      []byte("foo"),
+					RangeEnd: []byte("fop")}}},
 
 			[]*pb.WatchResponse{
 				{
@@ -91,7 +92,8 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 			[]string{"foo"},
 			&pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
 				CreateRequest: &pb.WatchCreateRequest{
-					Prefix: []byte("helloworld")}}},
+					Key:      []byte("helloworld"),
+					RangeEnd: []byte("helloworle")}}},
 
 			[]*pb.WatchResponse{},
 		},
@@ -140,7 +142,8 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 			[]string{"foo", "foo", "foo"},
 			&pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
 				CreateRequest: &pb.WatchCreateRequest{
-					Prefix: []byte("foo")}}},
+					Key:      []byte("foo"),
+					RangeEnd: []byte("fop")}}},
 
 			[]*pb.WatchResponse{
 				{
@@ -203,6 +206,11 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
 			t.Errorf("#%d: did not create watchid, got +%v", i, cresp)
 			continue
 		}
+		if cresp.Canceled {
+			t.Errorf("#%d: canceled watcher on create", i, cresp)
+			continue
+		}
+
 		createdWatchId := cresp.WatchId
 		if cresp.Header == nil || cresp.Header.Revision != 1 {
 			t.Errorf("#%d: header revision got +%v, wanted revison 1", i, cresp)
@@ -353,7 +361,7 @@ func TestV3WatchCurrentPutOverlap(t *testing.T) {
 	progress := make(map[int64]int64)
 
 	wreq := &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
-		CreateRequest: &pb.WatchCreateRequest{Prefix: []byte("foo")}}}
+		CreateRequest: &pb.WatchCreateRequest{Key: []byte("foo"), RangeEnd: []byte("fop")}}}
 	if err := wStream.Send(wreq); err != nil {
 		t.Fatalf("first watch request failed (%v)", err)
 	}
@@ -437,7 +445,7 @@ func testV3WatchMultipleWatchers(t *testing.T, startRev int64) {
 		} else {
 			wreq = &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
 				CreateRequest: &pb.WatchCreateRequest{
-					Prefix: []byte("fo"), StartRevision: startRev}}}
+					Key: []byte("fo"), RangeEnd: []byte("fp"), StartRevision: startRev}}}
 		}
 		if err := wStream.Send(wreq); err != nil {
 			t.Fatalf("wStream.Send error: %v", err)
@@ -530,7 +538,7 @@ func testV3WatchMultipleEventsTxn(t *testing.T, startRev int64) {
 
 	wreq := &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
 		CreateRequest: &pb.WatchCreateRequest{
-			Prefix: []byte("foo"), StartRevision: startRev}}}
+			Key: []byte("foo"), RangeEnd: []byte("fop"), StartRevision: startRev}}}
 	if err := wStream.Send(wreq); err != nil {
 		t.Fatalf("wStream.Send error: %v", err)
 	}
@@ -623,7 +631,7 @@ func TestV3WatchMultipleEventsPutUnsynced(t *testing.T) {
 
 	wreq := &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
 		CreateRequest: &pb.WatchCreateRequest{
-			Prefix: []byte("foo"), StartRevision: 1}}}
+			Key: []byte("foo"), RangeEnd: []byte("fop"), StartRevision: 1}}}
 	if err := wStream.Send(wreq); err != nil {
 		t.Fatalf("wStream.Send error: %v", err)
 	}

+ 526 - 0
pkg/adt/interval_tree.go

@@ -0,0 +1,526 @@
+// Copyright 2016 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package adt
+
+import (
+	"math"
+)
+
+// Comparable is an interface for trichotomic comparisons.
+type Comparable interface {
+	// Compare gives the result of a 3-way comparison
+	// a.Compare(b) = 1 => a > b
+	// a.Compare(b) = 0 => a == b
+	// a.Compare(b) = -1 => a < b
+	Compare(c Comparable) int
+}
+
+type rbcolor bool
+
+const black = true
+const red = false
+
+// Interval implements a Comparable interval [begin, end)
+// TODO: support different sorts of intervals: (a,b), [a,b], (a, b]
+type Interval struct {
+	Begin Comparable
+	End   Comparable
+}
+
+// Compare on an interval gives == if the interval overlaps.
+func (ivl *Interval) Compare(c Comparable) int {
+	ivl2 := c.(*Interval)
+	ivbCmpBegin := ivl.Begin.Compare(ivl2.Begin)
+	ivbCmpEnd := ivl.Begin.Compare(ivl2.End)
+	iveCmpBegin := ivl.End.Compare(ivl2.Begin)
+
+	// ivl is left of ivl2
+	if ivbCmpBegin < 0 && iveCmpBegin <= 0 {
+		return -1
+	}
+
+	// iv is right of iv2
+	if ivbCmpEnd >= 0 {
+		return 1
+	}
+
+	return 0
+}
+
+type intervalNode struct {
+	// iv is the interval-value pair entry.
+	iv IntervalValue
+	// max endpoint of all descendent nodes.
+	max Comparable
+	// left and right are sorted by low endpoint of key interval
+	left, right *intervalNode
+	// parent is the direct ancestor of the node
+	parent *intervalNode
+	c      rbcolor
+}
+
+func (x *intervalNode) color() rbcolor {
+	if x == nil {
+		return black
+	}
+	return x.c
+}
+
+func (n *intervalNode) height() int {
+	if n == nil {
+		return 0
+	}
+	ld := n.left.height()
+	rd := n.right.height()
+	if ld < rd {
+		return rd + 1
+	}
+	return ld + 1
+}
+
+func (x *intervalNode) min() *intervalNode {
+	for x.left != nil {
+		x = x.left
+	}
+	return x
+}
+
+// successor is the next in-order node in the tree
+func (x *intervalNode) successor() *intervalNode {
+	if x.right != nil {
+		return x.right.min()
+	}
+	y := x.parent
+	for y != nil && x == y.right {
+		x = y
+		y = y.parent
+	}
+	return y
+}
+
+// updateMax updates the maximum values for a node and its ancestors
+func (x *intervalNode) updateMax() {
+	for x != nil {
+		oldmax := x.max
+		max := x.iv.Ivl.End
+		if x.left != nil && x.left.max.Compare(max) > 0 {
+			max = x.left.max
+		}
+		if x.right != nil && x.right.max.Compare(max) > 0 {
+			max = x.right.max
+		}
+		if oldmax.Compare(max) == 0 {
+			break
+		}
+		x.max = max
+		x = x.parent
+	}
+}
+
+type nodeVisitor func(n *intervalNode) bool
+
+// visit will call a node visitor on each node that overlaps the given interval
+func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) {
+	if x == nil {
+		return
+	}
+	v := iv.Compare(&x.iv.Ivl)
+	switch {
+	case v < 0:
+		x.left.visit(iv, nv)
+	case v > 0:
+		maxiv := Interval{x.iv.Ivl.Begin, x.max}
+		if maxiv.Compare(iv) == 0 {
+			x.left.visit(iv, nv)
+			x.right.visit(iv, nv)
+		}
+	default:
+		nv(x)
+		x.left.visit(iv, nv)
+		x.right.visit(iv, nv)
+	}
+}
+
+type IntervalValue struct {
+	Ivl Interval
+	Val interface{}
+}
+
+// IntervalTree represents a (mostly) textbook implementation of the
+// "Introduction to Algorithms" (Cormen et al, 2nd ed.) chapter 13 red-black tree
+// and chapter 14.3 interval tree with search supporting "stabbing queries".
+type IntervalTree struct {
+	root  *intervalNode
+	count int
+}
+
+// Delete removes the node with the given interval from the tree, returning
+// true if a node is in fact removed.
+func (ivt *IntervalTree) Delete(ivl Interval) bool {
+	z := ivt.find(ivl)
+	if z == nil {
+		return false
+	}
+
+	y := z
+	if z.left != nil && z.right != nil {
+		y = z.successor()
+	}
+
+	x := y.left
+	if x == nil {
+		x = y.right
+	}
+	if x != nil {
+		x.parent = y.parent
+	}
+
+	if y.parent == nil {
+		ivt.root = x
+	} else {
+		if y == y.parent.left {
+			y.parent.left = x
+		} else {
+			y.parent.right = x
+		}
+		y.parent.updateMax()
+	}
+	if y != z {
+		z.iv = y.iv
+		z.updateMax()
+	}
+
+	if y.color() == black && x != nil {
+		ivt.deleteFixup(x)
+	}
+
+	ivt.count--
+	return true
+}
+
+func (ivt *IntervalTree) deleteFixup(x *intervalNode) {
+	for x != ivt.root && x.color() == black && x.parent != nil {
+		if x == x.parent.left {
+			w := x.parent.right
+			if w.color() == red {
+				w.c = black
+				x.parent.c = red
+				ivt.rotateLeft(x.parent)
+				w = x.parent.right
+			}
+			if w == nil {
+				break
+			}
+			if w.left.color() == black && w.right.color() == black {
+				w.c = red
+				x = x.parent
+			} else {
+				if w.right.color() == black {
+					w.left.c = black
+					w.c = red
+					ivt.rotateRight(w)
+					w = x.parent.right
+				}
+				w.c = x.parent.color()
+				x.parent.c = black
+				w.right.c = black
+				ivt.rotateLeft(x.parent)
+				x = ivt.root
+			}
+		} else {
+			// same as above but with left and right exchanged
+			w := x.parent.left
+			if w.color() == red {
+				w.c = black
+				x.parent.c = red
+				ivt.rotateRight(x.parent)
+				w = x.parent.left
+			}
+			if w == nil {
+				break
+			}
+			if w.left.color() == black && w.right.color() == black {
+				w.c = red
+				x = x.parent
+			} else {
+				if w.left.color() == black {
+					w.right.c = black
+					w.c = red
+					ivt.rotateLeft(w)
+					w = x.parent.left
+				}
+				w.c = x.parent.color()
+				x.parent.c = black
+				w.left.c = black
+				ivt.rotateRight(x.parent)
+				x = ivt.root
+			}
+		}
+	}
+	if x != nil {
+		x.c = black
+	}
+}
+
+// Insert adds a node with the given interval into the tree.
+func (ivt *IntervalTree) Insert(ivl Interval, val interface{}) {
+	var y *intervalNode
+	z := &intervalNode{iv: IntervalValue{ivl, val}, max: ivl.End, c: red}
+	x := ivt.root
+	for x != nil {
+		y = x
+		if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
+			x = x.left
+		} else {
+			x = x.right
+		}
+	}
+
+	z.parent = y
+	if y == nil {
+		ivt.root = z
+	} else {
+		if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
+			y.left = z
+		} else {
+			y.right = z
+		}
+		y.updateMax()
+	}
+	z.c = red
+	ivt.insertFixup(z)
+	ivt.count++
+}
+
+func (ivt *IntervalTree) insertFixup(z *intervalNode) {
+	for z.parent != nil && z.parent.parent != nil && z.parent.color() == red {
+		if z.parent == z.parent.parent.left {
+			y := z.parent.parent.right
+			if y.color() == red {
+				y.c = black
+				z.parent.c = black
+				z.parent.parent.c = red
+				z = z.parent.parent
+			} else {
+				if z == z.parent.right {
+					z = z.parent
+					ivt.rotateLeft(z)
+				}
+				z.parent.c = black
+				z.parent.parent.c = red
+				ivt.rotateRight(z.parent.parent)
+			}
+		} else {
+			// same as then with left/right exchanged
+			y := z.parent.parent.left
+			if y.color() == red {
+				y.c = black
+				z.parent.c = black
+				z.parent.parent.c = red
+				z = z.parent.parent
+			} else {
+				if z == z.parent.left {
+					z = z.parent
+					ivt.rotateRight(z)
+				}
+				z.parent.c = black
+				z.parent.parent.c = red
+				ivt.rotateLeft(z.parent.parent)
+			}
+		}
+	}
+	ivt.root.c = black
+}
+
+// rotateLeft moves x so it is left of its right child
+func (ivt *IntervalTree) rotateLeft(x *intervalNode) {
+	y := x.right
+	x.right = y.left
+	if y.left != nil {
+		y.left.parent = x
+	}
+	x.updateMax()
+	ivt.replaceParent(x, y)
+	y.left = x
+	y.updateMax()
+}
+
+// rotateLeft moves x so it is right of its left child
+func (ivt *IntervalTree) rotateRight(x *intervalNode) {
+	if x == nil {
+		return
+	}
+	y := x.left
+	x.left = y.right
+	if y.right != nil {
+		y.right.parent = x
+	}
+	x.updateMax()
+	ivt.replaceParent(x, y)
+	y.right = x
+	y.updateMax()
+}
+
+// replaceParent replaces x's parent with y
+func (ivt *IntervalTree) replaceParent(x *intervalNode, y *intervalNode) {
+	y.parent = x.parent
+	if x.parent == nil {
+		ivt.root = y
+	} else {
+		if x == x.parent.left {
+			x.parent.left = y
+		} else {
+			x.parent.right = y
+		}
+		x.parent.updateMax()
+	}
+	x.parent = y
+}
+
+// Len gives the number of elements in the tree
+func (ivt *IntervalTree) Len() int { return ivt.count }
+
+// Height is the number of levels in the tree; one node has height 1.
+func (ivt *IntervalTree) Height() int { return ivt.root.height() }
+
+// MaxHeight is the expected maximum tree height given the number of nodes
+func (ivt *IntervalTree) MaxHeight() int {
+	return int((2 * math.Log2(float64(ivt.Len()+1))) + 0.5)
+}
+
+// InternalVisitor is used on tree searchs; return false to stop searching.
+type IntervalVisitor func(n *IntervalValue) bool
+
+// Visit calls a visitor function on every tree node intersecting the given interval.
+func (ivt *IntervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
+	ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
+}
+
+// find the exact node for a given interval
+func (ivt *IntervalTree) find(ivl Interval) (ret *intervalNode) {
+	f := func(n *intervalNode) bool {
+		if n.iv.Ivl != ivl {
+			return true
+		}
+		ret = n
+		return false
+	}
+	ivt.root.visit(&ivl, f)
+	return ret
+}
+
+// Find gets the IntervalValue for the node matching the given interval
+func (ivt *IntervalTree) Find(ivl Interval) (ret *IntervalValue) {
+	n := ivt.find(ivl)
+	if n == nil {
+		return nil
+	}
+	return &n.iv
+}
+
+// Contains returns true if there is some tree node intersecting the given interval.
+func (ivt *IntervalTree) Contains(iv Interval) bool {
+	x := ivt.root
+	for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
+		if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
+			x = x.left
+		} else {
+			x = x.right
+		}
+	}
+	return x != nil
+}
+
+// Stab returns a slice with all elements in the tree intersecting the interval.
+func (ivt *IntervalTree) Stab(iv Interval) (ivs []*IntervalValue) {
+	f := func(n *IntervalValue) bool { ivs = append(ivs, n); return true }
+	ivt.Visit(iv, f)
+	return ivs
+}
+
+type StringComparable string
+
+func (s StringComparable) Compare(c Comparable) int {
+	sc := c.(StringComparable)
+	if s < sc {
+		return -1
+	}
+	if s > sc {
+		return 1
+	}
+	return 0
+}
+
+func NewStringInterval(begin, end string) Interval {
+	return Interval{StringComparable(begin), StringComparable(end)}
+}
+
+func NewStringPoint(s string) Interval {
+	return Interval{StringComparable(s), StringComparable(s + "\x00")}
+}
+
+// StringAffineComparable treats "" as > all other strings
+type StringAffineComparable string
+
+func (s StringAffineComparable) Compare(c Comparable) int {
+	sc := c.(StringAffineComparable)
+
+	if len(s) == 0 {
+		if len(sc) == 0 {
+			return 0
+		}
+		return 1
+	}
+	if len(sc) == 0 {
+		return -1
+	}
+
+	if s < sc {
+		return -1
+	}
+	if s > sc {
+		return 1
+	}
+	return 0
+}
+
+func NewStringAffineInterval(begin, end string) Interval {
+	return Interval{StringAffineComparable(begin), StringAffineComparable(end)}
+}
+func NewStringAffinePoint(s string) Interval {
+	return NewStringAffineInterval(s, s+"\x00")
+}
+
+func NewInt64Interval(a int64, b int64) Interval {
+	return Interval{Int64Comparable(a), Int64Comparable(b)}
+}
+
+func NewInt64Point(a int64) Interval {
+	return Interval{Int64Comparable(a), Int64Comparable(a + 1)}
+}
+
+type Int64Comparable int64
+
+func (v Int64Comparable) Compare(c Comparable) int {
+	vc := c.(Int64Comparable)
+	cmp := v - vc
+	if cmp < 0 {
+		return -1
+	}
+	if cmp > 0 {
+		return 1
+	}
+	return 0
+}

+ 138 - 0
pkg/adt/interval_tree_test.go

@@ -0,0 +1,138 @@
+// Copyright 2016 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package adt
+
+import (
+	"math/rand"
+	"testing"
+	"time"
+)
+
+func TestIntervalTreeContains(t *testing.T) {
+	ivt := &IntervalTree{}
+	ivt.Insert(NewStringInterval("1", "3"), 123)
+
+	if ivt.Contains(NewStringPoint("0")) {
+		t.Errorf("contains 0")
+	}
+	if !ivt.Contains(NewStringPoint("1")) {
+		t.Errorf("missing 1")
+	}
+	if !ivt.Contains(NewStringPoint("11")) {
+		t.Errorf("missing 11")
+	}
+	if !ivt.Contains(NewStringPoint("2")) {
+		t.Errorf("missing 2")
+	}
+	if ivt.Contains(NewStringPoint("3")) {
+		t.Errorf("contains 3")
+	}
+}
+
+func TestIntervalTreeStringAffine(t *testing.T) {
+	ivt := &IntervalTree{}
+	ivt.Insert(NewStringAffineInterval("8", ""), 123)
+	if !ivt.Contains(NewStringAffinePoint("9")) {
+		t.Errorf("missing 9")
+	}
+	if ivt.Contains(NewStringAffinePoint("7")) {
+		t.Errorf("contains 7")
+	}
+}
+
+func TestIntervalTreeStab(t *testing.T) {
+	ivt := &IntervalTree{}
+	ivt.Insert(NewStringInterval("0", "1"), 123)
+	ivt.Insert(NewStringInterval("0", "2"), 456)
+	ivt.Insert(NewStringInterval("5", "6"), 789)
+	ivt.Insert(NewStringInterval("6", "8"), 999)
+	ivt.Insert(NewStringInterval("0", "3"), 0)
+
+	if ivt.root.max.Compare(StringComparable("8")) != 0 {
+		t.Fatalf("wrong root max got %v, expected 8", ivt.root.max)
+	}
+	if x := len(ivt.Stab(NewStringPoint("0"))); x != 3 {
+		t.Errorf("got %d, expected 3", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("1"))); x != 2 {
+		t.Errorf("got %d, expected 2", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("2"))); x != 1 {
+		t.Errorf("got %d, expected 1", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("3"))); x != 0 {
+		t.Errorf("got %d, expected 0", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("5"))); x != 1 {
+		t.Errorf("got %d, expected 1", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("55"))); x != 1 {
+		t.Errorf("got %d, expected 1", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("6"))); x != 1 {
+		t.Errorf("got %d, expected 1", x)
+	}
+}
+
+type xy struct {
+	x int64
+	y int64
+}
+
+func TestIntervalTreeRandom(t *testing.T) {
+	// generate unique intervals
+	ivs := make(map[xy]struct{})
+	ivt := &IntervalTree{}
+	maxv := 128
+	rand.Seed(time.Now().UnixNano())
+
+	for i := rand.Intn(maxv) + 1; i != 0; i-- {
+		x, y := int64(rand.Intn(maxv)), int64(rand.Intn(maxv))
+		if x > y {
+			t := x
+			x = y
+			y = t
+		} else if x == y {
+			y++
+		}
+		iv := xy{x, y}
+		if _, ok := ivs[iv]; ok {
+			// don't double insert
+			continue
+		}
+		ivt.Insert(NewInt64Interval(x, y), 123)
+		ivs[iv] = struct{}{}
+	}
+
+	for ab := range ivs {
+		for xy := range ivs {
+			v := xy.x + int64(rand.Intn(int(xy.y-xy.x)))
+			if slen := len(ivt.Stab(NewInt64Point(v))); slen == 0 {
+				t.Fatalf("expected %v stab non-zero for [%+v)", v, xy)
+			}
+			if !ivt.Contains(NewInt64Point(v)) {
+				t.Fatalf("did not get %d as expected for [%+v)", v, xy)
+			}
+		}
+		if !ivt.Delete(NewInt64Interval(ab.x, ab.y)) {
+			t.Errorf("did not delete %v as expected", ab)
+		}
+		delete(ivs, ab)
+	}
+
+	if ivt.Len() != 0 {
+		t.Errorf("got ivt.Len() = %v, expected 0", ivt.Len())
+	}
+}

+ 40 - 49
storage/kv_test.go

@@ -722,13 +722,10 @@ func TestWatchableKVWatch(t *testing.T) {
 	w := s.NewWatchStream()
 	defer w.Close()
 
-	wid := w.Watch([]byte("foo"), true, 0)
+	wid := w.Watch([]byte("foo"), []byte("fop"), 0)
 
-	s.Put([]byte("foo"), []byte("bar"), 1)
-	select {
-	case resp := <-w.Chan():
-		wev := storagepb.Event{
-			Type: storagepb.PUT,
+	wev := []storagepb.Event{
+		{Type: storagepb.PUT,
 			Kv: &storagepb.KeyValue{
 				Key:            []byte("foo"),
 				Value:          []byte("bar"),
@@ -737,13 +734,40 @@ func TestWatchableKVWatch(t *testing.T) {
 				Version:        1,
 				Lease:          1,
 			},
-		}
+		},
+		{
+			Type: storagepb.PUT,
+			Kv: &storagepb.KeyValue{
+				Key:            []byte("foo1"),
+				Value:          []byte("bar1"),
+				CreateRevision: 3,
+				ModRevision:    3,
+				Version:        1,
+				Lease:          2,
+			},
+		},
+		{
+			Type: storagepb.PUT,
+			Kv: &storagepb.KeyValue{
+				Key:            []byte("foo1"),
+				Value:          []byte("bar11"),
+				CreateRevision: 3,
+				ModRevision:    4,
+				Version:        2,
+				Lease:          3,
+			},
+		},
+	}
+
+	s.Put([]byte("foo"), []byte("bar"), 1)
+	select {
+	case resp := <-w.Chan():
 		if resp.WatchID != wid {
 			t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
 		}
 		ev := resp.Events[0]
-		if !reflect.DeepEqual(ev, wev) {
-			t.Errorf("watched event = %+v, want %+v", ev, wev)
+		if !reflect.DeepEqual(ev, wev[0]) {
+			t.Errorf("watched event = %+v, want %+v", ev, wev[0])
 		}
 	case <-time.After(5 * time.Second):
 		// CPU might be too slow, and the routine is not able to switch around
@@ -753,50 +777,28 @@ func TestWatchableKVWatch(t *testing.T) {
 	s.Put([]byte("foo1"), []byte("bar1"), 2)
 	select {
 	case resp := <-w.Chan():
-		wev := storagepb.Event{
-			Type: storagepb.PUT,
-			Kv: &storagepb.KeyValue{
-				Key:            []byte("foo1"),
-				Value:          []byte("bar1"),
-				CreateRevision: 3,
-				ModRevision:    3,
-				Version:        1,
-				Lease:          2,
-			},
-		}
 		if resp.WatchID != wid {
 			t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
 		}
 		ev := resp.Events[0]
-		if !reflect.DeepEqual(ev, wev) {
-			t.Errorf("watched event = %+v, want %+v", ev, wev)
+		if !reflect.DeepEqual(ev, wev[1]) {
+			t.Errorf("watched event = %+v, want %+v", ev, wev[1])
 		}
 	case <-time.After(5 * time.Second):
 		testutil.FatalStack(t, "failed to watch the event")
 	}
 
 	w = s.NewWatchStream()
-	wid = w.Watch([]byte("foo1"), false, 1)
+	wid = w.Watch([]byte("foo1"), []byte("foo2"), 3)
 
 	select {
 	case resp := <-w.Chan():
-		wev := storagepb.Event{
-			Type: storagepb.PUT,
-			Kv: &storagepb.KeyValue{
-				Key:            []byte("foo1"),
-				Value:          []byte("bar1"),
-				CreateRevision: 3,
-				ModRevision:    3,
-				Version:        1,
-				Lease:          2,
-			},
-		}
 		if resp.WatchID != wid {
 			t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
 		}
 		ev := resp.Events[0]
-		if !reflect.DeepEqual(ev, wev) {
-			t.Errorf("watched event = %+v, want %+v", ev, wev)
+		if !reflect.DeepEqual(ev, wev[1]) {
+			t.Errorf("watched event = %+v, want %+v", ev, wev[1])
 		}
 	case <-time.After(5 * time.Second):
 		testutil.FatalStack(t, "failed to watch the event")
@@ -805,23 +807,12 @@ func TestWatchableKVWatch(t *testing.T) {
 	s.Put([]byte("foo1"), []byte("bar11"), 3)
 	select {
 	case resp := <-w.Chan():
-		wev := storagepb.Event{
-			Type: storagepb.PUT,
-			Kv: &storagepb.KeyValue{
-				Key:            []byte("foo1"),
-				Value:          []byte("bar11"),
-				CreateRevision: 3,
-				ModRevision:    4,
-				Version:        2,
-				Lease:          3,
-			},
-		}
 		if resp.WatchID != wid {
 			t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
 		}
 		ev := resp.Events[0]
-		if !reflect.DeepEqual(ev, wev) {
-			t.Errorf("watched event = %+v, want %+v", ev, wev)
+		if !reflect.DeepEqual(ev, wev[2]) {
+			t.Errorf("watched event = %+v, want %+v", ev, wev[2])
 		}
 	case <-time.After(5 * time.Second):
 		testutil.FatalStack(t, "failed to watch the event")

+ 35 - 215
storage/watchable_store.go

@@ -16,8 +16,6 @@ package storage
 
 import (
 	"log"
-	"math"
-	"strings"
 	"sync"
 	"time"
 
@@ -34,103 +32,8 @@ const (
 	chanBufLen = 1024
 )
 
-var (
-	// watchBatchMaxRevs is the maximum distinct revisions that
-	// may be sent to an unsynced watcher at a time. Declared as
-	// var instead of const for testing purposes.
-	watchBatchMaxRevs = 1000
-)
-
-type eventBatch struct {
-	// evs is a batch of revision-ordered events
-	evs []storagepb.Event
-	// revs is the minimum unique revisions observed for this batch
-	revs int
-	// moreRev is first revision with more events following this batch
-	moreRev int64
-}
-
-type (
-	watcherSetByKey map[string]watcherSet
-	watcherSet      map[*watcher]struct{}
-	watcherBatch    map[*watcher]*eventBatch
-)
-
-func (eb *eventBatch) add(ev storagepb.Event) {
-	if eb.revs > watchBatchMaxRevs {
-		// maxed out batch size
-		return
-	}
-
-	if len(eb.evs) == 0 {
-		// base case
-		eb.revs = 1
-		eb.evs = append(eb.evs, ev)
-		return
-	}
-
-	// revision accounting
-	ebRev := eb.evs[len(eb.evs)-1].Kv.ModRevision
-	evRev := ev.Kv.ModRevision
-	if evRev > ebRev {
-		eb.revs++
-		if eb.revs > watchBatchMaxRevs {
-			eb.moreRev = evRev
-			return
-		}
-	}
-
-	eb.evs = append(eb.evs, ev)
-}
-
-func (wb watcherBatch) add(w *watcher, ev storagepb.Event) {
-	eb := wb[w]
-	if eb == nil {
-		eb = &eventBatch{}
-		wb[w] = eb
-	}
-	eb.add(ev)
-}
-
-func (w watcherSet) add(wa *watcher) {
-	if _, ok := w[wa]; ok {
-		panic("add watcher twice!")
-	}
-	w[wa] = struct{}{}
-}
-
-func (w watcherSetByKey) add(wa *watcher) {
-	set := w[string(wa.key)]
-	if set == nil {
-		set = make(watcherSet)
-		w[string(wa.key)] = set
-	}
-	set.add(wa)
-}
-
-func (w watcherSetByKey) getSetByKey(key string) (watcherSet, bool) {
-	set, ok := w[key]
-	return set, ok
-}
-
-func (w watcherSetByKey) delete(wa *watcher) bool {
-	k := string(wa.key)
-	if v, ok := w[k]; ok {
-		if _, ok := v[wa]; ok {
-			delete(v, wa)
-			// if there is nothing in the set,
-			// remove the set
-			if len(v) == 0 {
-				delete(w, k)
-			}
-			return true
-		}
-	}
-	return false
-}
-
 type watchable interface {
-	watch(key []byte, prefix bool, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc)
+	watch(key, end []byte, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc)
 	rev() int64
 }
 
@@ -140,11 +43,11 @@ type watchableStore struct {
 	*store
 
 	// contains all unsynced watchers that needs to sync with events that have happened
-	unsynced watcherSetByKey
+	unsynced watcherGroup
 
 	// contains all synced watchers that are in sync with the progress of the store.
 	// The key of the map is the key that the watcher watches on.
-	synced watcherSetByKey
+	synced watcherGroup
 
 	stopc chan struct{}
 	wg    sync.WaitGroup
@@ -157,8 +60,8 @@ type cancelFunc func()
 func newWatchableStore(b backend.Backend, le lease.Lessor) *watchableStore {
 	s := &watchableStore{
 		store:    NewStore(b, le),
-		unsynced: make(watcherSetByKey),
-		synced:   make(watcherSetByKey),
+		unsynced: newWatcherGroup(),
+		synced:   newWatcherGroup(),
 		stopc:    make(chan struct{}),
 	}
 	if s.le != nil {
@@ -268,16 +171,16 @@ func (s *watchableStore) NewWatchStream() WatchStream {
 	}
 }
 
-func (s *watchableStore) watch(key []byte, prefix bool, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc) {
+func (s *watchableStore) watch(key, end []byte, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
 	wa := &watcher{
-		key:    key,
-		prefix: prefix,
-		cur:    startRev,
-		id:     id,
-		ch:     ch,
+		key: key,
+		end: end,
+		cur: startRev,
+		id:  id,
+		ch:  ch,
 	}
 
 	s.store.mu.Lock()
@@ -342,15 +245,16 @@ func (s *watchableStore) syncWatchers() {
 	s.store.mu.Lock()
 	defer s.store.mu.Unlock()
 
-	if len(s.unsynced) == 0 {
+	if s.unsynced.size() == 0 {
 		return
 	}
 
 	// in order to find key-value pairs from unsynced watchers, we need to
 	// find min revision index, and these revisions can be used to
 	// query the backend store of key-value pairs
-	prefixes, minRev := s.scanUnsync()
 	curRev := s.store.currentRev.main
+	compactionRev := s.store.compactMainRev
+	minRev := s.unsynced.scanMinRev(curRev, compactionRev)
 	minBytes, maxBytes := newRevBytes(), newRevBytes()
 	revToBytes(revision{main: minRev}, minBytes)
 	revToBytes(revision{main: curRev + 1}, maxBytes)
@@ -360,10 +264,10 @@ func (s *watchableStore) syncWatchers() {
 	tx := s.store.b.BatchTx()
 	tx.Lock()
 	revs, vs := tx.UnsafeRange(keyBucketName, minBytes, maxBytes, 0)
-	evs := kvsToEvents(revs, vs, s.unsynced, prefixes)
+	evs := kvsToEvents(&s.unsynced, revs, vs)
 	tx.Unlock()
 
-	for w, eb := range newWatcherBatch(s.unsynced, evs) {
+	for w, eb := range newWatcherBatch(&s.unsynced, evs) {
 		select {
 		// s.store.Rev also uses Lock, so just return directly
 		case w.ch <- WatchResponse{WatchID: w.id, Events: eb.evs, Revision: s.store.currentRev.main}:
@@ -383,56 +287,18 @@ func (s *watchableStore) syncWatchers() {
 		s.unsynced.delete(w)
 	}
 
-	slowWatcherGauge.Set(float64(len(s.unsynced)))
-}
-
-func (s *watchableStore) scanUnsync() (prefixes map[string]struct{}, minRev int64) {
-	curRev := s.store.currentRev.main
-	compactionRev := s.store.compactMainRev
-
-	prefixes = make(map[string]struct{})
-	minRev = int64(math.MaxInt64)
-	for _, set := range s.unsynced {
-		for w := range set {
-			k := string(w.key)
-
-			if w.cur > curRev {
-				panic("watcher current revision should not exceed current revision")
-			}
-
-			if w.cur < compactionRev {
-				select {
-				case w.ch <- WatchResponse{WatchID: w.id, CompactRevision: compactionRev}:
-					s.unsynced.delete(w)
-				default:
-					// retry next time
-				}
-				continue
-			}
-
-			if minRev > w.cur {
-				minRev = w.cur
-			}
-
-			if w.prefix {
-				prefixes[k] = struct{}{}
-			}
-		}
-	}
-
-	return prefixes, minRev
+	slowWatcherGauge.Set(float64(s.unsynced.size()))
 }
 
 // kvsToEvents gets all events for the watchers from all key-value pairs
-func kvsToEvents(revs, vals [][]byte, wsk watcherSetByKey, pfxs map[string]struct{}) (evs []storagepb.Event) {
+func kvsToEvents(wg *watcherGroup, revs, vals [][]byte) (evs []storagepb.Event) {
 	for i, v := range vals {
 		var kv storagepb.KeyValue
 		if err := kv.Unmarshal(v); err != nil {
 			log.Panicf("storage: cannot unmarshal event: %v", err)
 		}
 
-		k := string(kv.Key)
-		if _, ok := wsk.getSetByKey(k); !ok && !matchPrefix(k, pfxs) {
+		if !wg.contains(string(kv.Key)) {
 			continue
 		}
 
@@ -450,26 +316,19 @@ func kvsToEvents(revs, vals [][]byte, wsk watcherSetByKey, pfxs map[string]struc
 // notify notifies the fact that given event at the given rev just happened to
 // watchers that watch on the key of the event.
 func (s *watchableStore) notify(rev int64, evs []storagepb.Event) {
-	we := newWatcherBatch(s.synced, evs)
-	for _, wm := range s.synced {
-		for w := range wm {
-			eb, ok := we[w]
-			if !ok {
-				continue
-			}
-			if eb.revs != 1 {
-				panic("unexpected multiple revisions in notification")
-			}
-			select {
-			case w.ch <- WatchResponse{WatchID: w.id, Events: eb.evs, Revision: s.Rev()}:
-				pendingEventsGauge.Add(float64(len(eb.evs)))
-			default:
-				// move slow watcher to unsynced
-				w.cur = rev
-				s.unsynced.add(w)
-				delete(wm, w)
-				slowWatcherGauge.Inc()
-			}
+	for w, eb := range newWatcherBatch(&s.synced, evs) {
+		if eb.revs != 1 {
+			panic("unexpected multiple revisions in notification")
+		}
+		select {
+		case w.ch <- WatchResponse{WatchID: w.id, Events: eb.evs, Revision: s.Rev()}:
+			pendingEventsGauge.Add(float64(len(eb.evs)))
+		default:
+			// move slow watcher to unsynced
+			w.cur = rev
+			s.unsynced.add(w)
+			s.synced.delete(w)
+			slowWatcherGauge.Inc()
 		}
 	}
 }
@@ -479,9 +338,9 @@ func (s *watchableStore) rev() int64 { return s.store.Rev() }
 type watcher struct {
 	// the watcher key
 	key []byte
-	// prefix indicates if watcher is on a key or a prefix.
-	// If prefix is true, the watcher is on a prefix.
-	prefix bool
+	// end indicates the end of the range to watch.
+	// If end is set, the watcher is on a range.
+	end []byte
 	// cur is the current watcher revision.
 	// If cur is behind the current revision of the KV,
 	// watcher is unsynced and needs to catch up.
@@ -492,42 +351,3 @@ type watcher struct {
 	// The chan might be shared with other watchers.
 	ch chan<- WatchResponse
 }
-
-// newWatcherBatch maps watchers to their matched events. It enables quick
-// events look up by watcher.
-func newWatcherBatch(sm watcherSetByKey, evs []storagepb.Event) watcherBatch {
-	wb := make(watcherBatch)
-	for _, ev := range evs {
-		key := string(ev.Kv.Key)
-
-		// check all prefixes of the key to notify all corresponded watchers
-		for i := 0; i <= len(key); i++ {
-			for w := range sm[key[:i]] {
-				// don't double notify
-				if ev.Kv.ModRevision < w.cur {
-					continue
-				}
-
-				// the watcher needs to be notified when either it watches prefix or
-				// the key is exactly matched.
-				if !w.prefix && i != len(ev.Kv.Key) {
-					continue
-				}
-				wb.add(w, ev)
-			}
-		}
-	}
-
-	return wb
-}
-
-// matchPrefix returns true if key has any matching prefix
-// from prefixes map.
-func matchPrefix(key string, prefixes map[string]struct{}) bool {
-	for p := range prefixes {
-		if strings.HasPrefix(key, p) {
-			return true
-		}
-	}
-	return false
-}

+ 4 - 4
storage/watchable_store_bench_test.go

@@ -40,11 +40,11 @@ func BenchmarkWatchableStoreUnsyncedCancel(b *testing.B) {
 	// in unsynced for this benchmark.
 	ws := &watchableStore{
 		store:    s,
-		unsynced: make(watcherSetByKey),
+		unsynced: newWatcherGroup(),
 
 		// to make the test not crash from assigning to nil map.
 		// 'synced' doesn't get populated in this test.
-		synced: make(watcherSetByKey),
+		synced: newWatcherGroup(),
 	}
 
 	defer func() {
@@ -69,7 +69,7 @@ func BenchmarkWatchableStoreUnsyncedCancel(b *testing.B) {
 	watchIDs := make([]WatchID, watcherN)
 	for i := 0; i < watcherN; i++ {
 		// non-0 value to keep watchers in unsynced
-		watchIDs[i] = w.Watch(testKey, true, 1)
+		watchIDs[i] = w.Watch(testKey, nil, 1)
 	}
 
 	// random-cancel N watchers to make it not biased towards
@@ -109,7 +109,7 @@ func BenchmarkWatchableStoreSyncedCancel(b *testing.B) {
 	watchIDs := make([]WatchID, watcherN)
 	for i := 0; i < watcherN; i++ {
 		// 0 for startRev to keep watchers in synced
-		watchIDs[i] = w.Watch(testKey, true, 0)
+		watchIDs[i] = w.Watch(testKey, nil, 0)
 	}
 
 	// randomly cancel watchers to make it not biased towards

+ 34 - 39
storage/watchable_store_test.go

@@ -40,11 +40,11 @@ func TestWatch(t *testing.T) {
 	s.Put(testKey, testValue, lease.NoLease)
 
 	w := s.NewWatchStream()
-	w.Watch(testKey, true, 0)
+	w.Watch(testKey, nil, 0)
 
-	if _, ok := s.synced[string(testKey)]; !ok {
+	if !s.synced.contains(string(testKey)) {
 		// the key must have had an entry in synced
-		t.Errorf("existence = %v, want true", ok)
+		t.Errorf("existence = false, want true")
 	}
 }
 
@@ -61,15 +61,15 @@ func TestNewWatcherCancel(t *testing.T) {
 	s.Put(testKey, testValue, lease.NoLease)
 
 	w := s.NewWatchStream()
-	wt := w.Watch(testKey, true, 0)
+	wt := w.Watch(testKey, nil, 0)
 
 	if err := w.Cancel(wt); err != nil {
 		t.Error(err)
 	}
 
-	if _, ok := s.synced[string(testKey)]; ok {
+	if s.synced.contains(string(testKey)) {
 		// the key shoud have been deleted
-		t.Errorf("existence = %v, want false", ok)
+		t.Errorf("existence = true, want false")
 	}
 }
 
@@ -83,11 +83,11 @@ func TestCancelUnsynced(t *testing.T) {
 	// in unsynced to test if syncWatchers works as expected.
 	s := &watchableStore{
 		store:    NewStore(b, &lease.FakeLessor{}),
-		unsynced: make(watcherSetByKey),
+		unsynced: newWatcherGroup(),
 
 		// to make the test not crash from assigning to nil map.
 		// 'synced' doesn't get populated in this test.
-		synced: make(watcherSetByKey),
+		synced: newWatcherGroup(),
 	}
 
 	defer func() {
@@ -112,7 +112,7 @@ func TestCancelUnsynced(t *testing.T) {
 	watchIDs := make([]WatchID, watcherN)
 	for i := 0; i < watcherN; i++ {
 		// use 1 to keep watchers in unsynced
-		watchIDs[i] = w.Watch(testKey, true, 1)
+		watchIDs[i] = w.Watch(testKey, nil, 1)
 	}
 
 	for _, idx := range watchIDs {
@@ -125,8 +125,8 @@ func TestCancelUnsynced(t *testing.T) {
 	//
 	// unsynced should be empty
 	// because cancel removes watcher from unsynced
-	if len(s.unsynced) != 0 {
-		t.Errorf("unsynced size = %d, want 0", len(s.unsynced))
+	if size := s.unsynced.size(); size != 0 {
+		t.Errorf("unsynced size = %d, want 0", size)
 	}
 }
 
@@ -138,8 +138,8 @@ func TestSyncWatchers(t *testing.T) {
 
 	s := &watchableStore{
 		store:    NewStore(b, &lease.FakeLessor{}),
-		unsynced: make(watcherSetByKey),
-		synced:   make(watcherSetByKey),
+		unsynced: newWatcherGroup(),
+		synced:   newWatcherGroup(),
 	}
 
 	defer func() {
@@ -158,13 +158,13 @@ func TestSyncWatchers(t *testing.T) {
 
 	for i := 0; i < watcherN; i++ {
 		// specify rev as 1 to keep watchers in unsynced
-		w.Watch(testKey, true, 1)
+		w.Watch(testKey, nil, 1)
 	}
 
 	// Before running s.syncWatchers() synced should be empty because we manually
 	// populate unsynced only
-	sws, _ := s.synced.getSetByKey(string(testKey))
-	uws, _ := s.unsynced.getSetByKey(string(testKey))
+	sws := s.synced.watcherSetByKey(string(testKey))
+	uws := s.unsynced.watcherSetByKey(string(testKey))
 
 	if len(sws) != 0 {
 		t.Fatalf("synced[string(testKey)] size = %d, want 0", len(sws))
@@ -177,8 +177,8 @@ func TestSyncWatchers(t *testing.T) {
 	// this should move all unsynced watchers to synced ones
 	s.syncWatchers()
 
-	sws, _ = s.synced.getSetByKey(string(testKey))
-	uws, _ = s.unsynced.getSetByKey(string(testKey))
+	sws = s.synced.watcherSetByKey(string(testKey))
+	uws = s.unsynced.watcherSetByKey(string(testKey))
 
 	// After running s.syncWatchers(), synced should not be empty because syncwatchers
 	// populates synced in this test case
@@ -240,7 +240,7 @@ func TestWatchCompacted(t *testing.T) {
 	}
 
 	w := s.NewWatchStream()
-	wt := w.Watch(testKey, true, compactRev-1)
+	wt := w.Watch(testKey, nil, compactRev-1)
 
 	select {
 	case resp := <-w.Chan():
@@ -275,7 +275,7 @@ func TestWatchBatchUnsynced(t *testing.T) {
 	}
 
 	w := s.NewWatchStream()
-	w.Watch(v, false, 1)
+	w.Watch(v, nil, 1)
 	for i := 0; i < batches; i++ {
 		if resp := <-w.Chan(); len(resp.Events) != watchBatchMaxRevs {
 			t.Fatalf("len(events) = %d, want %d", len(resp.Events), watchBatchMaxRevs)
@@ -284,8 +284,8 @@ func TestWatchBatchUnsynced(t *testing.T) {
 
 	s.store.mu.Lock()
 	defer s.store.mu.Unlock()
-	if len(s.synced) != 1 {
-		t.Errorf("synced size = %d, want 1", len(s.synced))
+	if size := s.synced.size(); size != 1 {
+		t.Errorf("synced size = %d, want 1", size)
 	}
 }
 
@@ -311,14 +311,14 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
 	}
 
 	tests := []struct {
-		sync watcherSetByKey
+		sync []*watcher
 		evs  []storagepb.Event
 
 		wwe map[*watcher][]storagepb.Event
 	}{
 		// no watcher in sync, some events should return empty wwe
 		{
-			watcherSetByKey{},
+			nil,
 			evs,
 			map[*watcher][]storagepb.Event{},
 		},
@@ -326,9 +326,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
 		// one watcher in sync, one event that does not match the key of that
 		// watcher should return empty wwe
 		{
-			watcherSetByKey{
-				string(k2): {ws[2]: struct{}{}},
-			},
+			[]*watcher{ws[2]},
 			evs[:1],
 			map[*watcher][]storagepb.Event{},
 		},
@@ -336,9 +334,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
 		// one watcher in sync, one event that matches the key of that
 		// watcher should return wwe with that matching watcher
 		{
-			watcherSetByKey{
-				string(k1): {ws[1]: struct{}{}},
-			},
+			[]*watcher{ws[1]},
 			evs[1:2],
 			map[*watcher][]storagepb.Event{
 				ws[1]: evs[1:2],
@@ -349,10 +345,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
 		// that matches the key of only one of the watcher should return wwe
 		// with the matching watcher
 		{
-			watcherSetByKey{
-				string(k0): {ws[0]: struct{}{}},
-				string(k2): {ws[2]: struct{}{}},
-			},
+			[]*watcher{ws[0], ws[2]},
 			evs[2:],
 			map[*watcher][]storagepb.Event{
 				ws[2]: evs[2:],
@@ -362,10 +355,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
 		// two watchers in sync that watches the same key, two events that
 		// match the keys should return wwe with those two watchers
 		{
-			watcherSetByKey{
-				string(k0): {ws[0]: struct{}{}},
-				string(k1): {ws[1]: struct{}{}},
-			},
+			[]*watcher{ws[0], ws[1]},
 			evs[:2],
 			map[*watcher][]storagepb.Event{
 				ws[0]: evs[:1],
@@ -375,7 +365,12 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		gwe := newWatcherBatch(tt.sync, tt.evs)
+		wg := newWatcherGroup()
+		for _, w := range tt.sync {
+			wg.add(w)
+		}
+
+		gwe := newWatcherBatch(&wg, tt.evs)
 		if len(gwe) != len(tt.wwe) {
 			t.Errorf("#%d: len(gwe) got = %d, want = %d", i, len(gwe), len(tt.wwe))
 		}

+ 4 - 5
storage/watcher.go

@@ -29,16 +29,15 @@ type WatchID int64
 
 type WatchStream interface {
 	// Watch creates a watcher. The watcher watches the events happening or
-	// happened on the given key or key prefix from the given startRev.
+	// happened on the given key or range [key, end) from the given startRev.
 	//
 	// The whole event history can be watched unless compacted.
-	// If `prefix` is true, watch observes all events whose key prefix could be the given `key`.
 	// If `startRev` <=0, watch observes events after currentRev.
 	//
 	// The returned `id` is the ID of this watcher. It appears as WatchID
 	// in events that are sent to the created watcher through stream channel.
 	//
-	Watch(key []byte, prefix bool, startRev int64) WatchID
+	Watch(key, end []byte, startRev int64) WatchID
 
 	// Chan returns a chan. All watch response will be sent to the returned chan.
 	Chan() <-chan WatchResponse
@@ -87,7 +86,7 @@ type watchStream struct {
 
 // Watch creates a new watcher in the stream and returns its WatchID.
 // TODO: return error if ws is closed?
-func (ws *watchStream) Watch(key []byte, prefix bool, startRev int64) WatchID {
+func (ws *watchStream) Watch(key, end []byte, startRev int64) WatchID {
 	ws.mu.Lock()
 	defer ws.mu.Unlock()
 	if ws.closed {
@@ -97,7 +96,7 @@ func (ws *watchStream) Watch(key []byte, prefix bool, startRev int64) WatchID {
 	id := ws.nextID
 	ws.nextID++
 
-	_, c := ws.watchable.watch(key, prefix, startRev, id, ws.ch)
+	_, c := ws.watchable.watch(key, end, startRev, id, ws.ch)
 
 	ws.cancels[id] = c
 	return id

+ 1 - 1
storage/watcher_bench_test.go

@@ -33,6 +33,6 @@ func BenchmarkKVWatcherMemoryUsage(b *testing.B) {
 	b.ReportAllocs()
 	b.StartTimer()
 	for i := 0; i < b.N; i++ {
-		w.Watch([]byte(fmt.Sprint("foo", i)), false, 0)
+		w.Watch([]byte(fmt.Sprint("foo", i)), nil, 0)
 	}
 }

+ 269 - 0
storage/watcher_group.go

@@ -0,0 +1,269 @@
+// Copyright 2016 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+	"math"
+
+	"github.com/coreos/etcd/pkg/adt"
+	"github.com/coreos/etcd/storage/storagepb"
+)
+
+var (
+	// watchBatchMaxRevs is the maximum distinct revisions that
+	// may be sent to an unsynced watcher at a time. Declared as
+	// var instead of const for testing purposes.
+	watchBatchMaxRevs = 1000
+)
+
+type eventBatch struct {
+	// evs is a batch of revision-ordered events
+	evs []storagepb.Event
+	// revs is the minimum unique revisions observed for this batch
+	revs int
+	// moreRev is first revision with more events following this batch
+	moreRev int64
+}
+
+func (eb *eventBatch) add(ev storagepb.Event) {
+	if eb.revs > watchBatchMaxRevs {
+		// maxed out batch size
+		return
+	}
+
+	if len(eb.evs) == 0 {
+		// base case
+		eb.revs = 1
+		eb.evs = append(eb.evs, ev)
+		return
+	}
+
+	// revision accounting
+	ebRev := eb.evs[len(eb.evs)-1].Kv.ModRevision
+	evRev := ev.Kv.ModRevision
+	if evRev > ebRev {
+		eb.revs++
+		if eb.revs > watchBatchMaxRevs {
+			eb.moreRev = evRev
+			return
+		}
+	}
+
+	eb.evs = append(eb.evs, ev)
+}
+
+type watcherBatch map[*watcher]*eventBatch
+
+func (wb watcherBatch) add(w *watcher, ev storagepb.Event) {
+	eb := wb[w]
+	if eb == nil {
+		eb = &eventBatch{}
+		wb[w] = eb
+	}
+	eb.add(ev)
+}
+
+// newWatcherBatch maps watchers to their matched events. It enables quick
+// events look up by watcher.
+func newWatcherBatch(wg *watcherGroup, evs []storagepb.Event) watcherBatch {
+	wb := make(watcherBatch)
+	for _, ev := range evs {
+		for w := range wg.watcherSetByKey(string(ev.Kv.Key)) {
+			if ev.Kv.ModRevision >= w.cur {
+				// don't double notify
+				wb.add(w, ev)
+			}
+		}
+	}
+	return wb
+}
+
+type watcherSet map[*watcher]struct{}
+
+func (w watcherSet) add(wa *watcher) {
+	if _, ok := w[wa]; ok {
+		panic("add watcher twice!")
+	}
+	w[wa] = struct{}{}
+}
+
+func (w watcherSet) union(ws watcherSet) {
+	for wa := range ws {
+		w.add(wa)
+	}
+}
+
+func (w watcherSet) delete(wa *watcher) {
+	if _, ok := w[wa]; !ok {
+		panic("removing missing watcher!")
+	}
+	delete(w, wa)
+}
+
+type watcherSetByKey map[string]watcherSet
+
+func (w watcherSetByKey) add(wa *watcher) {
+	set := w[string(wa.key)]
+	if set == nil {
+		set = make(watcherSet)
+		w[string(wa.key)] = set
+	}
+	set.add(wa)
+}
+
+func (w watcherSetByKey) delete(wa *watcher) bool {
+	k := string(wa.key)
+	if v, ok := w[k]; ok {
+		if _, ok := v[wa]; ok {
+			delete(v, wa)
+			if len(v) == 0 {
+				// remove the set; nothing left
+				delete(w, k)
+			}
+			return true
+		}
+	}
+	return false
+}
+
+type interval struct {
+	begin string
+	end   string
+}
+
+type watcherSetByInterval map[interval]watcherSet
+
+// watcherGroup is a collection of watchers organized by their ranges
+type watcherGroup struct {
+	// keyWatchers has the watchers that watch on a single key
+	keyWatchers watcherSetByKey
+	// ranges has the watchers that watch a range; it is sorted by interval
+	ranges adt.IntervalTree
+	// watchers is the set of all watchers
+	watchers watcherSet
+}
+
+func newWatcherGroup() watcherGroup {
+	return watcherGroup{
+		keyWatchers: make(watcherSetByKey),
+		watchers:    make(watcherSet),
+	}
+}
+
+// add puts a watcher in the group.
+func (wg *watcherGroup) add(wa *watcher) {
+	wg.watchers.add(wa)
+	if wa.end == nil {
+		wg.keyWatchers.add(wa)
+		return
+	}
+
+	// interval already registered?
+	ivl := adt.NewStringAffineInterval(string(wa.key), string(wa.end))
+	if iv := wg.ranges.Find(ivl); iv != nil {
+		iv.Val.(watcherSet).add(wa)
+		return
+	}
+
+	// not registered, put in interval tree
+	ws := make(watcherSet)
+	ws.add(wa)
+	wg.ranges.Insert(ivl, ws)
+}
+
+// contains is whether the given key has a watcher in the group.
+func (wg *watcherGroup) contains(key string) bool {
+	_, ok := wg.keyWatchers[key]
+	return ok || wg.ranges.Contains(adt.NewStringAffinePoint(key))
+}
+
+// size gives the number of unique watchers in the group.
+func (wg *watcherGroup) size() int { return len(wg.watchers) }
+
+// delete removes a watcher from the group.
+func (wg *watcherGroup) delete(wa *watcher) bool {
+	if _, ok := wg.watchers[wa]; !ok {
+		return false
+	}
+	wg.watchers.delete(wa)
+	if wa.end == nil {
+		wg.keyWatchers.delete(wa)
+		return true
+	}
+
+	ivl := adt.NewStringAffineInterval(string(wa.key), string(wa.end))
+	iv := wg.ranges.Find(ivl)
+	if iv == nil {
+		return false
+	}
+
+	ws := iv.Val.(watcherSet)
+	delete(ws, wa)
+	if len(ws) == 0 {
+		// remove interval missing watchers
+		if ok := wg.ranges.Delete(ivl); !ok {
+			panic("could not remove watcher from interval tree")
+		}
+	}
+
+	return true
+}
+
+func (wg *watcherGroup) scanMinRev(curRev int64, compactRev int64) int64 {
+	minRev := int64(math.MaxInt64)
+	for w := range wg.watchers {
+		if w.cur > curRev {
+			panic("watcher current revision should not exceed current revision")
+		}
+		if w.cur < compactRev {
+			select {
+			case w.ch <- WatchResponse{WatchID: w.id, CompactRevision: compactRev}:
+				wg.delete(w)
+			default:
+				// retry next time
+			}
+			continue
+		}
+		if minRev > w.cur {
+			minRev = w.cur
+		}
+	}
+	return minRev
+}
+
+// watcherSetByKey gets the set of watchers that recieve events on the given key.
+func (wg *watcherGroup) watcherSetByKey(key string) watcherSet {
+	wkeys := wg.keyWatchers[key]
+	wranges := wg.ranges.Stab(adt.NewStringAffinePoint(key))
+
+	// zero-copy cases
+	switch {
+	case len(wranges) == 0:
+		// no need to merge ranges or copy; reuse single-key set
+		return wkeys
+	case len(wranges) == 0 && len(wkeys) == 0:
+		return nil
+	case len(wranges) == 1 && len(wkeys) == 0:
+		return wranges[0].Val.(watcherSet)
+	}
+
+	// copy case
+	ret := make(watcherSet)
+	ret.union(wg.keyWatchers[key])
+	for _, item := range wranges {
+		ret.union(item.Val.(watcherSet))
+	}
+	return ret
+}

+ 7 - 8
storage/watcher_test.go

@@ -35,7 +35,7 @@ func TestWatcherWatchID(t *testing.T) {
 	idm := make(map[WatchID]struct{})
 
 	for i := 0; i < 10; i++ {
-		id := w.Watch([]byte("foo"), false, 0)
+		id := w.Watch([]byte("foo"), nil, 0)
 		if _, ok := idm[id]; ok {
 			t.Errorf("#%d: id %d exists", i, id)
 		}
@@ -57,7 +57,7 @@ func TestWatcherWatchID(t *testing.T) {
 
 	// unsynced watchers
 	for i := 10; i < 20; i++ {
-		id := w.Watch([]byte("foo2"), false, 1)
+		id := w.Watch([]byte("foo2"), nil, 1)
 		if _, ok := idm[id]; ok {
 			t.Errorf("#%d: id %d exists", i, id)
 		}
@@ -86,12 +86,11 @@ func TestWatcherWatchPrefix(t *testing.T) {
 
 	idm := make(map[WatchID]struct{})
 
-	prefixMatch := true
 	val := []byte("bar")
-	keyWatch, keyPut := []byte("foo"), []byte("foobar")
+	keyWatch, keyEnd, keyPut := []byte("foo"), []byte("fop"), []byte("foobar")
 
 	for i := 0; i < 10; i++ {
-		id := w.Watch(keyWatch, prefixMatch, 0)
+		id := w.Watch(keyWatch, keyEnd, 0)
 		if _, ok := idm[id]; ok {
 			t.Errorf("#%d: unexpected duplicated id %x", i, id)
 		}
@@ -118,12 +117,12 @@ func TestWatcherWatchPrefix(t *testing.T) {
 		}
 	}
 
-	keyWatch1, keyPut1 := []byte("foo1"), []byte("foo1bar")
+	keyWatch1, keyEnd1, keyPut1 := []byte("foo1"), []byte("foo2"), []byte("foo1bar")
 	s.Put(keyPut1, val, lease.NoLease)
 
 	// unsynced watchers
 	for i := 10; i < 15; i++ {
-		id := w.Watch(keyWatch1, prefixMatch, 1)
+		id := w.Watch(keyWatch1, keyEnd1, 1)
 		if _, ok := idm[id]; ok {
 			t.Errorf("#%d: id %d exists", i, id)
 		}
@@ -159,7 +158,7 @@ func TestWatchStreamCancelWatcherByID(t *testing.T) {
 	w := s.NewWatchStream()
 	defer w.Close()
 
-	id := w.Watch([]byte("foo"), false, 0)
+	id := w.Watch([]byte("foo"), nil, 0)
 
 	tests := []struct {
 		cancelID WatchID