瀏覽代碼

adt: Visit() interval trees in sorted order and terminate early

For all intervals [x, y), Visit will visit intervals in ascending order
sorted by x. Also fixes a bug where Visit would not terminate the search
when requested by the visitor function.
Anthony Romano 8 年之前
父節點
當前提交
25e3ce1feb
共有 2 個文件被更改,包括 108 次插入8 次删除
  1. 13 8
      pkg/adt/interval_tree.go
  2. 95 0
      pkg/adt/interval_tree_test.go

+ 13 - 8
pkg/adt/interval_tree.go

@@ -134,25 +134,29 @@ func (x *intervalNode) updateMax() {
 type nodeVisitor func(n *intervalNode) bool
 
 // visit will call a node visitor on each node that overlaps the given interval
-func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) {
+func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) bool {
 	if x == nil {
-		return
+		return true
 	}
 	v := iv.Compare(&x.iv.Ivl)
 	switch {
 	case v < 0:
-		x.left.visit(iv, nv)
+		if !x.left.visit(iv, nv) {
+			return false
+		}
 	case v > 0:
 		maxiv := Interval{x.iv.Ivl.Begin, x.max}
 		if maxiv.Compare(iv) == 0 {
-			x.left.visit(iv, nv)
-			x.right.visit(iv, nv)
+			if !x.left.visit(iv, nv) || !x.right.visit(iv, nv) {
+				return false
+			}
 		}
 	default:
-		nv(x)
-		x.left.visit(iv, nv)
-		x.right.visit(iv, nv)
+		if !x.left.visit(iv, nv) || !nv(x) || !x.right.visit(iv, nv) {
+			return false
+		}
 	}
+	return true
 }
 
 type IntervalValue struct {
@@ -406,6 +410,7 @@ func (ivt *IntervalTree) MaxHeight() int {
 type IntervalVisitor func(n *IntervalValue) bool
 
 // Visit calls a visitor function on every tree node intersecting the given interval.
+// It will visit each interval [x, y) in ascending order sorted on x.
 func (ivt *IntervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
 	ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
 }

+ 95 - 0
pkg/adt/interval_tree_test.go

@@ -136,3 +136,98 @@ func TestIntervalTreeRandom(t *testing.T) {
 		t.Errorf("got ivt.Len() = %v, expected 0", ivt.Len())
 	}
 }
+
+// TestIntervalTreeSortedVisit tests that intervals are visited in sorted order.
+func TestIntervalTreeSortedVisit(t *testing.T) {
+	tests := []struct {
+		ivls       []Interval
+		visitRange Interval
+	}{
+		{
+			ivls:       []Interval{NewInt64Interval(1, 10), NewInt64Interval(2, 5), NewInt64Interval(3, 6)},
+			visitRange: NewInt64Interval(0, 100),
+		},
+		{
+			ivls:       []Interval{NewInt64Interval(1, 10), NewInt64Interval(10, 12), NewInt64Interval(3, 6)},
+			visitRange: NewInt64Interval(0, 100),
+		},
+		{
+			ivls:       []Interval{NewInt64Interval(2, 3), NewInt64Interval(3, 4), NewInt64Interval(6, 7), NewInt64Interval(5, 6)},
+			visitRange: NewInt64Interval(0, 100),
+		},
+		{
+			ivls: []Interval{
+				NewInt64Interval(2, 3),
+				NewInt64Interval(2, 4),
+				NewInt64Interval(3, 7),
+				NewInt64Interval(2, 5),
+				NewInt64Interval(3, 8),
+				NewInt64Interval(3, 5),
+			},
+			visitRange: NewInt64Interval(0, 100),
+		},
+	}
+	for i, tt := range tests {
+		ivt := &IntervalTree{}
+		for _, ivl := range tt.ivls {
+			ivt.Insert(ivl, struct{}{})
+		}
+		last := tt.ivls[0].Begin
+		count := 0
+		chk := func(iv *IntervalValue) bool {
+			if last.Compare(iv.Ivl.Begin) > 0 {
+				t.Errorf("#%d: expected less than %d, got interval %+v", i, last, iv.Ivl)
+			}
+			last = iv.Ivl.Begin
+			count++
+			return true
+		}
+		ivt.Visit(tt.visitRange, chk)
+		if count != len(tt.ivls) {
+			t.Errorf("#%d: did not cover all intervals. expected %d, got %d", i, len(tt.ivls), count)
+		}
+	}
+}
+
+// TestIntervalTreeVisitExit tests that visiting can be stopped.
+func TestIntervalTreeVisitExit(t *testing.T) {
+	ivls := []Interval{NewInt64Interval(1, 10), NewInt64Interval(2, 5), NewInt64Interval(3, 6), NewInt64Interval(4, 8)}
+	ivlRange := NewInt64Interval(0, 100)
+	tests := []struct {
+		f IntervalVisitor
+
+		wcount int
+	}{
+		{
+			f:      func(n *IntervalValue) bool { return false },
+			wcount: 1,
+		},
+		{
+			f:      func(n *IntervalValue) bool { return n.Ivl.Begin.Compare(ivls[0].Begin) <= 0 },
+			wcount: 2,
+		},
+		{
+			f:      func(n *IntervalValue) bool { return n.Ivl.Begin.Compare(ivls[2].Begin) < 0 },
+			wcount: 3,
+		},
+		{
+			f:      func(n *IntervalValue) bool { return true },
+			wcount: 4,
+		},
+	}
+
+	for i, tt := range tests {
+		ivt := &IntervalTree{}
+		for _, ivl := range ivls {
+			ivt.Insert(ivl, struct{}{})
+		}
+		count := 0
+		ivt.Visit(ivlRange, func(n *IntervalValue) bool {
+			count++
+			return tt.f(n)
+		})
+		if count != tt.wcount {
+			t.Errorf("#%d: expected count %d, got %d", i, tt.wcount, count)
+		}
+	}
+}