fse_encoder.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "errors"
  7. "fmt"
  8. "math"
  9. )
  10. const (
  11. // For encoding we only support up to
  12. maxEncTableLog = 8
  13. maxEncTablesize = 1 << maxTableLog
  14. maxEncTableMask = (1 << maxTableLog) - 1
  15. minEncTablelog = 5
  16. maxEncSymbolValue = maxMatchLengthSymbol
  17. )
  18. // Scratch provides temporary storage for compression and decompression.
  19. type fseEncoder struct {
  20. symbolLen uint16 // Length of active part of the symbol table.
  21. actualTableLog uint8 // Selected tablelog.
  22. ct cTable // Compression tables.
  23. maxCount int // count of the most probable symbol
  24. zeroBits bool // no bits has prob > 50%.
  25. clearCount bool // clear count
  26. useRLE bool // This encoder is for RLE
  27. preDefined bool // This encoder is predefined.
  28. reUsed bool // Set to know when the encoder has been reused.
  29. rleVal uint8 // RLE Symbol
  30. maxBits uint8 // Maximum output bits after transform.
  31. // TODO: Technically zstd should be fine with 64 bytes.
  32. count [256]uint32
  33. norm [256]int16
  34. }
  35. // cTable contains tables used for compression.
  36. type cTable struct {
  37. tableSymbol []byte
  38. stateTable []uint16
  39. symbolTT []symbolTransform
  40. }
  41. // symbolTransform contains the state transform for a symbol.
  42. type symbolTransform struct {
  43. deltaNbBits uint32
  44. deltaFindState int16
  45. outBits uint8
  46. }
  47. // String prints values as a human readable string.
  48. func (s symbolTransform) String() string {
  49. return fmt.Sprintf("{deltabits: %08x, findstate:%d outbits:%d}", s.deltaNbBits, s.deltaFindState, s.outBits)
  50. }
  51. // Histogram allows to populate the histogram and skip that step in the compression,
  52. // It otherwise allows to inspect the histogram when compression is done.
  53. // To indicate that you have populated the histogram call HistogramFinished
  54. // with the value of the highest populated symbol, as well as the number of entries
  55. // in the most populated entry. These are accepted at face value.
  56. // The returned slice will always be length 256.
  57. func (s *fseEncoder) Histogram() []uint32 {
  58. return s.count[:]
  59. }
  60. // HistogramFinished can be called to indicate that the histogram has been populated.
  61. // maxSymbol is the index of the highest set symbol of the next data segment.
  62. // maxCount is the number of entries in the most populated entry.
  63. // These are accepted at face value.
  64. func (s *fseEncoder) HistogramFinished(maxSymbol uint8, maxCount int) {
  65. s.maxCount = maxCount
  66. s.symbolLen = uint16(maxSymbol) + 1
  67. s.clearCount = maxCount != 0
  68. }
  69. // prepare will prepare and allocate scratch tables used for both compression and decompression.
  70. func (s *fseEncoder) prepare() (*fseEncoder, error) {
  71. if s == nil {
  72. s = &fseEncoder{}
  73. }
  74. s.useRLE = false
  75. if s.clearCount && s.maxCount == 0 {
  76. for i := range s.count {
  77. s.count[i] = 0
  78. }
  79. s.clearCount = false
  80. }
  81. return s, nil
  82. }
  83. // allocCtable will allocate tables needed for compression.
  84. // If existing tables a re big enough, they are simply re-used.
  85. func (s *fseEncoder) allocCtable() {
  86. tableSize := 1 << s.actualTableLog
  87. // get tableSymbol that is big enough.
  88. if cap(s.ct.tableSymbol) < int(tableSize) {
  89. s.ct.tableSymbol = make([]byte, tableSize)
  90. }
  91. s.ct.tableSymbol = s.ct.tableSymbol[:tableSize]
  92. ctSize := tableSize
  93. if cap(s.ct.stateTable) < ctSize {
  94. s.ct.stateTable = make([]uint16, ctSize)
  95. }
  96. s.ct.stateTable = s.ct.stateTable[:ctSize]
  97. if cap(s.ct.symbolTT) < 256 {
  98. s.ct.symbolTT = make([]symbolTransform, 256)
  99. }
  100. s.ct.symbolTT = s.ct.symbolTT[:256]
  101. }
  102. // buildCTable will populate the compression table so it is ready to be used.
  103. func (s *fseEncoder) buildCTable() error {
  104. tableSize := uint32(1 << s.actualTableLog)
  105. highThreshold := tableSize - 1
  106. var cumul [256]int16
  107. s.allocCtable()
  108. tableSymbol := s.ct.tableSymbol[:tableSize]
  109. // symbol start positions
  110. {
  111. cumul[0] = 0
  112. for ui, v := range s.norm[:s.symbolLen-1] {
  113. u := byte(ui) // one less than reference
  114. if v == -1 {
  115. // Low proba symbol
  116. cumul[u+1] = cumul[u] + 1
  117. tableSymbol[highThreshold] = u
  118. highThreshold--
  119. } else {
  120. cumul[u+1] = cumul[u] + v
  121. }
  122. }
  123. // Encode last symbol separately to avoid overflowing u
  124. u := int(s.symbolLen - 1)
  125. v := s.norm[s.symbolLen-1]
  126. if v == -1 {
  127. // Low proba symbol
  128. cumul[u+1] = cumul[u] + 1
  129. tableSymbol[highThreshold] = byte(u)
  130. highThreshold--
  131. } else {
  132. cumul[u+1] = cumul[u] + v
  133. }
  134. if uint32(cumul[s.symbolLen]) != tableSize {
  135. return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", cumul[s.symbolLen], tableSize)
  136. }
  137. cumul[s.symbolLen] = int16(tableSize) + 1
  138. }
  139. // Spread symbols
  140. s.zeroBits = false
  141. {
  142. step := tableStep(tableSize)
  143. tableMask := tableSize - 1
  144. var position uint32
  145. // if any symbol > largeLimit, we may have 0 bits output.
  146. largeLimit := int16(1 << (s.actualTableLog - 1))
  147. for ui, v := range s.norm[:s.symbolLen] {
  148. symbol := byte(ui)
  149. if v > largeLimit {
  150. s.zeroBits = true
  151. }
  152. for nbOccurrences := int16(0); nbOccurrences < v; nbOccurrences++ {
  153. tableSymbol[position] = symbol
  154. position = (position + step) & tableMask
  155. for position > highThreshold {
  156. position = (position + step) & tableMask
  157. } /* Low proba area */
  158. }
  159. }
  160. // Check if we have gone through all positions
  161. if position != 0 {
  162. return errors.New("position!=0")
  163. }
  164. }
  165. // Build table
  166. table := s.ct.stateTable
  167. {
  168. tsi := int(tableSize)
  169. for u, v := range tableSymbol {
  170. // TableU16 : sorted by symbol order; gives next state value
  171. table[cumul[v]] = uint16(tsi + u)
  172. cumul[v]++
  173. }
  174. }
  175. // Build Symbol Transformation Table
  176. {
  177. total := int16(0)
  178. symbolTT := s.ct.symbolTT[:s.symbolLen]
  179. tableLog := s.actualTableLog
  180. tl := (uint32(tableLog) << 16) - (1 << tableLog)
  181. for i, v := range s.norm[:s.symbolLen] {
  182. switch v {
  183. case 0:
  184. case -1, 1:
  185. symbolTT[i].deltaNbBits = tl
  186. symbolTT[i].deltaFindState = int16(total - 1)
  187. total++
  188. default:
  189. maxBitsOut := uint32(tableLog) - highBit(uint32(v-1))
  190. minStatePlus := uint32(v) << maxBitsOut
  191. symbolTT[i].deltaNbBits = (maxBitsOut << 16) - minStatePlus
  192. symbolTT[i].deltaFindState = int16(total - v)
  193. total += v
  194. }
  195. }
  196. if total != int16(tableSize) {
  197. return fmt.Errorf("total mismatch %d (got) != %d (want)", total, tableSize)
  198. }
  199. }
  200. return nil
  201. }
  202. var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}
  203. func (s *fseEncoder) setRLE(val byte) {
  204. s.allocCtable()
  205. s.actualTableLog = 0
  206. s.ct.stateTable = s.ct.stateTable[:1]
  207. s.ct.symbolTT[val] = symbolTransform{
  208. deltaFindState: 0,
  209. deltaNbBits: 0,
  210. }
  211. if debug {
  212. println("setRLE: val", val, "symbolTT", s.ct.symbolTT[val])
  213. }
  214. s.rleVal = val
  215. s.useRLE = true
  216. }
  217. // setBits will set output bits for the transform.
  218. // if nil is provided, the number of bits is equal to the index.
  219. func (s *fseEncoder) setBits(transform []byte) {
  220. if s.reUsed || s.preDefined {
  221. return
  222. }
  223. if s.useRLE {
  224. if transform == nil {
  225. s.ct.symbolTT[s.rleVal].outBits = s.rleVal
  226. s.maxBits = s.rleVal
  227. return
  228. }
  229. s.maxBits = transform[s.rleVal]
  230. s.ct.symbolTT[s.rleVal].outBits = s.maxBits
  231. return
  232. }
  233. if transform == nil {
  234. for i := range s.ct.symbolTT[:s.symbolLen] {
  235. s.ct.symbolTT[i].outBits = uint8(i)
  236. }
  237. s.maxBits = uint8(s.symbolLen - 1)
  238. return
  239. }
  240. s.maxBits = 0
  241. for i, v := range transform[:s.symbolLen] {
  242. s.ct.symbolTT[i].outBits = v
  243. if v > s.maxBits {
  244. // We could assume bits always going up, but we play safe.
  245. s.maxBits = v
  246. }
  247. }
  248. }
  249. // normalizeCount will normalize the count of the symbols so
  250. // the total is equal to the table size.
  251. // If successful, compression tables will also be made ready.
  252. func (s *fseEncoder) normalizeCount(length int) error {
  253. if s.reUsed {
  254. return nil
  255. }
  256. s.optimalTableLog(length)
  257. var (
  258. tableLog = s.actualTableLog
  259. scale = 62 - uint64(tableLog)
  260. step = (1 << 62) / uint64(length)
  261. vStep = uint64(1) << (scale - 20)
  262. stillToDistribute = int16(1 << tableLog)
  263. largest int
  264. largestP int16
  265. lowThreshold = (uint32)(length >> tableLog)
  266. )
  267. if s.maxCount == length {
  268. s.useRLE = true
  269. return nil
  270. }
  271. s.useRLE = false
  272. for i, cnt := range s.count[:s.symbolLen] {
  273. // already handled
  274. // if (count[s] == s.length) return 0; /* rle special case */
  275. if cnt == 0 {
  276. s.norm[i] = 0
  277. continue
  278. }
  279. if cnt <= lowThreshold {
  280. s.norm[i] = -1
  281. stillToDistribute--
  282. } else {
  283. proba := (int16)((uint64(cnt) * step) >> scale)
  284. if proba < 8 {
  285. restToBeat := vStep * uint64(rtbTable[proba])
  286. v := uint64(cnt)*step - (uint64(proba) << scale)
  287. if v > restToBeat {
  288. proba++
  289. }
  290. }
  291. if proba > largestP {
  292. largestP = proba
  293. largest = i
  294. }
  295. s.norm[i] = proba
  296. stillToDistribute -= proba
  297. }
  298. }
  299. if -stillToDistribute >= (s.norm[largest] >> 1) {
  300. // corner case, need another normalization method
  301. err := s.normalizeCount2(length)
  302. if err != nil {
  303. return err
  304. }
  305. if debugAsserts {
  306. err = s.validateNorm()
  307. if err != nil {
  308. return err
  309. }
  310. }
  311. return s.buildCTable()
  312. }
  313. s.norm[largest] += stillToDistribute
  314. if debugAsserts {
  315. err := s.validateNorm()
  316. if err != nil {
  317. return err
  318. }
  319. }
  320. return s.buildCTable()
  321. }
  322. // Secondary normalization method.
  323. // To be used when primary method fails.
  324. func (s *fseEncoder) normalizeCount2(length int) error {
  325. const notYetAssigned = -2
  326. var (
  327. distributed uint32
  328. total = uint32(length)
  329. tableLog = s.actualTableLog
  330. lowThreshold = uint32(total >> tableLog)
  331. lowOne = uint32((total * 3) >> (tableLog + 1))
  332. )
  333. for i, cnt := range s.count[:s.symbolLen] {
  334. if cnt == 0 {
  335. s.norm[i] = 0
  336. continue
  337. }
  338. if cnt <= lowThreshold {
  339. s.norm[i] = -1
  340. distributed++
  341. total -= cnt
  342. continue
  343. }
  344. if cnt <= lowOne {
  345. s.norm[i] = 1
  346. distributed++
  347. total -= cnt
  348. continue
  349. }
  350. s.norm[i] = notYetAssigned
  351. }
  352. toDistribute := (1 << tableLog) - distributed
  353. if (total / toDistribute) > lowOne {
  354. // risk of rounding to zero
  355. lowOne = uint32((total * 3) / (toDistribute * 2))
  356. for i, cnt := range s.count[:s.symbolLen] {
  357. if (s.norm[i] == notYetAssigned) && (cnt <= lowOne) {
  358. s.norm[i] = 1
  359. distributed++
  360. total -= cnt
  361. continue
  362. }
  363. }
  364. toDistribute = (1 << tableLog) - distributed
  365. }
  366. if distributed == uint32(s.symbolLen)+1 {
  367. // all values are pretty poor;
  368. // probably incompressible data (should have already been detected);
  369. // find max, then give all remaining points to max
  370. var maxV int
  371. var maxC uint32
  372. for i, cnt := range s.count[:s.symbolLen] {
  373. if cnt > maxC {
  374. maxV = i
  375. maxC = cnt
  376. }
  377. }
  378. s.norm[maxV] += int16(toDistribute)
  379. return nil
  380. }
  381. if total == 0 {
  382. // all of the symbols were low enough for the lowOne or lowThreshold
  383. for i := uint32(0); toDistribute > 0; i = (i + 1) % (uint32(s.symbolLen)) {
  384. if s.norm[i] > 0 {
  385. toDistribute--
  386. s.norm[i]++
  387. }
  388. }
  389. return nil
  390. }
  391. var (
  392. vStepLog = 62 - uint64(tableLog)
  393. mid = uint64((1 << (vStepLog - 1)) - 1)
  394. rStep = (((1 << vStepLog) * uint64(toDistribute)) + mid) / uint64(total) // scale on remaining
  395. tmpTotal = mid
  396. )
  397. for i, cnt := range s.count[:s.symbolLen] {
  398. if s.norm[i] == notYetAssigned {
  399. var (
  400. end = tmpTotal + uint64(cnt)*rStep
  401. sStart = uint32(tmpTotal >> vStepLog)
  402. sEnd = uint32(end >> vStepLog)
  403. weight = sEnd - sStart
  404. )
  405. if weight < 1 {
  406. return errors.New("weight < 1")
  407. }
  408. s.norm[i] = int16(weight)
  409. tmpTotal = end
  410. }
  411. }
  412. return nil
  413. }
  414. // optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
  415. func (s *fseEncoder) optimalTableLog(length int) {
  416. tableLog := uint8(maxEncTableLog)
  417. minBitsSrc := highBit(uint32(length)) + 1
  418. minBitsSymbols := highBit(uint32(s.symbolLen-1)) + 2
  419. minBits := uint8(minBitsSymbols)
  420. if minBitsSrc < minBitsSymbols {
  421. minBits = uint8(minBitsSrc)
  422. }
  423. maxBitsSrc := uint8(highBit(uint32(length-1))) - 2
  424. if maxBitsSrc < tableLog {
  425. // Accuracy can be reduced
  426. tableLog = maxBitsSrc
  427. }
  428. if minBits > tableLog {
  429. tableLog = minBits
  430. }
  431. // Need a minimum to safely represent all symbol values
  432. if tableLog < minEncTablelog {
  433. tableLog = minEncTablelog
  434. }
  435. if tableLog > maxEncTableLog {
  436. tableLog = maxEncTableLog
  437. }
  438. s.actualTableLog = tableLog
  439. }
  440. // validateNorm validates the normalized histogram table.
  441. func (s *fseEncoder) validateNorm() (err error) {
  442. var total int
  443. for _, v := range s.norm[:s.symbolLen] {
  444. if v >= 0 {
  445. total += int(v)
  446. } else {
  447. total -= int(v)
  448. }
  449. }
  450. defer func() {
  451. if err == nil {
  452. return
  453. }
  454. fmt.Printf("selected TableLog: %d, Symbol length: %d\n", s.actualTableLog, s.symbolLen)
  455. for i, v := range s.norm[:s.symbolLen] {
  456. fmt.Printf("%3d: %5d -> %4d \n", i, s.count[i], v)
  457. }
  458. }()
  459. if total != (1 << s.actualTableLog) {
  460. return fmt.Errorf("warning: Total == %d != %d", total, 1<<s.actualTableLog)
  461. }
  462. for i, v := range s.count[s.symbolLen:] {
  463. if v != 0 {
  464. return fmt.Errorf("warning: Found symbol out of range, %d after cut", i)
  465. }
  466. }
  467. return nil
  468. }
  469. // writeCount will write the normalized histogram count to header.
  470. // This is read back by readNCount.
  471. func (s *fseEncoder) writeCount(out []byte) ([]byte, error) {
  472. if s.useRLE {
  473. return append(out, s.rleVal), nil
  474. }
  475. if s.preDefined || s.reUsed {
  476. // Never write predefined.
  477. return out, nil
  478. }
  479. var (
  480. tableLog = s.actualTableLog
  481. tableSize = 1 << tableLog
  482. previous0 bool
  483. charnum uint16
  484. // maximum header size plus 2 extra bytes for final output if bitCount == 0.
  485. maxHeaderSize = ((int(s.symbolLen) * int(tableLog)) >> 3) + 3 + 2
  486. // Write Table Size
  487. bitStream = uint32(tableLog - minEncTablelog)
  488. bitCount = uint(4)
  489. remaining = int16(tableSize + 1) /* +1 for extra accuracy */
  490. threshold = int16(tableSize)
  491. nbBits = uint(tableLog + 1)
  492. outP = len(out)
  493. )
  494. if cap(out) < outP+maxHeaderSize {
  495. out = append(out, make([]byte, maxHeaderSize*3)...)
  496. out = out[:len(out)-maxHeaderSize*3]
  497. }
  498. out = out[:outP+maxHeaderSize]
  499. // stops at 1
  500. for remaining > 1 {
  501. if previous0 {
  502. start := charnum
  503. for s.norm[charnum] == 0 {
  504. charnum++
  505. }
  506. for charnum >= start+24 {
  507. start += 24
  508. bitStream += uint32(0xFFFF) << bitCount
  509. out[outP] = byte(bitStream)
  510. out[outP+1] = byte(bitStream >> 8)
  511. outP += 2
  512. bitStream >>= 16
  513. }
  514. for charnum >= start+3 {
  515. start += 3
  516. bitStream += 3 << bitCount
  517. bitCount += 2
  518. }
  519. bitStream += uint32(charnum-start) << bitCount
  520. bitCount += 2
  521. if bitCount > 16 {
  522. out[outP] = byte(bitStream)
  523. out[outP+1] = byte(bitStream >> 8)
  524. outP += 2
  525. bitStream >>= 16
  526. bitCount -= 16
  527. }
  528. }
  529. count := s.norm[charnum]
  530. charnum++
  531. max := (2*threshold - 1) - remaining
  532. if count < 0 {
  533. remaining += count
  534. } else {
  535. remaining -= count
  536. }
  537. count++ // +1 for extra accuracy
  538. if count >= threshold {
  539. count += max // [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[
  540. }
  541. bitStream += uint32(count) << bitCount
  542. bitCount += nbBits
  543. if count < max {
  544. bitCount--
  545. }
  546. previous0 = count == 1
  547. if remaining < 1 {
  548. return nil, errors.New("internal error: remaining < 1")
  549. }
  550. for remaining < threshold {
  551. nbBits--
  552. threshold >>= 1
  553. }
  554. if bitCount > 16 {
  555. out[outP] = byte(bitStream)
  556. out[outP+1] = byte(bitStream >> 8)
  557. outP += 2
  558. bitStream >>= 16
  559. bitCount -= 16
  560. }
  561. }
  562. if outP+2 > len(out) {
  563. return nil, fmt.Errorf("internal error: %d > %d, maxheader: %d, sl: %d, tl: %d, normcount: %v", outP+2, len(out), maxHeaderSize, s.symbolLen, int(tableLog), s.norm[:s.symbolLen])
  564. }
  565. out[outP] = byte(bitStream)
  566. out[outP+1] = byte(bitStream >> 8)
  567. outP += int((bitCount + 7) / 8)
  568. if charnum > s.symbolLen {
  569. return nil, errors.New("internal error: charnum > s.symbolLen")
  570. }
  571. return out[:outP], nil
  572. }
  573. // Approximate symbol cost, as fractional value, using fixed-point format (accuracyLog fractional bits)
  574. // note 1 : assume symbolValue is valid (<= maxSymbolValue)
  575. // note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits *
  576. func (s *fseEncoder) bitCost(symbolValue uint8, accuracyLog uint32) uint32 {
  577. minNbBits := s.ct.symbolTT[symbolValue].deltaNbBits >> 16
  578. threshold := (minNbBits + 1) << 16
  579. if debugAsserts {
  580. if !(s.actualTableLog < 16) {
  581. panic("!s.actualTableLog < 16")
  582. }
  583. // ensure enough room for renormalization double shift
  584. if !(uint8(accuracyLog) < 31-s.actualTableLog) {
  585. panic("!uint8(accuracyLog) < 31-s.actualTableLog")
  586. }
  587. }
  588. tableSize := uint32(1) << s.actualTableLog
  589. deltaFromThreshold := threshold - (s.ct.symbolTT[symbolValue].deltaNbBits + tableSize)
  590. // linear interpolation (very approximate)
  591. normalizedDeltaFromThreshold := (deltaFromThreshold << accuracyLog) >> s.actualTableLog
  592. bitMultiplier := uint32(1) << accuracyLog
  593. if debugAsserts {
  594. if s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold {
  595. panic("s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold")
  596. }
  597. if normalizedDeltaFromThreshold > bitMultiplier {
  598. panic("normalizedDeltaFromThreshold > bitMultiplier")
  599. }
  600. }
  601. return (minNbBits+1)*bitMultiplier - normalizedDeltaFromThreshold
  602. }
  603. // Returns the cost in bits of encoding the distribution in count using ctable.
  604. // Histogram should only be up to the last non-zero symbol.
  605. // Returns an -1 if ctable cannot represent all the symbols in count.
  606. func (s *fseEncoder) approxSize(hist []uint32) uint32 {
  607. if int(s.symbolLen) < len(hist) {
  608. // More symbols than we have.
  609. return math.MaxUint32
  610. }
  611. if s.useRLE {
  612. // We will never reuse RLE encoders.
  613. return math.MaxUint32
  614. }
  615. const kAccuracyLog = 8
  616. badCost := (uint32(s.actualTableLog) + 1) << kAccuracyLog
  617. var cost uint32
  618. for i, v := range hist {
  619. if v == 0 {
  620. continue
  621. }
  622. if s.norm[i] == 0 {
  623. return math.MaxUint32
  624. }
  625. bitCost := s.bitCost(uint8(i), kAccuracyLog)
  626. if bitCost > badCost {
  627. return math.MaxUint32
  628. }
  629. cost += v * bitCost
  630. }
  631. return cost >> kAccuracyLog
  632. }
  633. // maxHeaderSize returns the maximum header size in bits.
  634. // This is not exact size, but we want a penalty for new tables anyway.
  635. func (s *fseEncoder) maxHeaderSize() uint32 {
  636. if s.preDefined {
  637. return 0
  638. }
  639. if s.useRLE {
  640. return 8
  641. }
  642. return (((uint32(s.symbolLen) * uint32(s.actualTableLog)) >> 3) + 3) * 8
  643. }
  644. // cState contains the compression state of a stream.
  645. type cState struct {
  646. bw *bitWriter
  647. stateTable []uint16
  648. state uint16
  649. }
  650. // init will initialize the compression state to the first symbol of the stream.
  651. func (c *cState) init(bw *bitWriter, ct *cTable, first symbolTransform) {
  652. c.bw = bw
  653. c.stateTable = ct.stateTable
  654. if len(c.stateTable) == 1 {
  655. // RLE
  656. c.stateTable[0] = uint16(0)
  657. c.state = 0
  658. return
  659. }
  660. nbBitsOut := (first.deltaNbBits + (1 << 15)) >> 16
  661. im := int32((nbBitsOut << 16) - first.deltaNbBits)
  662. lu := (im >> nbBitsOut) + int32(first.deltaFindState)
  663. c.state = c.stateTable[lu]
  664. return
  665. }
  666. // encode the output symbol provided and write it to the bitstream.
  667. func (c *cState) encode(symbolTT symbolTransform) {
  668. nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16
  669. dstState := int32(c.state>>(nbBitsOut&15)) + int32(symbolTT.deltaFindState)
  670. c.bw.addBits16NC(c.state, uint8(nbBitsOut))
  671. c.state = c.stateTable[dstState]
  672. }
  673. // flush will write the tablelog to the output and flush the remaining full bytes.
  674. func (c *cState) flush(tableLog uint8) {
  675. c.bw.flush32()
  676. c.bw.addBits16NC(c.state, tableLog)
  677. }