Переглянути джерело

pkg/adt: fix interval tree black-height property based on rbtree

Author: xkey <xk33430@ly.com>
ref. https://github.com/etcd-io/etcd/pull/10978

Signed-off-by: Gyuho Lee <leegyuho@amazon.com>
xkey 6 роки тому
батько
коміт
036bd1ab09
2 змінених файлів з 101 додано та 86 видалено
  1. 100 85
      pkg/adt/interval_tree.go
  2. 1 1
      pkg/adt/interval_tree_test.go

+ 100 - 85
pkg/adt/interval_tree.go

@@ -87,39 +87,39 @@ type intervalNode struct {
 	c      rbcolor
 }
 
-func (x *intervalNode) color() rbcolor {
-	if x == nil {
+func (x *intervalNode) color(sentinel *intervalNode) rbcolor {
+	if x == sentinel {
 		return black
 	}
 	return x.c
 }
 
-func (x *intervalNode) height() int {
-	if x == nil {
+func (x *intervalNode) height(sentinel *intervalNode) int {
+	if x == sentinel {
 		return 0
 	}
-	ld := x.left.height()
-	rd := x.right.height()
+	ld := x.left.height(sentinel)
+	rd := x.right.height(sentinel)
 	if ld < rd {
 		return rd + 1
 	}
 	return ld + 1
 }
 
-func (x *intervalNode) min() *intervalNode {
-	for x.left != nil {
+func (x *intervalNode) min(sentinel *intervalNode) *intervalNode {
+	for x.left != sentinel {
 		x = x.left
 	}
 	return x
 }
 
 // successor is the next in-order node in the tree
-func (x *intervalNode) successor() *intervalNode {
-	if x.right != nil {
-		return x.right.min()
+func (x *intervalNode) successor(sentinel *intervalNode) *intervalNode {
+	if x.right != sentinel {
+		return x.right.min(sentinel)
 	}
 	y := x.parent
-	for y != nil && x == y.right {
+	for y != sentinel && x == y.right {
 		x = y
 		y = y.parent
 	}
@@ -127,14 +127,14 @@ func (x *intervalNode) successor() *intervalNode {
 }
 
 // updateMax updates the maximum values for a node and its ancestors
-func (x *intervalNode) updateMax() {
-	for x != nil {
+func (x *intervalNode) updateMax(sentinel *intervalNode) {
+	for x != sentinel {
 		oldmax := x.max
 		max := x.iv.Ivl.End
-		if x.left != nil && x.left.max.Compare(max) > 0 {
+		if x.left != sentinel && x.left.max.Compare(max) > 0 {
 			max = x.left.max
 		}
-		if x.right != nil && x.right.max.Compare(max) > 0 {
+		if x.right != sentinel && x.right.max.Compare(max) > 0 {
 			max = x.right.max
 		}
 		if oldmax.Compare(max) == 0 {
@@ -148,25 +148,25 @@ func (x *intervalNode) updateMax() {
 type nodeVisitor func(n *intervalNode) bool
 
 // visit will call a node visitor on each node that overlaps the given interval
-func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) bool {
-	if x == nil {
+func (x *intervalNode) visit(iv *Interval, sentinel *intervalNode, nv nodeVisitor) bool {
+	if x == sentinel {
 		return true
 	}
 	v := iv.Compare(&x.iv.Ivl)
 	switch {
 	case v < 0:
-		if !x.left.visit(iv, nv) {
+		if !x.left.visit(iv, sentinel, nv) {
 			return false
 		}
 	case v > 0:
 		maxiv := Interval{x.iv.Ivl.Begin, x.max}
 		if maxiv.Compare(iv) == 0 {
-			if !x.left.visit(iv, nv) || !x.right.visit(iv, nv) {
+			if !x.left.visit(iv, sentinel, nv) || !x.right.visit(iv, sentinel, nv) {
 				return false
 			}
 		}
 	default:
-		if !x.left.visit(iv, nv) || !nv(x) || !x.right.visit(iv, nv) {
+		if !x.left.visit(iv, sentinel, nv) || !nv(x) || !x.right.visit(iv, sentinel, nv) {
 			return false
 		}
 	}
@@ -211,9 +211,18 @@ type IntervalTree interface {
 
 // NewIntervalTree returns a new interval tree.
 func NewIntervalTree() IntervalTree {
+	sentinel := &intervalNode{
+		iv:     IntervalValue{},
+		max:    nil,
+		left:   nil,
+		right:  nil,
+		parent: nil,
+		c:      black,
+	}
 	return &intervalTree{
-		root:  nil,
-		count: 0,
+		root:     sentinel,
+		count:    0,
+		sentinel: sentinel,
 	}
 }
 
@@ -221,9 +230,11 @@ type intervalTree struct {
 	root  *intervalNode
 	count int
 
-	// TODO: use 'sentinel' as a dummy object to simplify boundary conditions
+	// red-black NIL node
+	// use 'sentinel' as a dummy object to simplify boundary conditions
 	// use the sentinel to treat a nil child of a node x as an ordinary node whose parent is x
 	// use one shared sentinel to represent all nil leaves and the root's parent
+	sentinel *intervalNode
 }
 
 // TODO: make this consistent with textbook implementation
@@ -263,24 +274,25 @@ type intervalTree struct {
 // true if a node is in fact removed.
 func (ivt *intervalTree) Delete(ivl Interval) bool {
 	z := ivt.find(ivl)
-	if z == nil {
+	if z == ivt.sentinel {
 		return false
 	}
 
 	y := z
-	if z.left != nil && z.right != nil {
-		y = z.successor()
+	if z.left != ivt.sentinel && z.right != ivt.sentinel {
+		y = z.successor(ivt.sentinel)
 	}
 
-	x := y.left
-	if x == nil {
+	x := ivt.sentinel
+	if y.left != ivt.sentinel {
+		x = y.left
+	} else if y.right != ivt.sentinel {
 		x = y.right
 	}
-	if x != nil {
-		x.parent = y.parent
-	}
 
-	if y.parent == nil {
+	x.parent = y.parent
+
+	if y.parent == ivt.sentinel {
 		ivt.root = x
 	} else {
 		if y == y.parent.left {
@@ -288,14 +300,14 @@ func (ivt *intervalTree) Delete(ivl Interval) bool {
 		} else {
 			y.parent.right = x
 		}
-		y.parent.updateMax()
+		y.parent.updateMax(ivt.sentinel)
 	}
 	if y != z {
 		z.iv = y.iv
-		z.updateMax()
+		z.updateMax(ivt.sentinel)
 	}
 
-	if y.color() == black && x != nil {
+	if y.color(ivt.sentinel) == black {
 		ivt.deleteFixup(x)
 	}
 
@@ -348,10 +360,10 @@ func (ivt *intervalTree) Delete(ivl Interval) bool {
 //	40. x.color = BLACK
 //
 func (ivt *intervalTree) deleteFixup(x *intervalNode) {
-	for x != ivt.root && x.color() == black && x.parent != nil {
+	for x != ivt.root && x.color(ivt.sentinel) == black {
 		if x == x.parent.left { // line 3-20
 			w := x.parent.right
-			if w.color() == red {
+			if w.color(ivt.sentinel) == red {
 				w.c = black
 				x.parent.c = red
 				ivt.rotateLeft(x.parent)
@@ -360,28 +372,26 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) {
 			if w == nil {
 				break
 			}
-			if w.left.color() == black && w.right.color() == black {
+			if w.left.color(ivt.sentinel) == black && w.right.color(ivt.sentinel) == black {
 				w.c = red
 				x = x.parent
 			} else {
-				if w.right.color() == black {
+				if w.right.color(ivt.sentinel) == black {
 					w.left.c = black
 					w.c = red
 					ivt.rotateRight(w)
 					w = x.parent.right
 				}
-				w.c = x.parent.color()
+				w.c = x.parent.color(ivt.sentinel)
 				x.parent.c = black
 				w.right.c = black
 				ivt.rotateLeft(x.parent)
 				x = ivt.root
 			}
-
 		} else { // line 22-38
-
 			// same as above but with left and right exchanged
 			w := x.parent.left
-			if w.color() == red {
+			if w.color(ivt.sentinel) == red {
 				w.c = black
 				x.parent.c = red
 				ivt.rotateRight(x.parent)
@@ -390,17 +400,17 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) {
 			if w == nil {
 				break
 			}
-			if w.left.color() == black && w.right.color() == black {
+			if w.left.color(ivt.sentinel) == black && w.right.color(ivt.sentinel) == black {
 				w.c = red
 				x = x.parent
 			} else {
-				if w.left.color() == black {
+				if w.left.color(ivt.sentinel) == black {
 					w.right.c = black
 					w.c = red
 					ivt.rotateLeft(w)
 					w = x.parent.left
 				}
-				w.c = x.parent.color()
+				w.c = x.parent.color(ivt.sentinel)
 				x.parent.c = black
 				w.left.c = black
 				ivt.rotateRight(x.parent)
@@ -419,9 +429,9 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte
 		iv:     IntervalValue{ivl, val},
 		max:    ivl.End,
 		c:      red,
-		left:   nil,
-		right:  nil,
-		parent: nil,
+		left:   ivt.sentinel,
+		right:  ivt.sentinel,
+		parent: ivt.sentinel,
 	}
 }
 
@@ -458,10 +468,10 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte
 
 // Insert adds a node with the given interval into the tree.
 func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
-	var y *intervalNode
+	y := ivt.sentinel
 	z := ivt.createIntervalNode(ivl, val)
 	x := ivt.root
-	for x != nil {
+	for x != ivt.sentinel {
 		y = x
 		if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
 			x = x.left
@@ -471,7 +481,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
 	}
 
 	z.parent = y
-	if y == nil {
+	if y == ivt.sentinel {
 		ivt.root = z
 	} else {
 		if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
@@ -479,7 +489,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
 		} else {
 			y.right = z
 		}
-		y.updateMax()
+		y.updateMax(ivt.sentinel)
 	}
 	z.c = red
 
@@ -522,10 +532,11 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
 //	30. T.root.color = BLACK
 //
 func (ivt *intervalTree) insertFixup(z *intervalNode) {
-	for z.parent != nil && z.parent.parent != nil && z.parent.color() == red {
+	for z.parent.color(ivt.sentinel) == red {
 		if z.parent == z.parent.parent.left { // line 3-15
+
 			y := z.parent.parent.right
-			if y.color() == red {
+			if y.color(ivt.sentinel) == red {
 				y.c = black
 				z.parent.c = black
 				z.parent.parent.c = red
@@ -542,7 +553,7 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) {
 		} else { // line 16-28
 			// same as then with left/right exchanged
 			y := z.parent.parent.left
-			if y.color() == red {
+			if y.color(ivt.sentinel) == red {
 				y.c = black
 				z.parent.c = black
 				z.parent.parent.c = red
@@ -588,23 +599,27 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) {
 //	18. x.p = y
 //
 func (ivt *intervalTree) rotateLeft(x *intervalNode) {
+	// rotateLeft x must have right child
+	if x.right == ivt.sentinel {
+		return
+	}
+
 	// line 2-3
 	y := x.right
 	x.right = y.left
 
 	// line 5-6
-	if y.left != nil {
+	if y.left != ivt.sentinel {
 		y.left.parent = x
 	}
-
-	x.updateMax()
+	x.updateMax(ivt.sentinel)
 
 	// line 10-15, 18
 	ivt.replaceParent(x, y)
 
 	// line 17
 	y.left = x
-	y.updateMax()
+	y.updateMax(ivt.sentinel)
 }
 
 // rotateRight moves x so it is right of its left child
@@ -630,7 +645,8 @@ func (ivt *intervalTree) rotateLeft(x *intervalNode) {
 //	18. x.p = y
 //
 func (ivt *intervalTree) rotateRight(x *intervalNode) {
-	if x == nil {
+	// rotateRight x must have left child
+	if x.left == ivt.sentinel {
 		return
 	}
 
@@ -639,24 +655,23 @@ func (ivt *intervalTree) rotateRight(x *intervalNode) {
 	x.left = y.right
 
 	// line 5-6
-	if y.right != nil {
+	if y.right != ivt.sentinel {
 		y.right.parent = x
 	}
-
-	x.updateMax()
+	x.updateMax(ivt.sentinel)
 
 	// line 10-15, 18
 	ivt.replaceParent(x, y)
 
 	// line 17
 	y.right = x
-	y.updateMax()
+	y.updateMax(ivt.sentinel)
 }
 
 // replaceParent replaces x's parent with y
 func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
 	y.parent = x.parent
-	if x.parent == nil {
+	if x.parent == ivt.sentinel {
 		ivt.root = y
 	} else {
 		if x == x.parent.left {
@@ -664,7 +679,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
 		} else {
 			x.parent.right = y
 		}
-		x.parent.updateMax()
+		x.parent.updateMax(ivt.sentinel)
 	}
 	x.parent = y
 }
@@ -673,7 +688,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
 func (ivt *intervalTree) Len() int { return ivt.count }
 
 // Height is the number of levels in the tree; one node has height 1.
-func (ivt *intervalTree) Height() int { return ivt.root.height() }
+func (ivt *intervalTree) Height() int { return ivt.root.height(ivt.sentinel) }
 
 // MaxHeight is the expected maximum tree height given the number of nodes
 func (ivt *intervalTree) MaxHeight() int {
@@ -686,11 +701,12 @@ type IntervalVisitor func(n *IntervalValue) bool
 // Visit calls a visitor function on every tree node intersecting the given interval.
 // It will visit each interval [x, y) in ascending order sorted on x.
 func (ivt *intervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
-	ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
+	ivt.root.visit(&ivl, ivt.sentinel, func(n *intervalNode) bool { return ivv(&n.iv) })
 }
 
 // find the exact node for a given interval
-func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) {
+func (ivt *intervalTree) find(ivl Interval) *intervalNode {
+	ret := ivt.sentinel
 	f := func(n *intervalNode) bool {
 		if n.iv.Ivl != ivl {
 			return true
@@ -698,14 +714,14 @@ func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) {
 		ret = n
 		return false
 	}
-	ivt.root.visit(&ivl, f)
+	ivt.root.visit(&ivl, ivt.sentinel, f)
 	return ret
 }
 
 // Find gets the IntervalValue for the node matching the given interval
 func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) {
 	n := ivt.find(ivl)
-	if n == nil {
+	if n == ivt.sentinel {
 		return nil
 	}
 	return &n.iv
@@ -714,14 +730,14 @@ func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) {
 // Intersects returns true if there is some tree node intersecting the given interval.
 func (ivt *intervalTree) Intersects(iv Interval) bool {
 	x := ivt.root
-	for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
-		if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
+	for x != ivt.sentinel && iv.Compare(&x.iv.Ivl) != 0 {
+		if x.left != ivt.sentinel && x.left.max.Compare(iv.Begin) > 0 {
 			x = x.left
 		} else {
 			x = x.right
 		}
 	}
-	return x != nil
+	return x != ivt.sentinel
 }
 
 // Contains returns true if the interval tree's keys cover the entire given interval.
@@ -789,7 +805,7 @@ func (vi visitedInterval) String() string {
 // visitLevel traverses tree in level order.
 // used for testing
 func (ivt *intervalTree) visitLevel() []visitedInterval {
-	if ivt.root == nil {
+	if ivt.root == ivt.sentinel {
 		return nil
 	}
 
@@ -804,22 +820,21 @@ func (ivt *intervalTree) visitLevel() []visitedInterval {
 		f := queue[0]
 		queue = queue[1:]
 
-		ivt := visitedInterval{
+		vi := visitedInterval{
 			root:  f.node.iv.Ivl,
-			color: f.node.color(),
+			color: f.node.color(ivt.sentinel),
 			depth: f.depth,
 		}
-
-		if f.node.left != nil {
-			ivt.left = f.node.left.iv.Ivl
+		if f.node.left != ivt.sentinel {
+			vi.left = f.node.left.iv.Ivl
 			queue = append(queue, pair{f.node.left, f.depth + 1})
 		}
-		if f.node.right != nil {
-			ivt.right = f.node.right.iv.Ivl
+		if f.node.right != ivt.sentinel {
+			vi.right = f.node.right.iv.Ivl
 			queue = append(queue, pair{f.node.right, f.depth + 1})
 		}
 
-		rs = append(rs, ivt)
+		rs = append(rs, vi)
 	}
 
 	return rs

+ 1 - 1
pkg/adt/interval_tree_test.go

@@ -298,7 +298,7 @@ func TestIntervalTreeDelete(t *testing.T) {
 		//        /     \                                                              /
 		// [238,239]   [292,293]                                                [953,954]
 		//
-		t.Logf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11)
+		t.Fatalf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11)
 	}
 }