瀏覽代碼

*: support checking that an interval tree's keys cover an entire interval

Anthony Romano 8 年之前
父節點
當前提交
f67bdc2eed
共有 4 個文件被更改,包括 106 次插入50 次删除
  1. 8 38
      auth/range_perm_cache.go
  2. 1 1
      mvcc/watcher_group.go
  3. 26 2
      pkg/adt/interval_tree.go
  4. 71 9
      pkg/adt/interval_tree_test.go

+ 8 - 38
auth/range_perm_cache.go

@@ -66,59 +66,29 @@ func getMergedPerms(tx backend.BatchTx, userName string) *unifiedRangePermission
 }
 
 func checkKeyInterval(cachedPerms *unifiedRangePermissions, key, rangeEnd string, permtyp authpb.Permission_Type) bool {
-	var tocheck *adt.IntervalTree
-
+	ivl := adt.NewStringInterval(key, rangeEnd)
 	switch permtyp {
 	case authpb.READ:
-		tocheck = cachedPerms.readPerms
+		return cachedPerms.readPerms.Contains(ivl)
 	case authpb.WRITE:
-		tocheck = cachedPerms.writePerms
+		return cachedPerms.writePerms.Contains(ivl)
 	default:
 		plog.Panicf("unknown auth type: %v", permtyp)
 	}
-
-	ivl := adt.NewStringInterval(key, rangeEnd)
-
-	isContiguous := true
-	var maxEnd, minBegin adt.Comparable
-
-	tocheck.Visit(ivl, func(n *adt.IntervalValue) bool {
-		if minBegin == nil {
-			minBegin = n.Ivl.Begin
-			maxEnd = n.Ivl.End
-			return true
-		}
-
-		if maxEnd.Compare(n.Ivl.Begin) < 0 {
-			isContiguous = false
-			return false
-		}
-
-		if n.Ivl.End.Compare(maxEnd) > 0 {
-			maxEnd = n.Ivl.End
-		}
-
-		return true
-	})
-
-	return isContiguous && maxEnd.Compare(ivl.End) >= 0 && minBegin.Compare(ivl.Begin) <= 0
+	return false
 }
 
 func checkKeyPoint(cachedPerms *unifiedRangePermissions, key string, permtyp authpb.Permission_Type) bool {
-	var tocheck *adt.IntervalTree
-
+	pt := adt.NewStringPoint(key)
 	switch permtyp {
 	case authpb.READ:
-		tocheck = cachedPerms.readPerms
+		return cachedPerms.readPerms.Intersects(pt)
 	case authpb.WRITE:
-		tocheck = cachedPerms.writePerms
+		return cachedPerms.writePerms.Intersects(pt)
 	default:
 		plog.Panicf("unknown auth type: %v", permtyp)
 	}
-
-	pt := adt.NewStringPoint(key)
-
-	return tocheck.Contains(pt)
+	return false
 }
 
 func (as *authStore) isRangeOpPermitted(tx backend.BatchTx, userName string, key, rangeEnd string, permtyp authpb.Permission_Type) bool {

+ 1 - 1
mvcc/watcher_group.go

@@ -183,7 +183,7 @@ func (wg *watcherGroup) add(wa *watcher) {
 // contains is whether the given key has a watcher in the group.
 func (wg *watcherGroup) contains(key string) bool {
 	_, ok := wg.keyWatchers[key]
-	return ok || wg.ranges.Contains(adt.NewStringAffinePoint(key))
+	return ok || wg.ranges.Intersects(adt.NewStringAffinePoint(key))
 }
 
 // size gives the number of unique watchers in the group.

+ 26 - 2
pkg/adt/interval_tree.go

@@ -437,8 +437,8 @@ func (ivt *IntervalTree) Find(ivl Interval) (ret *IntervalValue) {
 	return &n.iv
 }
 
-// Contains returns true if there is some tree node intersecting the given interval.
-func (ivt *IntervalTree) Contains(iv Interval) bool {
+// Intersects returns true if there is some tree node intersecting the given interval.
+func (ivt *IntervalTree) Intersects(iv Interval) bool {
 	x := ivt.root
 	for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
 		if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
@@ -450,6 +450,30 @@ func (ivt *IntervalTree) Contains(iv Interval) bool {
 	return x != nil
 }
 
+// Contains returns true if the interval tree's keys cover the entire given interval.
+func (ivt *IntervalTree) Contains(ivl Interval) bool {
+	var maxEnd, minBegin Comparable
+
+	isContiguous := true
+	ivt.Visit(ivl, func(n *IntervalValue) bool {
+		if minBegin == nil {
+			minBegin = n.Ivl.Begin
+			maxEnd = n.Ivl.End
+			return true
+		}
+		if maxEnd.Compare(n.Ivl.Begin) < 0 {
+			isContiguous = false
+			return false
+		}
+		if n.Ivl.End.Compare(maxEnd) > 0 {
+			maxEnd = n.Ivl.End
+		}
+		return true
+	})
+
+	return isContiguous && minBegin != nil && maxEnd.Compare(ivl.End) >= 0 && minBegin.Compare(ivl.Begin) <= 0
+}
+
 // Stab returns a slice with all elements in the tree intersecting the interval.
 func (ivt *IntervalTree) Stab(iv Interval) (ivs []*IntervalValue) {
 	if ivt.count == 0 {

+ 71 - 9
pkg/adt/interval_tree_test.go

@@ -20,23 +20,23 @@ import (
 	"time"
 )
 
-func TestIntervalTreeContains(t *testing.T) {
+func TestIntervalTreeIntersects(t *testing.T) {
 	ivt := &IntervalTree{}
 	ivt.Insert(NewStringInterval("1", "3"), 123)
 
-	if ivt.Contains(NewStringPoint("0")) {
+	if ivt.Intersects(NewStringPoint("0")) {
 		t.Errorf("contains 0")
 	}
-	if !ivt.Contains(NewStringPoint("1")) {
+	if !ivt.Intersects(NewStringPoint("1")) {
 		t.Errorf("missing 1")
 	}
-	if !ivt.Contains(NewStringPoint("11")) {
+	if !ivt.Intersects(NewStringPoint("11")) {
 		t.Errorf("missing 11")
 	}
-	if !ivt.Contains(NewStringPoint("2")) {
+	if !ivt.Intersects(NewStringPoint("2")) {
 		t.Errorf("missing 2")
 	}
-	if ivt.Contains(NewStringPoint("3")) {
+	if ivt.Intersects(NewStringPoint("3")) {
 		t.Errorf("contains 3")
 	}
 }
@@ -44,10 +44,10 @@ func TestIntervalTreeContains(t *testing.T) {
 func TestIntervalTreeStringAffine(t *testing.T) {
 	ivt := &IntervalTree{}
 	ivt.Insert(NewStringAffineInterval("8", ""), 123)
-	if !ivt.Contains(NewStringAffinePoint("9")) {
+	if !ivt.Intersects(NewStringAffinePoint("9")) {
 		t.Errorf("missing 9")
 	}
-	if ivt.Contains(NewStringAffinePoint("7")) {
+	if ivt.Intersects(NewStringAffinePoint("7")) {
 		t.Errorf("contains 7")
 	}
 }
@@ -122,7 +122,7 @@ func TestIntervalTreeRandom(t *testing.T) {
 			if slen := len(ivt.Stab(NewInt64Point(v))); slen == 0 {
 				t.Fatalf("expected %v stab non-zero for [%+v)", v, xy)
 			}
-			if !ivt.Contains(NewInt64Point(v)) {
+			if !ivt.Intersects(NewInt64Point(v)) {
 				t.Fatalf("did not get %d as expected for [%+v)", v, xy)
 			}
 		}
@@ -231,3 +231,65 @@ func TestIntervalTreeVisitExit(t *testing.T) {
 		}
 	}
 }
+
+// TestIntervalTreeContains tests that contains returns true iff the ivt maps the entire interval.
+func TestIntervalTreeContains(t *testing.T) {
+	tests := []struct {
+		ivls   []Interval
+		chkIvl Interval
+
+		wContains bool
+	}{
+		{
+			ivls:   []Interval{NewInt64Interval(1, 10)},
+			chkIvl: NewInt64Interval(0, 100),
+
+			wContains: false,
+		},
+		{
+			ivls:   []Interval{NewInt64Interval(1, 10)},
+			chkIvl: NewInt64Interval(1, 10),
+
+			wContains: true,
+		},
+		{
+			ivls:   []Interval{NewInt64Interval(1, 10)},
+			chkIvl: NewInt64Interval(2, 8),
+
+			wContains: true,
+		},
+		{
+			ivls:   []Interval{NewInt64Interval(1, 5), NewInt64Interval(6, 10)},
+			chkIvl: NewInt64Interval(1, 10),
+
+			wContains: false,
+		},
+		{
+			ivls:   []Interval{NewInt64Interval(1, 5), NewInt64Interval(3, 10)},
+			chkIvl: NewInt64Interval(1, 10),
+
+			wContains: true,
+		},
+		{
+			ivls:   []Interval{NewInt64Interval(1, 4), NewInt64Interval(4, 7), NewInt64Interval(3, 10)},
+			chkIvl: NewInt64Interval(1, 10),
+
+			wContains: true,
+		},
+		{
+			ivls:   []Interval{},
+			chkIvl: NewInt64Interval(1, 10),
+
+			wContains: false,
+		},
+	}
+	for i, tt := range tests {
+		ivt := &IntervalTree{}
+		for _, ivl := range tt.ivls {
+			ivt.Insert(ivl, struct{}{})
+		}
+		if v := ivt.Contains(tt.chkIvl); v != tt.wContains {
+			t.Errorf("#%d: ivt.Contains got %v, expected %v", i, v, tt.wContains)
+		}
+	}
+}