decompress.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. package huff0
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "github.com/klauspost/compress/fse"
  7. )
  8. type dTable struct {
  9. single []dEntrySingle
  10. double []dEntryDouble
  11. }
  12. // single-symbols decoding
  13. type dEntrySingle struct {
  14. entry uint16
  15. }
  16. // double-symbols decoding
  17. type dEntryDouble struct {
  18. seq uint16
  19. nBits uint8
  20. len uint8
  21. }
  22. // ReadTable will read a table from the input.
  23. // The size of the input may be larger than the table definition.
  24. // Any content remaining after the table definition will be returned.
  25. // If no Scratch is provided a new one is allocated.
  26. // The returned Scratch can be used for decoding input using this table.
  27. func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
  28. s, err = s.prepare(in)
  29. if err != nil {
  30. return s, nil, err
  31. }
  32. if len(in) <= 1 {
  33. return s, nil, errors.New("input too small for table")
  34. }
  35. iSize := in[0]
  36. in = in[1:]
  37. if iSize >= 128 {
  38. // Uncompressed
  39. oSize := iSize - 127
  40. iSize = (oSize + 1) / 2
  41. if int(iSize) > len(in) {
  42. return s, nil, errors.New("input too small for table")
  43. }
  44. for n := uint8(0); n < oSize; n += 2 {
  45. v := in[n/2]
  46. s.huffWeight[n] = v >> 4
  47. s.huffWeight[n+1] = v & 15
  48. }
  49. s.symbolLen = uint16(oSize)
  50. in = in[iSize:]
  51. } else {
  52. if len(in) <= int(iSize) {
  53. return s, nil, errors.New("input too small for table")
  54. }
  55. // FSE compressed weights
  56. s.fse.DecompressLimit = 255
  57. hw := s.huffWeight[:]
  58. s.fse.Out = hw
  59. b, err := fse.Decompress(in[:iSize], s.fse)
  60. s.fse.Out = nil
  61. if err != nil {
  62. return s, nil, err
  63. }
  64. if len(b) > 255 {
  65. return s, nil, errors.New("corrupt input: output table too large")
  66. }
  67. s.symbolLen = uint16(len(b))
  68. in = in[iSize:]
  69. }
  70. // collect weight stats
  71. var rankStats [16]uint32
  72. weightTotal := uint32(0)
  73. for _, v := range s.huffWeight[:s.symbolLen] {
  74. if v > tableLogMax {
  75. return s, nil, errors.New("corrupt input: weight too large")
  76. }
  77. v2 := v & 15
  78. rankStats[v2]++
  79. weightTotal += (1 << v2) >> 1
  80. }
  81. if weightTotal == 0 {
  82. return s, nil, errors.New("corrupt input: weights zero")
  83. }
  84. // get last non-null symbol weight (implied, total must be 2^n)
  85. {
  86. tableLog := highBit32(weightTotal) + 1
  87. if tableLog > tableLogMax {
  88. return s, nil, errors.New("corrupt input: tableLog too big")
  89. }
  90. s.actualTableLog = uint8(tableLog)
  91. // determine last weight
  92. {
  93. total := uint32(1) << tableLog
  94. rest := total - weightTotal
  95. verif := uint32(1) << highBit32(rest)
  96. lastWeight := highBit32(rest) + 1
  97. if verif != rest {
  98. // last value must be a clean power of 2
  99. return s, nil, errors.New("corrupt input: last value not power of two")
  100. }
  101. s.huffWeight[s.symbolLen] = uint8(lastWeight)
  102. s.symbolLen++
  103. rankStats[lastWeight]++
  104. }
  105. }
  106. if (rankStats[1] < 2) || (rankStats[1]&1 != 0) {
  107. // by construction : at least 2 elts of rank 1, must be even
  108. return s, nil, errors.New("corrupt input: min elt size, even check failed ")
  109. }
  110. // TODO: Choose between single/double symbol decoding
  111. // Calculate starting value for each rank
  112. {
  113. var nextRankStart uint32
  114. for n := uint8(1); n < s.actualTableLog+1; n++ {
  115. current := nextRankStart
  116. nextRankStart += rankStats[n] << (n - 1)
  117. rankStats[n] = current
  118. }
  119. }
  120. // fill DTable (always full size)
  121. tSize := 1 << tableLogMax
  122. if len(s.dt.single) != tSize {
  123. s.dt.single = make([]dEntrySingle, tSize)
  124. }
  125. for n, w := range s.huffWeight[:s.symbolLen] {
  126. if w == 0 {
  127. continue
  128. }
  129. length := (uint32(1) << w) >> 1
  130. d := dEntrySingle{
  131. entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8),
  132. }
  133. single := s.dt.single[rankStats[w] : rankStats[w]+length]
  134. for i := range single {
  135. single[i] = d
  136. }
  137. rankStats[w] += length
  138. }
  139. return s, in, nil
  140. }
  141. // Decompress1X will decompress a 1X encoded stream.
  142. // The length of the supplied input must match the end of a block exactly.
  143. // Before this is called, the table must be initialized with ReadTable unless
  144. // the encoder re-used the table.
  145. func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) {
  146. if len(s.dt.single) == 0 {
  147. return nil, errors.New("no table loaded")
  148. }
  149. var br bitReader
  150. err = br.init(in)
  151. if err != nil {
  152. return nil, err
  153. }
  154. s.Out = s.Out[:0]
  155. decode := func() byte {
  156. val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
  157. v := s.dt.single[val]
  158. br.bitsRead += uint8(v.entry)
  159. return uint8(v.entry >> 8)
  160. }
  161. hasDec := func(v dEntrySingle) byte {
  162. br.bitsRead += uint8(v.entry)
  163. return uint8(v.entry >> 8)
  164. }
  165. // Avoid bounds check by always having full sized table.
  166. const tlSize = 1 << tableLogMax
  167. const tlMask = tlSize - 1
  168. dt := s.dt.single[:tlSize]
  169. // Use temp table to avoid bound checks/append penalty.
  170. var tmp = s.huffWeight[:256]
  171. var off uint8
  172. for br.off >= 8 {
  173. br.fillFast()
  174. tmp[off+0] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
  175. tmp[off+1] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
  176. br.fillFast()
  177. tmp[off+2] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
  178. tmp[off+3] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
  179. off += 4
  180. if off == 0 {
  181. if len(s.Out)+256 > s.MaxDecodedSize {
  182. br.close()
  183. return nil, ErrMaxDecodedSizeExceeded
  184. }
  185. s.Out = append(s.Out, tmp...)
  186. }
  187. }
  188. if len(s.Out)+int(off) > s.MaxDecodedSize {
  189. br.close()
  190. return nil, ErrMaxDecodedSizeExceeded
  191. }
  192. s.Out = append(s.Out, tmp[:off]...)
  193. for !br.finished() {
  194. br.fill()
  195. if len(s.Out) >= s.MaxDecodedSize {
  196. br.close()
  197. return nil, ErrMaxDecodedSizeExceeded
  198. }
  199. s.Out = append(s.Out, decode())
  200. }
  201. return s.Out, br.close()
  202. }
  203. // Decompress4X will decompress a 4X encoded stream.
  204. // Before this is called, the table must be initialized with ReadTable unless
  205. // the encoder re-used the table.
  206. // The length of the supplied input must match the end of a block exactly.
  207. // The destination size of the uncompressed data must be known and provided.
  208. func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
  209. if len(s.dt.single) == 0 {
  210. return nil, errors.New("no table loaded")
  211. }
  212. if len(in) < 6+(4*1) {
  213. return nil, errors.New("input too small")
  214. }
  215. if dstSize > s.MaxDecodedSize {
  216. return nil, ErrMaxDecodedSizeExceeded
  217. }
  218. // TODO: We do not detect when we overrun a buffer, except if the last one does.
  219. var br [4]bitReader
  220. start := 6
  221. for i := 0; i < 3; i++ {
  222. length := int(in[i*2]) | (int(in[i*2+1]) << 8)
  223. if start+length >= len(in) {
  224. return nil, errors.New("truncated input (or invalid offset)")
  225. }
  226. err = br[i].init(in[start : start+length])
  227. if err != nil {
  228. return nil, err
  229. }
  230. start += length
  231. }
  232. err = br[3].init(in[start:])
  233. if err != nil {
  234. return nil, err
  235. }
  236. // Prepare output
  237. if cap(s.Out) < dstSize {
  238. s.Out = make([]byte, 0, dstSize)
  239. }
  240. s.Out = s.Out[:dstSize]
  241. // destination, offset to match first output
  242. dstOut := s.Out
  243. dstEvery := (dstSize + 3) / 4
  244. const tlSize = 1 << tableLogMax
  245. const tlMask = tlSize - 1
  246. single := s.dt.single[:tlSize]
  247. decode := func(br *bitReader) byte {
  248. val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
  249. v := single[val&tlMask]
  250. br.bitsRead += uint8(v.entry)
  251. return uint8(v.entry >> 8)
  252. }
  253. // Use temp table to avoid bound checks/append penalty.
  254. var tmp = s.huffWeight[:256]
  255. var off uint8
  256. var decoded int
  257. // Decode 2 values from each decoder/loop.
  258. const bufoff = 256 / 4
  259. bigloop:
  260. for {
  261. for i := range br {
  262. br := &br[i]
  263. if br.off < 4 {
  264. break bigloop
  265. }
  266. br.fillFast()
  267. }
  268. {
  269. const stream = 0
  270. val := br[stream].peekBitsFast(s.actualTableLog)
  271. v := single[val&tlMask]
  272. br[stream].bitsRead += uint8(v.entry)
  273. val2 := br[stream].peekBitsFast(s.actualTableLog)
  274. v2 := single[val2&tlMask]
  275. tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
  276. tmp[off+bufoff*stream] = uint8(v.entry >> 8)
  277. br[stream].bitsRead += uint8(v2.entry)
  278. }
  279. {
  280. const stream = 1
  281. val := br[stream].peekBitsFast(s.actualTableLog)
  282. v := single[val&tlMask]
  283. br[stream].bitsRead += uint8(v.entry)
  284. val2 := br[stream].peekBitsFast(s.actualTableLog)
  285. v2 := single[val2&tlMask]
  286. tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
  287. tmp[off+bufoff*stream] = uint8(v.entry >> 8)
  288. br[stream].bitsRead += uint8(v2.entry)
  289. }
  290. {
  291. const stream = 2
  292. val := br[stream].peekBitsFast(s.actualTableLog)
  293. v := single[val&tlMask]
  294. br[stream].bitsRead += uint8(v.entry)
  295. val2 := br[stream].peekBitsFast(s.actualTableLog)
  296. v2 := single[val2&tlMask]
  297. tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
  298. tmp[off+bufoff*stream] = uint8(v.entry >> 8)
  299. br[stream].bitsRead += uint8(v2.entry)
  300. }
  301. {
  302. const stream = 3
  303. val := br[stream].peekBitsFast(s.actualTableLog)
  304. v := single[val&tlMask]
  305. br[stream].bitsRead += uint8(v.entry)
  306. val2 := br[stream].peekBitsFast(s.actualTableLog)
  307. v2 := single[val2&tlMask]
  308. tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
  309. tmp[off+bufoff*stream] = uint8(v.entry >> 8)
  310. br[stream].bitsRead += uint8(v2.entry)
  311. }
  312. off += 2
  313. if off == bufoff {
  314. if bufoff > dstEvery {
  315. return nil, errors.New("corruption detected: stream overrun 1")
  316. }
  317. copy(dstOut, tmp[:bufoff])
  318. copy(dstOut[dstEvery:], tmp[bufoff:bufoff*2])
  319. copy(dstOut[dstEvery*2:], tmp[bufoff*2:bufoff*3])
  320. copy(dstOut[dstEvery*3:], tmp[bufoff*3:bufoff*4])
  321. off = 0
  322. dstOut = dstOut[bufoff:]
  323. decoded += 256
  324. // There must at least be 3 buffers left.
  325. if len(dstOut) < dstEvery*3 {
  326. return nil, errors.New("corruption detected: stream overrun 2")
  327. }
  328. }
  329. }
  330. if off > 0 {
  331. ioff := int(off)
  332. if len(dstOut) < dstEvery*3+ioff {
  333. return nil, errors.New("corruption detected: stream overrun 3")
  334. }
  335. copy(dstOut, tmp[:off])
  336. copy(dstOut[dstEvery:dstEvery+ioff], tmp[bufoff:bufoff*2])
  337. copy(dstOut[dstEvery*2:dstEvery*2+ioff], tmp[bufoff*2:bufoff*3])
  338. copy(dstOut[dstEvery*3:dstEvery*3+ioff], tmp[bufoff*3:bufoff*4])
  339. decoded += int(off) * 4
  340. dstOut = dstOut[off:]
  341. }
  342. // Decode remaining.
  343. for i := range br {
  344. offset := dstEvery * i
  345. br := &br[i]
  346. for !br.finished() {
  347. br.fill()
  348. if offset >= len(dstOut) {
  349. return nil, errors.New("corruption detected: stream overrun 4")
  350. }
  351. dstOut[offset] = decode(br)
  352. offset++
  353. }
  354. decoded += offset - dstEvery*i
  355. err = br.close()
  356. if err != nil {
  357. return nil, err
  358. }
  359. }
  360. if dstSize != decoded {
  361. return nil, errors.New("corruption detected: short output block")
  362. }
  363. return s.Out, nil
  364. }
  365. // matches will compare a decoding table to a coding table.
  366. // Errors are written to the writer.
  367. // Nothing will be written if table is ok.
  368. func (s *Scratch) matches(ct cTable, w io.Writer) {
  369. if s == nil || len(s.dt.single) == 0 {
  370. return
  371. }
  372. dt := s.dt.single[:1<<s.actualTableLog]
  373. tablelog := s.actualTableLog
  374. ok := 0
  375. broken := 0
  376. for sym, enc := range ct {
  377. errs := 0
  378. broken++
  379. if enc.nBits == 0 {
  380. for _, dec := range dt {
  381. if uint8(dec.entry>>8) == byte(sym) {
  382. fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym)
  383. errs++
  384. break
  385. }
  386. }
  387. if errs == 0 {
  388. broken--
  389. }
  390. continue
  391. }
  392. // Unused bits in input
  393. ub := tablelog - enc.nBits
  394. top := enc.val << ub
  395. // decoder looks at top bits.
  396. dec := dt[top]
  397. if uint8(dec.entry) != enc.nBits {
  398. fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry))
  399. errs++
  400. }
  401. if uint8(dec.entry>>8) != uint8(sym) {
  402. fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8))
  403. errs++
  404. }
  405. if errs > 0 {
  406. fmt.Fprintf(w, "%d errros in base, stopping\n", errs)
  407. continue
  408. }
  409. // Ensure that all combinations are covered.
  410. for i := uint16(0); i < (1 << ub); i++ {
  411. vval := top | i
  412. dec := dt[vval]
  413. if uint8(dec.entry) != enc.nBits {
  414. fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry))
  415. errs++
  416. }
  417. if uint8(dec.entry>>8) != uint8(sym) {
  418. fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8))
  419. errs++
  420. }
  421. if errs > 20 {
  422. fmt.Fprintf(w, "%d errros, stopping\n", errs)
  423. break
  424. }
  425. }
  426. if errs == 0 {
  427. ok++
  428. broken--
  429. }
  430. }
  431. if broken > 0 {
  432. fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok)
  433. }
  434. }