compress.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. package huff0
  2. import (
  3. "fmt"
  4. "runtime"
  5. "sync"
  6. )
  7. // Compress1X will compress the input.
  8. // The output can be decoded using Decompress1X.
  9. // Supply a Scratch object. The scratch object contains state about re-use,
  10. // So when sharing across independent encodes, be sure to set the re-use policy.
  11. func Compress1X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
  12. s, err = s.prepare(in)
  13. if err != nil {
  14. return nil, false, err
  15. }
  16. return compress(in, s, s.compress1X)
  17. }
  18. // Compress4X will compress the input. The input is split into 4 independent blocks
  19. // and compressed similar to Compress1X.
  20. // The output can be decoded using Decompress4X.
  21. // Supply a Scratch object. The scratch object contains state about re-use,
  22. // So when sharing across independent encodes, be sure to set the re-use policy.
  23. func Compress4X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
  24. s, err = s.prepare(in)
  25. if err != nil {
  26. return nil, false, err
  27. }
  28. if false {
  29. // TODO: compress4Xp only slightly faster.
  30. const parallelThreshold = 8 << 10
  31. if len(in) < parallelThreshold || runtime.GOMAXPROCS(0) == 1 {
  32. return compress(in, s, s.compress4X)
  33. }
  34. return compress(in, s, s.compress4Xp)
  35. }
  36. return compress(in, s, s.compress4X)
  37. }
  38. func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)) (out []byte, reUsed bool, err error) {
  39. // Nuke previous table if we cannot reuse anyway.
  40. if s.Reuse == ReusePolicyNone {
  41. s.prevTable = s.prevTable[:0]
  42. }
  43. // Create histogram, if none was provided.
  44. maxCount := s.maxCount
  45. var canReuse = false
  46. if maxCount == 0 {
  47. maxCount, canReuse = s.countSimple(in)
  48. } else {
  49. canReuse = s.canUseTable(s.prevTable)
  50. }
  51. // We want the output size to be less than this:
  52. wantSize := len(in)
  53. if s.WantLogLess > 0 {
  54. wantSize -= wantSize >> s.WantLogLess
  55. }
  56. // Reset for next run.
  57. s.clearCount = true
  58. s.maxCount = 0
  59. if maxCount >= len(in) {
  60. if maxCount > len(in) {
  61. return nil, false, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in))
  62. }
  63. if len(in) == 1 {
  64. return nil, false, ErrIncompressible
  65. }
  66. // One symbol, use RLE
  67. return nil, false, ErrUseRLE
  68. }
  69. if maxCount == 1 || maxCount < (len(in)>>7) {
  70. // Each symbol present maximum once or too well distributed.
  71. return nil, false, ErrIncompressible
  72. }
  73. if s.Reuse == ReusePolicyPrefer && canReuse {
  74. keepTable := s.cTable
  75. keepTL := s.actualTableLog
  76. s.cTable = s.prevTable
  77. s.actualTableLog = s.prevTableLog
  78. s.Out, err = compressor(in)
  79. s.cTable = keepTable
  80. s.actualTableLog = keepTL
  81. if err == nil && len(s.Out) < wantSize {
  82. s.OutData = s.Out
  83. return s.Out, true, nil
  84. }
  85. // Do not attempt to re-use later.
  86. s.prevTable = s.prevTable[:0]
  87. }
  88. // Calculate new table.
  89. err = s.buildCTable()
  90. if err != nil {
  91. return nil, false, err
  92. }
  93. if false && !s.canUseTable(s.cTable) {
  94. panic("invalid table generated")
  95. }
  96. if s.Reuse == ReusePolicyAllow && canReuse {
  97. hSize := len(s.Out)
  98. oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen])
  99. newSize := s.cTable.estimateSize(s.count[:s.symbolLen])
  100. if oldSize <= hSize+newSize || hSize+12 >= wantSize {
  101. // Retain cTable even if we re-use.
  102. keepTable := s.cTable
  103. keepTL := s.actualTableLog
  104. s.cTable = s.prevTable
  105. s.actualTableLog = s.prevTableLog
  106. s.Out, err = compressor(in)
  107. // Restore ctable.
  108. s.cTable = keepTable
  109. s.actualTableLog = keepTL
  110. if err != nil {
  111. return nil, false, err
  112. }
  113. if len(s.Out) >= wantSize {
  114. return nil, false, ErrIncompressible
  115. }
  116. s.OutData = s.Out
  117. return s.Out, true, nil
  118. }
  119. }
  120. // Use new table
  121. err = s.cTable.write(s)
  122. if err != nil {
  123. s.OutTable = nil
  124. return nil, false, err
  125. }
  126. s.OutTable = s.Out
  127. // Compress using new table
  128. s.Out, err = compressor(in)
  129. if err != nil {
  130. s.OutTable = nil
  131. return nil, false, err
  132. }
  133. if len(s.Out) >= wantSize {
  134. s.OutTable = nil
  135. return nil, false, ErrIncompressible
  136. }
  137. // Move current table into previous.
  138. s.prevTable, s.prevTableLog, s.cTable = s.cTable, s.actualTableLog, s.prevTable[:0]
  139. s.OutData = s.Out[len(s.OutTable):]
  140. return s.Out, false, nil
  141. }
  142. func (s *Scratch) compress1X(src []byte) ([]byte, error) {
  143. return s.compress1xDo(s.Out, src)
  144. }
  145. func (s *Scratch) compress1xDo(dst, src []byte) ([]byte, error) {
  146. var bw = bitWriter{out: dst}
  147. // N is length divisible by 4.
  148. n := len(src)
  149. n -= n & 3
  150. cTable := s.cTable[:256]
  151. // Encode last bytes.
  152. for i := len(src) & 3; i > 0; i-- {
  153. bw.encSymbol(cTable, src[n+i-1])
  154. }
  155. n -= 4
  156. if s.actualTableLog <= 8 {
  157. for ; n >= 0; n -= 4 {
  158. tmp := src[n : n+4]
  159. // tmp should be len 4
  160. bw.flush32()
  161. bw.encTwoSymbols(cTable, tmp[3], tmp[2])
  162. bw.encTwoSymbols(cTable, tmp[1], tmp[0])
  163. }
  164. } else {
  165. for ; n >= 0; n -= 4 {
  166. tmp := src[n : n+4]
  167. // tmp should be len 4
  168. bw.flush32()
  169. bw.encTwoSymbols(cTable, tmp[3], tmp[2])
  170. bw.flush32()
  171. bw.encTwoSymbols(cTable, tmp[1], tmp[0])
  172. }
  173. }
  174. err := bw.close()
  175. return bw.out, err
  176. }
  177. var sixZeros [6]byte
  178. func (s *Scratch) compress4X(src []byte) ([]byte, error) {
  179. if len(src) < 12 {
  180. return nil, ErrIncompressible
  181. }
  182. segmentSize := (len(src) + 3) / 4
  183. // Add placeholder for output length
  184. offsetIdx := len(s.Out)
  185. s.Out = append(s.Out, sixZeros[:]...)
  186. for i := 0; i < 4; i++ {
  187. toDo := src
  188. if len(toDo) > segmentSize {
  189. toDo = toDo[:segmentSize]
  190. }
  191. src = src[len(toDo):]
  192. var err error
  193. idx := len(s.Out)
  194. s.Out, err = s.compress1xDo(s.Out, toDo)
  195. if err != nil {
  196. return nil, err
  197. }
  198. // Write compressed length as little endian before block.
  199. if i < 3 {
  200. // Last length is not written.
  201. length := len(s.Out) - idx
  202. s.Out[i*2+offsetIdx] = byte(length)
  203. s.Out[i*2+offsetIdx+1] = byte(length >> 8)
  204. }
  205. }
  206. return s.Out, nil
  207. }
  208. // compress4Xp will compress 4 streams using separate goroutines.
  209. func (s *Scratch) compress4Xp(src []byte) ([]byte, error) {
  210. if len(src) < 12 {
  211. return nil, ErrIncompressible
  212. }
  213. // Add placeholder for output length
  214. s.Out = s.Out[:6]
  215. segmentSize := (len(src) + 3) / 4
  216. var wg sync.WaitGroup
  217. var errs [4]error
  218. wg.Add(4)
  219. for i := 0; i < 4; i++ {
  220. toDo := src
  221. if len(toDo) > segmentSize {
  222. toDo = toDo[:segmentSize]
  223. }
  224. src = src[len(toDo):]
  225. // Separate goroutine for each block.
  226. go func(i int) {
  227. s.tmpOut[i], errs[i] = s.compress1xDo(s.tmpOut[i][:0], toDo)
  228. wg.Done()
  229. }(i)
  230. }
  231. wg.Wait()
  232. for i := 0; i < 4; i++ {
  233. if errs[i] != nil {
  234. return nil, errs[i]
  235. }
  236. o := s.tmpOut[i]
  237. // Write compressed length as little endian before block.
  238. if i < 3 {
  239. // Last length is not written.
  240. s.Out[i*2] = byte(len(o))
  241. s.Out[i*2+1] = byte(len(o) >> 8)
  242. }
  243. // Write output.
  244. s.Out = append(s.Out, o...)
  245. }
  246. return s.Out, nil
  247. }
  248. // countSimple will create a simple histogram in s.count.
  249. // Returns the biggest count.
  250. // Does not update s.clearCount.
  251. func (s *Scratch) countSimple(in []byte) (max int, reuse bool) {
  252. reuse = true
  253. for _, v := range in {
  254. s.count[v]++
  255. }
  256. m := uint32(0)
  257. if len(s.prevTable) > 0 {
  258. for i, v := range s.count[:] {
  259. if v > m {
  260. m = v
  261. }
  262. if v > 0 {
  263. s.symbolLen = uint16(i) + 1
  264. if i >= len(s.prevTable) {
  265. reuse = false
  266. } else {
  267. if s.prevTable[i].nBits == 0 {
  268. reuse = false
  269. }
  270. }
  271. }
  272. }
  273. return int(m), reuse
  274. }
  275. for i, v := range s.count[:] {
  276. if v > m {
  277. m = v
  278. }
  279. if v > 0 {
  280. s.symbolLen = uint16(i) + 1
  281. }
  282. }
  283. return int(m), false
  284. }
  285. func (s *Scratch) canUseTable(c cTable) bool {
  286. if len(c) < int(s.symbolLen) {
  287. return false
  288. }
  289. for i, v := range s.count[:s.symbolLen] {
  290. if v != 0 && c[i].nBits == 0 {
  291. return false
  292. }
  293. }
  294. return true
  295. }
  296. func (s *Scratch) validateTable(c cTable) bool {
  297. if len(c) < int(s.symbolLen) {
  298. return false
  299. }
  300. for i, v := range s.count[:s.symbolLen] {
  301. if v != 0 {
  302. if c[i].nBits == 0 {
  303. return false
  304. }
  305. if c[i].nBits > s.actualTableLog {
  306. return false
  307. }
  308. }
  309. }
  310. return true
  311. }
  312. // minTableLog provides the minimum logSize to safely represent a distribution.
  313. func (s *Scratch) minTableLog() uint8 {
  314. minBitsSrc := highBit32(uint32(s.br.remain())) + 1
  315. minBitsSymbols := highBit32(uint32(s.symbolLen-1)) + 2
  316. if minBitsSrc < minBitsSymbols {
  317. return uint8(minBitsSrc)
  318. }
  319. return uint8(minBitsSymbols)
  320. }
  321. // optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
  322. func (s *Scratch) optimalTableLog() {
  323. tableLog := s.TableLog
  324. minBits := s.minTableLog()
  325. maxBitsSrc := uint8(highBit32(uint32(s.br.remain()-1))) - 1
  326. if maxBitsSrc < tableLog {
  327. // Accuracy can be reduced
  328. tableLog = maxBitsSrc
  329. }
  330. if minBits > tableLog {
  331. tableLog = minBits
  332. }
  333. // Need a minimum to safely represent all symbol values
  334. if tableLog < minTablelog {
  335. tableLog = minTablelog
  336. }
  337. if tableLog > tableLogMax {
  338. tableLog = tableLogMax
  339. }
  340. s.actualTableLog = tableLog
  341. }
  342. type cTableEntry struct {
  343. val uint16
  344. nBits uint8
  345. // We have 8 bits extra
  346. }
  347. const huffNodesMask = huffNodesLen - 1
  348. func (s *Scratch) buildCTable() error {
  349. s.optimalTableLog()
  350. s.huffSort()
  351. if cap(s.cTable) < maxSymbolValue+1 {
  352. s.cTable = make([]cTableEntry, s.symbolLen, maxSymbolValue+1)
  353. } else {
  354. s.cTable = s.cTable[:s.symbolLen]
  355. for i := range s.cTable {
  356. s.cTable[i] = cTableEntry{}
  357. }
  358. }
  359. var startNode = int16(s.symbolLen)
  360. nonNullRank := s.symbolLen - 1
  361. nodeNb := int16(startNode)
  362. huffNode := s.nodes[1 : huffNodesLen+1]
  363. // This overlays the slice above, but allows "-1" index lookups.
  364. // Different from reference implementation.
  365. huffNode0 := s.nodes[0 : huffNodesLen+1]
  366. for huffNode[nonNullRank].count == 0 {
  367. nonNullRank--
  368. }
  369. lowS := int16(nonNullRank)
  370. nodeRoot := nodeNb + lowS - 1
  371. lowN := nodeNb
  372. huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS-1].count
  373. huffNode[lowS].parent, huffNode[lowS-1].parent = uint16(nodeNb), uint16(nodeNb)
  374. nodeNb++
  375. lowS -= 2
  376. for n := nodeNb; n <= nodeRoot; n++ {
  377. huffNode[n].count = 1 << 30
  378. }
  379. // fake entry, strong barrier
  380. huffNode0[0].count = 1 << 31
  381. // create parents
  382. for nodeNb <= nodeRoot {
  383. var n1, n2 int16
  384. if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
  385. n1 = lowS
  386. lowS--
  387. } else {
  388. n1 = lowN
  389. lowN++
  390. }
  391. if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
  392. n2 = lowS
  393. lowS--
  394. } else {
  395. n2 = lowN
  396. lowN++
  397. }
  398. huffNode[nodeNb].count = huffNode0[n1+1].count + huffNode0[n2+1].count
  399. huffNode0[n1+1].parent, huffNode0[n2+1].parent = uint16(nodeNb), uint16(nodeNb)
  400. nodeNb++
  401. }
  402. // distribute weights (unlimited tree height)
  403. huffNode[nodeRoot].nbBits = 0
  404. for n := nodeRoot - 1; n >= startNode; n-- {
  405. huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
  406. }
  407. for n := uint16(0); n <= nonNullRank; n++ {
  408. huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
  409. }
  410. s.actualTableLog = s.setMaxHeight(int(nonNullRank))
  411. maxNbBits := s.actualTableLog
  412. // fill result into tree (val, nbBits)
  413. if maxNbBits > tableLogMax {
  414. return fmt.Errorf("internal error: maxNbBits (%d) > tableLogMax (%d)", maxNbBits, tableLogMax)
  415. }
  416. var nbPerRank [tableLogMax + 1]uint16
  417. var valPerRank [16]uint16
  418. for _, v := range huffNode[:nonNullRank+1] {
  419. nbPerRank[v.nbBits]++
  420. }
  421. // determine stating value per rank
  422. {
  423. min := uint16(0)
  424. for n := maxNbBits; n > 0; n-- {
  425. // get starting value within each rank
  426. valPerRank[n] = min
  427. min += nbPerRank[n]
  428. min >>= 1
  429. }
  430. }
  431. // push nbBits per symbol, symbol order
  432. for _, v := range huffNode[:nonNullRank+1] {
  433. s.cTable[v.symbol].nBits = v.nbBits
  434. }
  435. // assign value within rank, symbol order
  436. t := s.cTable[:s.symbolLen]
  437. for n, val := range t {
  438. nbits := val.nBits & 15
  439. v := valPerRank[nbits]
  440. t[n].val = v
  441. valPerRank[nbits] = v + 1
  442. }
  443. return nil
  444. }
  445. // huffSort will sort symbols, decreasing order.
  446. func (s *Scratch) huffSort() {
  447. type rankPos struct {
  448. base uint32
  449. current uint32
  450. }
  451. // Clear nodes
  452. nodes := s.nodes[:huffNodesLen+1]
  453. s.nodes = nodes
  454. nodes = nodes[1 : huffNodesLen+1]
  455. // Sort into buckets based on length of symbol count.
  456. var rank [32]rankPos
  457. for _, v := range s.count[:s.symbolLen] {
  458. r := highBit32(v+1) & 31
  459. rank[r].base++
  460. }
  461. // maxBitLength is log2(BlockSizeMax) + 1
  462. const maxBitLength = 18 + 1
  463. for n := maxBitLength; n > 0; n-- {
  464. rank[n-1].base += rank[n].base
  465. }
  466. for n := range rank[:maxBitLength] {
  467. rank[n].current = rank[n].base
  468. }
  469. for n, c := range s.count[:s.symbolLen] {
  470. r := (highBit32(c+1) + 1) & 31
  471. pos := rank[r].current
  472. rank[r].current++
  473. prev := nodes[(pos-1)&huffNodesMask]
  474. for pos > rank[r].base && c > prev.count {
  475. nodes[pos&huffNodesMask] = prev
  476. pos--
  477. prev = nodes[(pos-1)&huffNodesMask]
  478. }
  479. nodes[pos&huffNodesMask] = nodeElt{count: c, symbol: byte(n)}
  480. }
  481. return
  482. }
  483. func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
  484. maxNbBits := s.actualTableLog
  485. huffNode := s.nodes[1 : huffNodesLen+1]
  486. //huffNode = huffNode[: huffNodesLen]
  487. largestBits := huffNode[lastNonNull].nbBits
  488. // early exit : no elt > maxNbBits
  489. if largestBits <= maxNbBits {
  490. return largestBits
  491. }
  492. totalCost := int(0)
  493. baseCost := int(1) << (largestBits - maxNbBits)
  494. n := uint32(lastNonNull)
  495. for huffNode[n].nbBits > maxNbBits {
  496. totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits))
  497. huffNode[n].nbBits = maxNbBits
  498. n--
  499. }
  500. // n stops at huffNode[n].nbBits <= maxNbBits
  501. for huffNode[n].nbBits == maxNbBits {
  502. n--
  503. }
  504. // n end at index of smallest symbol using < maxNbBits
  505. // renorm totalCost
  506. totalCost >>= largestBits - maxNbBits /* note : totalCost is necessarily a multiple of baseCost */
  507. // repay normalized cost
  508. {
  509. const noSymbol = 0xF0F0F0F0
  510. var rankLast [tableLogMax + 2]uint32
  511. for i := range rankLast[:] {
  512. rankLast[i] = noSymbol
  513. }
  514. // Get pos of last (smallest) symbol per rank
  515. {
  516. currentNbBits := uint8(maxNbBits)
  517. for pos := int(n); pos >= 0; pos-- {
  518. if huffNode[pos].nbBits >= currentNbBits {
  519. continue
  520. }
  521. currentNbBits = huffNode[pos].nbBits // < maxNbBits
  522. rankLast[maxNbBits-currentNbBits] = uint32(pos)
  523. }
  524. }
  525. for totalCost > 0 {
  526. nBitsToDecrease := uint8(highBit32(uint32(totalCost))) + 1
  527. for ; nBitsToDecrease > 1; nBitsToDecrease-- {
  528. highPos := rankLast[nBitsToDecrease]
  529. lowPos := rankLast[nBitsToDecrease-1]
  530. if highPos == noSymbol {
  531. continue
  532. }
  533. if lowPos == noSymbol {
  534. break
  535. }
  536. highTotal := huffNode[highPos].count
  537. lowTotal := 2 * huffNode[lowPos].count
  538. if highTotal <= lowTotal {
  539. break
  540. }
  541. }
  542. // only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
  543. // HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
  544. // FIXME: try to remove
  545. for (nBitsToDecrease <= tableLogMax) && (rankLast[nBitsToDecrease] == noSymbol) {
  546. nBitsToDecrease++
  547. }
  548. totalCost -= 1 << (nBitsToDecrease - 1)
  549. if rankLast[nBitsToDecrease-1] == noSymbol {
  550. // this rank is no longer empty
  551. rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]
  552. }
  553. huffNode[rankLast[nBitsToDecrease]].nbBits++
  554. if rankLast[nBitsToDecrease] == 0 {
  555. /* special case, reached largest symbol */
  556. rankLast[nBitsToDecrease] = noSymbol
  557. } else {
  558. rankLast[nBitsToDecrease]--
  559. if huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease {
  560. rankLast[nBitsToDecrease] = noSymbol /* this rank is now empty */
  561. }
  562. }
  563. }
  564. for totalCost < 0 { /* Sometimes, cost correction overshoot */
  565. if rankLast[1] == noSymbol { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */
  566. for huffNode[n].nbBits == maxNbBits {
  567. n--
  568. }
  569. huffNode[n+1].nbBits--
  570. rankLast[1] = n + 1
  571. totalCost++
  572. continue
  573. }
  574. huffNode[rankLast[1]+1].nbBits--
  575. rankLast[1]++
  576. totalCost++
  577. }
  578. }
  579. return maxNbBits
  580. }
  581. type nodeElt struct {
  582. count uint32
  583. parent uint16
  584. symbol byte
  585. nbBits uint8
  586. }