trie.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. package utilities
  2. import (
  3. "sort"
  4. )
  5. // DoubleArray is a Double Array implementation of trie on sequences of strings.
  6. type DoubleArray struct {
  7. // Encoding keeps an encoding from string to int
  8. Encoding map[string]int
  9. // Base is the base array of Double Array
  10. Base []int
  11. // Check is the check array of Double Array
  12. Check []int
  13. }
  14. // NewDoubleArray builds a DoubleArray from a set of sequences of strings.
  15. func NewDoubleArray(seqs [][]string) *DoubleArray {
  16. da := &DoubleArray{Encoding: make(map[string]int)}
  17. if len(seqs) == 0 {
  18. return da
  19. }
  20. encoded := registerTokens(da, seqs)
  21. sort.Sort(byLex(encoded))
  22. root := node{row: -1, col: -1, left: 0, right: len(encoded)}
  23. addSeqs(da, encoded, 0, root)
  24. for i := len(da.Base); i > 0; i-- {
  25. if da.Check[i-1] != 0 {
  26. da.Base = da.Base[:i]
  27. da.Check = da.Check[:i]
  28. break
  29. }
  30. }
  31. return da
  32. }
  33. func registerTokens(da *DoubleArray, seqs [][]string) [][]int {
  34. var result [][]int
  35. for _, seq := range seqs {
  36. var encoded []int
  37. for _, token := range seq {
  38. if _, ok := da.Encoding[token]; !ok {
  39. da.Encoding[token] = len(da.Encoding)
  40. }
  41. encoded = append(encoded, da.Encoding[token])
  42. }
  43. result = append(result, encoded)
  44. }
  45. for i := range result {
  46. result[i] = append(result[i], len(da.Encoding))
  47. }
  48. return result
  49. }
  50. type node struct {
  51. row, col int
  52. left, right int
  53. }
  54. func (n node) value(seqs [][]int) int {
  55. return seqs[n.row][n.col]
  56. }
  57. func (n node) children(seqs [][]int) []*node {
  58. var result []*node
  59. lastVal := int(-1)
  60. last := new(node)
  61. for i := n.left; i < n.right; i++ {
  62. if lastVal == seqs[i][n.col+1] {
  63. continue
  64. }
  65. last.right = i
  66. last = &node{
  67. row: i,
  68. col: n.col + 1,
  69. left: i,
  70. }
  71. result = append(result, last)
  72. }
  73. last.right = n.right
  74. return result
  75. }
  76. func addSeqs(da *DoubleArray, seqs [][]int, pos int, n node) {
  77. ensureSize(da, pos)
  78. children := n.children(seqs)
  79. var i int
  80. for i = 1; ; i++ {
  81. ok := func() bool {
  82. for _, child := range children {
  83. code := child.value(seqs)
  84. j := i + code
  85. ensureSize(da, j)
  86. if da.Check[j] != 0 {
  87. return false
  88. }
  89. }
  90. return true
  91. }()
  92. if ok {
  93. break
  94. }
  95. }
  96. da.Base[pos] = i
  97. for _, child := range children {
  98. code := child.value(seqs)
  99. j := i + code
  100. da.Check[j] = pos + 1
  101. }
  102. terminator := len(da.Encoding)
  103. for _, child := range children {
  104. code := child.value(seqs)
  105. if code == terminator {
  106. continue
  107. }
  108. j := i + code
  109. addSeqs(da, seqs, j, *child)
  110. }
  111. }
  112. func ensureSize(da *DoubleArray, i int) {
  113. for i >= len(da.Base) {
  114. da.Base = append(da.Base, make([]int, len(da.Base)+1)...)
  115. da.Check = append(da.Check, make([]int, len(da.Check)+1)...)
  116. }
  117. }
  118. type byLex [][]int
  119. func (l byLex) Len() int { return len(l) }
  120. func (l byLex) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
  121. func (l byLex) Less(i, j int) bool {
  122. si := l[i]
  123. sj := l[j]
  124. var k int
  125. for k = 0; k < len(si) && k < len(sj); k++ {
  126. if si[k] < sj[k] {
  127. return true
  128. }
  129. if si[k] > sj[k] {
  130. return false
  131. }
  132. }
  133. if k < len(sj) {
  134. return true
  135. }
  136. return false
  137. }
  138. // HasCommonPrefix determines if any sequence in the DoubleArray is a prefix of the given sequence.
  139. func (da *DoubleArray) HasCommonPrefix(seq []string) bool {
  140. if len(da.Base) == 0 {
  141. return false
  142. }
  143. var i int
  144. for _, t := range seq {
  145. code, ok := da.Encoding[t]
  146. if !ok {
  147. break
  148. }
  149. j := da.Base[i] + code
  150. if len(da.Check) <= j || da.Check[j] != i+1 {
  151. break
  152. }
  153. i = j
  154. }
  155. j := da.Base[i] + len(da.Encoding)
  156. if len(da.Check) <= j || da.Check[j] != i+1 {
  157. return false
  158. }
  159. return true
  160. }