interval_tree.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. // Copyright 2016 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package adt
  15. import (
  16. "math"
  17. )
  18. // Comparable is an interface for trichotomic comparisons.
  19. type Comparable interface {
  20. // Compare gives the result of a 3-way comparison
  21. // a.Compare(b) = 1 => a > b
  22. // a.Compare(b) = 0 => a == b
  23. // a.Compare(b) = -1 => a < b
  24. Compare(c Comparable) int
  25. }
  26. type rbcolor int
  27. const (
  28. black rbcolor = iota
  29. red
  30. )
  31. // Interval implements a Comparable interval [begin, end)
  32. // TODO: support different sorts of intervals: (a,b), [a,b], (a, b]
  33. type Interval struct {
  34. Begin Comparable
  35. End Comparable
  36. }
  37. // Compare on an interval gives == if the interval overlaps.
  38. func (ivl *Interval) Compare(c Comparable) int {
  39. ivl2 := c.(*Interval)
  40. ivbCmpBegin := ivl.Begin.Compare(ivl2.Begin)
  41. ivbCmpEnd := ivl.Begin.Compare(ivl2.End)
  42. iveCmpBegin := ivl.End.Compare(ivl2.Begin)
  43. // ivl is left of ivl2
  44. if ivbCmpBegin < 0 && iveCmpBegin <= 0 {
  45. return -1
  46. }
  47. // iv is right of iv2
  48. if ivbCmpEnd >= 0 {
  49. return 1
  50. }
  51. return 0
  52. }
  53. type intervalNode struct {
  54. // iv is the interval-value pair entry.
  55. iv IntervalValue
  56. // max endpoint of all descendent nodes.
  57. max Comparable
  58. // left and right are sorted by low endpoint of key interval
  59. left, right *intervalNode
  60. // parent is the direct ancestor of the node
  61. parent *intervalNode
  62. c rbcolor
  63. }
  64. func (x *intervalNode) color() rbcolor {
  65. if x == nil {
  66. return black
  67. }
  68. return x.c
  69. }
  70. func (n *intervalNode) height() int {
  71. if n == nil {
  72. return 0
  73. }
  74. ld := n.left.height()
  75. rd := n.right.height()
  76. if ld < rd {
  77. return rd + 1
  78. }
  79. return ld + 1
  80. }
  81. func (x *intervalNode) min() *intervalNode {
  82. for x.left != nil {
  83. x = x.left
  84. }
  85. return x
  86. }
  87. // successor is the next in-order node in the tree
  88. func (x *intervalNode) successor() *intervalNode {
  89. if x.right != nil {
  90. return x.right.min()
  91. }
  92. y := x.parent
  93. for y != nil && x == y.right {
  94. x = y
  95. y = y.parent
  96. }
  97. return y
  98. }
  99. // updateMax updates the maximum values for a node and its ancestors
  100. func (x *intervalNode) updateMax() {
  101. for x != nil {
  102. oldmax := x.max
  103. max := x.iv.Ivl.End
  104. if x.left != nil && x.left.max.Compare(max) > 0 {
  105. max = x.left.max
  106. }
  107. if x.right != nil && x.right.max.Compare(max) > 0 {
  108. max = x.right.max
  109. }
  110. if oldmax.Compare(max) == 0 {
  111. break
  112. }
  113. x.max = max
  114. x = x.parent
  115. }
  116. }
  117. type nodeVisitor func(n *intervalNode) bool
  118. // visit will call a node visitor on each node that overlaps the given interval
  119. func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) bool {
  120. if x == nil {
  121. return true
  122. }
  123. v := iv.Compare(&x.iv.Ivl)
  124. switch {
  125. case v < 0:
  126. if !x.left.visit(iv, nv) {
  127. return false
  128. }
  129. case v > 0:
  130. maxiv := Interval{x.iv.Ivl.Begin, x.max}
  131. if maxiv.Compare(iv) == 0 {
  132. if !x.left.visit(iv, nv) || !x.right.visit(iv, nv) {
  133. return false
  134. }
  135. }
  136. default:
  137. if !x.left.visit(iv, nv) || !nv(x) || !x.right.visit(iv, nv) {
  138. return false
  139. }
  140. }
  141. return true
  142. }
  143. type IntervalValue struct {
  144. Ivl Interval
  145. Val interface{}
  146. }
  147. // IntervalTree represents a (mostly) textbook implementation of the
  148. // "Introduction to Algorithms" (Cormen et al, 2nd ed.) chapter 13 red-black tree
  149. // and chapter 14.3 interval tree with search supporting "stabbing queries".
  150. type IntervalTree struct {
  151. root *intervalNode
  152. count int
  153. }
  154. // Delete removes the node with the given interval from the tree, returning
  155. // true if a node is in fact removed.
  156. func (ivt *IntervalTree) Delete(ivl Interval) bool {
  157. z := ivt.find(ivl)
  158. if z == nil {
  159. return false
  160. }
  161. y := z
  162. if z.left != nil && z.right != nil {
  163. y = z.successor()
  164. }
  165. x := y.left
  166. if x == nil {
  167. x = y.right
  168. }
  169. if x != nil {
  170. x.parent = y.parent
  171. }
  172. if y.parent == nil {
  173. ivt.root = x
  174. } else {
  175. if y == y.parent.left {
  176. y.parent.left = x
  177. } else {
  178. y.parent.right = x
  179. }
  180. y.parent.updateMax()
  181. }
  182. if y != z {
  183. z.iv = y.iv
  184. z.updateMax()
  185. }
  186. if y.color() == black && x != nil {
  187. ivt.deleteFixup(x)
  188. }
  189. ivt.count--
  190. return true
  191. }
  192. func (ivt *IntervalTree) deleteFixup(x *intervalNode) {
  193. for x != ivt.root && x.color() == black && x.parent != nil {
  194. if x == x.parent.left {
  195. w := x.parent.right
  196. if w.color() == red {
  197. w.c = black
  198. x.parent.c = red
  199. ivt.rotateLeft(x.parent)
  200. w = x.parent.right
  201. }
  202. if w == nil {
  203. break
  204. }
  205. if w.left.color() == black && w.right.color() == black {
  206. w.c = red
  207. x = x.parent
  208. } else {
  209. if w.right.color() == black {
  210. w.left.c = black
  211. w.c = red
  212. ivt.rotateRight(w)
  213. w = x.parent.right
  214. }
  215. w.c = x.parent.color()
  216. x.parent.c = black
  217. w.right.c = black
  218. ivt.rotateLeft(x.parent)
  219. x = ivt.root
  220. }
  221. } else {
  222. // same as above but with left and right exchanged
  223. w := x.parent.left
  224. if w.color() == red {
  225. w.c = black
  226. x.parent.c = red
  227. ivt.rotateRight(x.parent)
  228. w = x.parent.left
  229. }
  230. if w == nil {
  231. break
  232. }
  233. if w.left.color() == black && w.right.color() == black {
  234. w.c = red
  235. x = x.parent
  236. } else {
  237. if w.left.color() == black {
  238. w.right.c = black
  239. w.c = red
  240. ivt.rotateLeft(w)
  241. w = x.parent.left
  242. }
  243. w.c = x.parent.color()
  244. x.parent.c = black
  245. w.left.c = black
  246. ivt.rotateRight(x.parent)
  247. x = ivt.root
  248. }
  249. }
  250. }
  251. if x != nil {
  252. x.c = black
  253. }
  254. }
  255. // Insert adds a node with the given interval into the tree.
  256. func (ivt *IntervalTree) Insert(ivl Interval, val interface{}) {
  257. var y *intervalNode
  258. z := &intervalNode{iv: IntervalValue{ivl, val}, max: ivl.End, c: red}
  259. x := ivt.root
  260. for x != nil {
  261. y = x
  262. if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
  263. x = x.left
  264. } else {
  265. x = x.right
  266. }
  267. }
  268. z.parent = y
  269. if y == nil {
  270. ivt.root = z
  271. } else {
  272. if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
  273. y.left = z
  274. } else {
  275. y.right = z
  276. }
  277. y.updateMax()
  278. }
  279. z.c = red
  280. ivt.insertFixup(z)
  281. ivt.count++
  282. }
  283. func (ivt *IntervalTree) insertFixup(z *intervalNode) {
  284. for z.parent != nil && z.parent.parent != nil && z.parent.color() == red {
  285. if z.parent == z.parent.parent.left {
  286. y := z.parent.parent.right
  287. if y.color() == red {
  288. y.c = black
  289. z.parent.c = black
  290. z.parent.parent.c = red
  291. z = z.parent.parent
  292. } else {
  293. if z == z.parent.right {
  294. z = z.parent
  295. ivt.rotateLeft(z)
  296. }
  297. z.parent.c = black
  298. z.parent.parent.c = red
  299. ivt.rotateRight(z.parent.parent)
  300. }
  301. } else {
  302. // same as then with left/right exchanged
  303. y := z.parent.parent.left
  304. if y.color() == red {
  305. y.c = black
  306. z.parent.c = black
  307. z.parent.parent.c = red
  308. z = z.parent.parent
  309. } else {
  310. if z == z.parent.left {
  311. z = z.parent
  312. ivt.rotateRight(z)
  313. }
  314. z.parent.c = black
  315. z.parent.parent.c = red
  316. ivt.rotateLeft(z.parent.parent)
  317. }
  318. }
  319. }
  320. ivt.root.c = black
  321. }
  322. // rotateLeft moves x so it is left of its right child
  323. func (ivt *IntervalTree) rotateLeft(x *intervalNode) {
  324. y := x.right
  325. x.right = y.left
  326. if y.left != nil {
  327. y.left.parent = x
  328. }
  329. x.updateMax()
  330. ivt.replaceParent(x, y)
  331. y.left = x
  332. y.updateMax()
  333. }
  334. // rotateLeft moves x so it is right of its left child
  335. func (ivt *IntervalTree) rotateRight(x *intervalNode) {
  336. if x == nil {
  337. return
  338. }
  339. y := x.left
  340. x.left = y.right
  341. if y.right != nil {
  342. y.right.parent = x
  343. }
  344. x.updateMax()
  345. ivt.replaceParent(x, y)
  346. y.right = x
  347. y.updateMax()
  348. }
  349. // replaceParent replaces x's parent with y
  350. func (ivt *IntervalTree) replaceParent(x *intervalNode, y *intervalNode) {
  351. y.parent = x.parent
  352. if x.parent == nil {
  353. ivt.root = y
  354. } else {
  355. if x == x.parent.left {
  356. x.parent.left = y
  357. } else {
  358. x.parent.right = y
  359. }
  360. x.parent.updateMax()
  361. }
  362. x.parent = y
  363. }
  364. // Len gives the number of elements in the tree
  365. func (ivt *IntervalTree) Len() int { return ivt.count }
  366. // Height is the number of levels in the tree; one node has height 1.
  367. func (ivt *IntervalTree) Height() int { return ivt.root.height() }
  368. // MaxHeight is the expected maximum tree height given the number of nodes
  369. func (ivt *IntervalTree) MaxHeight() int {
  370. return int((2 * math.Log2(float64(ivt.Len()+1))) + 0.5)
  371. }
  372. // IntervalVisitor is used on tree searches; return false to stop searching.
  373. type IntervalVisitor func(n *IntervalValue) bool
  374. // Visit calls a visitor function on every tree node intersecting the given interval.
  375. // It will visit each interval [x, y) in ascending order sorted on x.
  376. func (ivt *IntervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
  377. ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
  378. }
  379. // find the exact node for a given interval
  380. func (ivt *IntervalTree) find(ivl Interval) (ret *intervalNode) {
  381. f := func(n *intervalNode) bool {
  382. if n.iv.Ivl != ivl {
  383. return true
  384. }
  385. ret = n
  386. return false
  387. }
  388. ivt.root.visit(&ivl, f)
  389. return ret
  390. }
  391. // Find gets the IntervalValue for the node matching the given interval
  392. func (ivt *IntervalTree) Find(ivl Interval) (ret *IntervalValue) {
  393. n := ivt.find(ivl)
  394. if n == nil {
  395. return nil
  396. }
  397. return &n.iv
  398. }
  399. // Intersects returns true if there is some tree node intersecting the given interval.
  400. func (ivt *IntervalTree) Intersects(iv Interval) bool {
  401. x := ivt.root
  402. for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
  403. if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
  404. x = x.left
  405. } else {
  406. x = x.right
  407. }
  408. }
  409. return x != nil
  410. }
  411. // Contains returns true if the interval tree's keys cover the entire given interval.
  412. func (ivt *IntervalTree) Contains(ivl Interval) bool {
  413. var maxEnd, minBegin Comparable
  414. isContiguous := true
  415. ivt.Visit(ivl, func(n *IntervalValue) bool {
  416. if minBegin == nil {
  417. minBegin = n.Ivl.Begin
  418. maxEnd = n.Ivl.End
  419. return true
  420. }
  421. if maxEnd.Compare(n.Ivl.Begin) < 0 {
  422. isContiguous = false
  423. return false
  424. }
  425. if n.Ivl.End.Compare(maxEnd) > 0 {
  426. maxEnd = n.Ivl.End
  427. }
  428. return true
  429. })
  430. return isContiguous && minBegin != nil && maxEnd.Compare(ivl.End) >= 0 && minBegin.Compare(ivl.Begin) <= 0
  431. }
  432. // Stab returns a slice with all elements in the tree intersecting the interval.
  433. func (ivt *IntervalTree) Stab(iv Interval) (ivs []*IntervalValue) {
  434. if ivt.count == 0 {
  435. return nil
  436. }
  437. f := func(n *IntervalValue) bool { ivs = append(ivs, n); return true }
  438. ivt.Visit(iv, f)
  439. return ivs
  440. }
  441. type StringComparable string
  442. func (s StringComparable) Compare(c Comparable) int {
  443. sc := c.(StringComparable)
  444. if s < sc {
  445. return -1
  446. }
  447. if s > sc {
  448. return 1
  449. }
  450. return 0
  451. }
  452. func NewStringInterval(begin, end string) Interval {
  453. return Interval{StringComparable(begin), StringComparable(end)}
  454. }
  455. func NewStringPoint(s string) Interval {
  456. return Interval{StringComparable(s), StringComparable(s + "\x00")}
  457. }
  458. // StringAffineComparable treats "" as > all other strings
  459. type StringAffineComparable string
  460. func (s StringAffineComparable) Compare(c Comparable) int {
  461. sc := c.(StringAffineComparable)
  462. if len(s) == 0 {
  463. if len(sc) == 0 {
  464. return 0
  465. }
  466. return 1
  467. }
  468. if len(sc) == 0 {
  469. return -1
  470. }
  471. if s < sc {
  472. return -1
  473. }
  474. if s > sc {
  475. return 1
  476. }
  477. return 0
  478. }
  479. func NewStringAffineInterval(begin, end string) Interval {
  480. return Interval{StringAffineComparable(begin), StringAffineComparable(end)}
  481. }
  482. func NewStringAffinePoint(s string) Interval {
  483. return NewStringAffineInterval(s, s+"\x00")
  484. }
  485. func NewInt64Interval(a int64, b int64) Interval {
  486. return Interval{Int64Comparable(a), Int64Comparable(b)}
  487. }
  488. func NewInt64Point(a int64) Interval {
  489. return Interval{Int64Comparable(a), Int64Comparable(a + 1)}
  490. }
  491. type Int64Comparable int64
  492. func (v Int64Comparable) Compare(c Comparable) int {
  493. vc := c.(Int64Comparable)
  494. cmp := v - vc
  495. if cmp < 0 {
  496. return -1
  497. }
  498. if cmp > 0 {
  499. return 1
  500. }
  501. return 0
  502. }