123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531 |
- // Copyright 2016 The etcd Authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package adt
- import (
- "math"
- )
- // Comparable is an interface for trichotomic comparisons.
- type Comparable interface {
- // Compare gives the result of a 3-way comparison
- // a.Compare(b) = 1 => a > b
- // a.Compare(b) = 0 => a == b
- // a.Compare(b) = -1 => a < b
- Compare(c Comparable) int
- }
- type rbcolor int
- const (
- black rbcolor = iota
- red
- )
- // Interval implements a Comparable interval [begin, end)
- // TODO: support different sorts of intervals: (a,b), [a,b], (a, b]
- type Interval struct {
- Begin Comparable
- End Comparable
- }
- // Compare on an interval gives == if the interval overlaps.
- func (ivl *Interval) Compare(c Comparable) int {
- ivl2 := c.(*Interval)
- ivbCmpBegin := ivl.Begin.Compare(ivl2.Begin)
- ivbCmpEnd := ivl.Begin.Compare(ivl2.End)
- iveCmpBegin := ivl.End.Compare(ivl2.Begin)
- // ivl is left of ivl2
- if ivbCmpBegin < 0 && iveCmpBegin <= 0 {
- return -1
- }
- // iv is right of iv2
- if ivbCmpEnd >= 0 {
- return 1
- }
- return 0
- }
- type intervalNode struct {
- // iv is the interval-value pair entry.
- iv IntervalValue
- // max endpoint of all descendent nodes.
- max Comparable
- // left and right are sorted by low endpoint of key interval
- left, right *intervalNode
- // parent is the direct ancestor of the node
- parent *intervalNode
- c rbcolor
- }
- func (x *intervalNode) color() rbcolor {
- if x == nil {
- return black
- }
- return x.c
- }
- func (n *intervalNode) height() int {
- if n == nil {
- return 0
- }
- ld := n.left.height()
- rd := n.right.height()
- if ld < rd {
- return rd + 1
- }
- return ld + 1
- }
- func (x *intervalNode) min() *intervalNode {
- for x.left != nil {
- x = x.left
- }
- return x
- }
- // successor is the next in-order node in the tree
- func (x *intervalNode) successor() *intervalNode {
- if x.right != nil {
- return x.right.min()
- }
- y := x.parent
- for y != nil && x == y.right {
- x = y
- y = y.parent
- }
- return y
- }
- // updateMax updates the maximum values for a node and its ancestors
- func (x *intervalNode) updateMax() {
- for x != nil {
- oldmax := x.max
- max := x.iv.Ivl.End
- if x.left != nil && x.left.max.Compare(max) > 0 {
- max = x.left.max
- }
- if x.right != nil && x.right.max.Compare(max) > 0 {
- max = x.right.max
- }
- if oldmax.Compare(max) == 0 {
- break
- }
- x.max = max
- x = x.parent
- }
- }
- type nodeVisitor func(n *intervalNode) bool
- // visit will call a node visitor on each node that overlaps the given interval
- func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) {
- if x == nil {
- return
- }
- v := iv.Compare(&x.iv.Ivl)
- switch {
- case v < 0:
- x.left.visit(iv, nv)
- case v > 0:
- maxiv := Interval{x.iv.Ivl.Begin, x.max}
- if maxiv.Compare(iv) == 0 {
- x.left.visit(iv, nv)
- x.right.visit(iv, nv)
- }
- default:
- nv(x)
- x.left.visit(iv, nv)
- x.right.visit(iv, nv)
- }
- }
- type IntervalValue struct {
- Ivl Interval
- Val interface{}
- }
- // IntervalTree represents a (mostly) textbook implementation of the
- // "Introduction to Algorithms" (Cormen et al, 2nd ed.) chapter 13 red-black tree
- // and chapter 14.3 interval tree with search supporting "stabbing queries".
- type IntervalTree struct {
- root *intervalNode
- count int
- }
- // Delete removes the node with the given interval from the tree, returning
- // true if a node is in fact removed.
- func (ivt *IntervalTree) Delete(ivl Interval) bool {
- z := ivt.find(ivl)
- if z == nil {
- return false
- }
- y := z
- if z.left != nil && z.right != nil {
- y = z.successor()
- }
- x := y.left
- if x == nil {
- x = y.right
- }
- if x != nil {
- x.parent = y.parent
- }
- if y.parent == nil {
- ivt.root = x
- } else {
- if y == y.parent.left {
- y.parent.left = x
- } else {
- y.parent.right = x
- }
- y.parent.updateMax()
- }
- if y != z {
- z.iv = y.iv
- z.updateMax()
- }
- if y.color() == black && x != nil {
- ivt.deleteFixup(x)
- }
- ivt.count--
- return true
- }
- func (ivt *IntervalTree) deleteFixup(x *intervalNode) {
- for x != ivt.root && x.color() == black && x.parent != nil {
- if x == x.parent.left {
- w := x.parent.right
- if w.color() == red {
- w.c = black
- x.parent.c = red
- ivt.rotateLeft(x.parent)
- w = x.parent.right
- }
- if w == nil {
- break
- }
- if w.left.color() == black && w.right.color() == black {
- w.c = red
- x = x.parent
- } else {
- if w.right.color() == black {
- w.left.c = black
- w.c = red
- ivt.rotateRight(w)
- w = x.parent.right
- }
- w.c = x.parent.color()
- x.parent.c = black
- w.right.c = black
- ivt.rotateLeft(x.parent)
- x = ivt.root
- }
- } else {
- // same as above but with left and right exchanged
- w := x.parent.left
- if w.color() == red {
- w.c = black
- x.parent.c = red
- ivt.rotateRight(x.parent)
- w = x.parent.left
- }
- if w == nil {
- break
- }
- if w.left.color() == black && w.right.color() == black {
- w.c = red
- x = x.parent
- } else {
- if w.left.color() == black {
- w.right.c = black
- w.c = red
- ivt.rotateLeft(w)
- w = x.parent.left
- }
- w.c = x.parent.color()
- x.parent.c = black
- w.left.c = black
- ivt.rotateRight(x.parent)
- x = ivt.root
- }
- }
- }
- if x != nil {
- x.c = black
- }
- }
- // Insert adds a node with the given interval into the tree.
- func (ivt *IntervalTree) Insert(ivl Interval, val interface{}) {
- var y *intervalNode
- z := &intervalNode{iv: IntervalValue{ivl, val}, max: ivl.End, c: red}
- x := ivt.root
- for x != nil {
- y = x
- if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
- x = x.left
- } else {
- x = x.right
- }
- }
- z.parent = y
- if y == nil {
- ivt.root = z
- } else {
- if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
- y.left = z
- } else {
- y.right = z
- }
- y.updateMax()
- }
- z.c = red
- ivt.insertFixup(z)
- ivt.count++
- }
- func (ivt *IntervalTree) insertFixup(z *intervalNode) {
- for z.parent != nil && z.parent.parent != nil && z.parent.color() == red {
- if z.parent == z.parent.parent.left {
- y := z.parent.parent.right
- if y.color() == red {
- y.c = black
- z.parent.c = black
- z.parent.parent.c = red
- z = z.parent.parent
- } else {
- if z == z.parent.right {
- z = z.parent
- ivt.rotateLeft(z)
- }
- z.parent.c = black
- z.parent.parent.c = red
- ivt.rotateRight(z.parent.parent)
- }
- } else {
- // same as then with left/right exchanged
- y := z.parent.parent.left
- if y.color() == red {
- y.c = black
- z.parent.c = black
- z.parent.parent.c = red
- z = z.parent.parent
- } else {
- if z == z.parent.left {
- z = z.parent
- ivt.rotateRight(z)
- }
- z.parent.c = black
- z.parent.parent.c = red
- ivt.rotateLeft(z.parent.parent)
- }
- }
- }
- ivt.root.c = black
- }
- // rotateLeft moves x so it is left of its right child
- func (ivt *IntervalTree) rotateLeft(x *intervalNode) {
- y := x.right
- x.right = y.left
- if y.left != nil {
- y.left.parent = x
- }
- x.updateMax()
- ivt.replaceParent(x, y)
- y.left = x
- y.updateMax()
- }
- // rotateLeft moves x so it is right of its left child
- func (ivt *IntervalTree) rotateRight(x *intervalNode) {
- if x == nil {
- return
- }
- y := x.left
- x.left = y.right
- if y.right != nil {
- y.right.parent = x
- }
- x.updateMax()
- ivt.replaceParent(x, y)
- y.right = x
- y.updateMax()
- }
- // replaceParent replaces x's parent with y
- func (ivt *IntervalTree) replaceParent(x *intervalNode, y *intervalNode) {
- y.parent = x.parent
- if x.parent == nil {
- ivt.root = y
- } else {
- if x == x.parent.left {
- x.parent.left = y
- } else {
- x.parent.right = y
- }
- x.parent.updateMax()
- }
- x.parent = y
- }
- // Len gives the number of elements in the tree
- func (ivt *IntervalTree) Len() int { return ivt.count }
- // Height is the number of levels in the tree; one node has height 1.
- func (ivt *IntervalTree) Height() int { return ivt.root.height() }
- // MaxHeight is the expected maximum tree height given the number of nodes
- func (ivt *IntervalTree) MaxHeight() int {
- return int((2 * math.Log2(float64(ivt.Len()+1))) + 0.5)
- }
- // IntervalVisitor is used on tree searchs; return false to stop searching.
- type IntervalVisitor func(n *IntervalValue) bool
- // Visit calls a visitor function on every tree node intersecting the given interval.
- func (ivt *IntervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
- ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
- }
- // find the exact node for a given interval
- func (ivt *IntervalTree) find(ivl Interval) (ret *intervalNode) {
- f := func(n *intervalNode) bool {
- if n.iv.Ivl != ivl {
- return true
- }
- ret = n
- return false
- }
- ivt.root.visit(&ivl, f)
- return ret
- }
- // Find gets the IntervalValue for the node matching the given interval
- func (ivt *IntervalTree) Find(ivl Interval) (ret *IntervalValue) {
- n := ivt.find(ivl)
- if n == nil {
- return nil
- }
- return &n.iv
- }
- // Contains returns true if there is some tree node intersecting the given interval.
- func (ivt *IntervalTree) Contains(iv Interval) bool {
- x := ivt.root
- for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
- if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
- x = x.left
- } else {
- x = x.right
- }
- }
- return x != nil
- }
- // Stab returns a slice with all elements in the tree intersecting the interval.
- func (ivt *IntervalTree) Stab(iv Interval) (ivs []*IntervalValue) {
- if ivt.count == 0 {
- return nil
- }
- f := func(n *IntervalValue) bool { ivs = append(ivs, n); return true }
- ivt.Visit(iv, f)
- return ivs
- }
- type StringComparable string
- func (s StringComparable) Compare(c Comparable) int {
- sc := c.(StringComparable)
- if s < sc {
- return -1
- }
- if s > sc {
- return 1
- }
- return 0
- }
- func NewStringInterval(begin, end string) Interval {
- return Interval{StringComparable(begin), StringComparable(end)}
- }
- func NewStringPoint(s string) Interval {
- return Interval{StringComparable(s), StringComparable(s + "\x00")}
- }
- // StringAffineComparable treats "" as > all other strings
- type StringAffineComparable string
- func (s StringAffineComparable) Compare(c Comparable) int {
- sc := c.(StringAffineComparable)
- if len(s) == 0 {
- if len(sc) == 0 {
- return 0
- }
- return 1
- }
- if len(sc) == 0 {
- return -1
- }
- if s < sc {
- return -1
- }
- if s > sc {
- return 1
- }
- return 0
- }
- func NewStringAffineInterval(begin, end string) Interval {
- return Interval{StringAffineComparable(begin), StringAffineComparable(end)}
- }
- func NewStringAffinePoint(s string) Interval {
- return NewStringAffineInterval(s, s+"\x00")
- }
- func NewInt64Interval(a int64, b int64) Interval {
- return Interval{Int64Comparable(a), Int64Comparable(b)}
- }
- func NewInt64Point(a int64) Interval {
- return Interval{Int64Comparable(a), Int64Comparable(a + 1)}
- }
- type Int64Comparable int64
- func (v Int64Comparable) Compare(c Comparable) int {
- vc := c.(Int64Comparable)
- cmp := v - vc
- if cmp < 0 {
- return -1
- }
- if cmp > 0 {
- return 1
- }
- return 0
- }
|