framedec.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "bytes"
  7. "encoding/hex"
  8. "errors"
  9. "hash"
  10. "io"
  11. "sync"
  12. "github.com/klauspost/compress/zstd/internal/xxhash"
  13. )
  14. type frameDec struct {
  15. o decoderOptions
  16. crc hash.Hash64
  17. offset int64
  18. WindowSize uint64
  19. // maxWindowSize is the maximum windows size to support.
  20. // should never be bigger than max-int.
  21. maxWindowSize uint64
  22. // In order queue of blocks being decoded.
  23. decoding chan *blockDec
  24. // Frame history passed between blocks
  25. history history
  26. rawInput byteBuffer
  27. // Byte buffer that can be reused for small input blocks.
  28. bBuf byteBuf
  29. FrameContentSize uint64
  30. frameDone sync.WaitGroup
  31. DictionaryID uint32
  32. HasCheckSum bool
  33. SingleSegment bool
  34. // asyncRunning indicates whether the async routine processes input on 'decoding'.
  35. asyncRunningMu sync.Mutex
  36. asyncRunning bool
  37. }
  38. const (
  39. // The minimum Window_Size is 1 KB.
  40. MinWindowSize = 1 << 10
  41. MaxWindowSize = 1 << 29
  42. )
  43. var (
  44. frameMagic = []byte{0x28, 0xb5, 0x2f, 0xfd}
  45. skippableFrameMagic = []byte{0x2a, 0x4d, 0x18}
  46. )
  47. func newFrameDec(o decoderOptions) *frameDec {
  48. d := frameDec{
  49. o: o,
  50. maxWindowSize: MaxWindowSize,
  51. }
  52. if d.maxWindowSize > o.maxDecodedSize {
  53. d.maxWindowSize = o.maxDecodedSize
  54. }
  55. return &d
  56. }
  57. // reset will read the frame header and prepare for block decoding.
  58. // If nothing can be read from the input, io.EOF will be returned.
  59. // Any other error indicated that the stream contained data, but
  60. // there was a problem.
  61. func (d *frameDec) reset(br byteBuffer) error {
  62. d.HasCheckSum = false
  63. d.WindowSize = 0
  64. var b []byte
  65. for {
  66. b = br.readSmall(4)
  67. if b == nil {
  68. return io.EOF
  69. }
  70. if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {
  71. if debug {
  72. println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic))
  73. }
  74. // Break if not skippable frame.
  75. break
  76. }
  77. // Read size to skip
  78. b = br.readSmall(4)
  79. if b == nil {
  80. println("Reading Frame Size EOF")
  81. return io.ErrUnexpectedEOF
  82. }
  83. n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  84. println("Skipping frame with", n, "bytes.")
  85. err := br.skipN(int(n))
  86. if err != nil {
  87. if debug {
  88. println("Reading discarded frame", err)
  89. }
  90. return err
  91. }
  92. }
  93. if !bytes.Equal(b, frameMagic) {
  94. println("Got magic numbers: ", b, "want:", frameMagic)
  95. return ErrMagicMismatch
  96. }
  97. // Read Frame_Header_Descriptor
  98. fhd, err := br.readByte()
  99. if err != nil {
  100. println("Reading Frame_Header_Descriptor", err)
  101. return err
  102. }
  103. d.SingleSegment = fhd&(1<<5) != 0
  104. if fhd&(1<<3) != 0 {
  105. return errors.New("Reserved bit set on frame header")
  106. }
  107. // Read Window_Descriptor
  108. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
  109. d.WindowSize = 0
  110. if !d.SingleSegment {
  111. wd, err := br.readByte()
  112. if err != nil {
  113. println("Reading Window_Descriptor", err)
  114. return err
  115. }
  116. printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
  117. windowLog := 10 + (wd >> 3)
  118. windowBase := uint64(1) << windowLog
  119. windowAdd := (windowBase / 8) * uint64(wd&0x7)
  120. d.WindowSize = windowBase + windowAdd
  121. }
  122. // Read Dictionary_ID
  123. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
  124. d.DictionaryID = 0
  125. if size := fhd & 3; size != 0 {
  126. if size == 3 {
  127. size = 4
  128. }
  129. b = br.readSmall(int(size))
  130. if b == nil {
  131. if debug {
  132. println("Reading Dictionary_ID", io.ErrUnexpectedEOF)
  133. }
  134. return io.ErrUnexpectedEOF
  135. }
  136. switch size {
  137. case 1:
  138. d.DictionaryID = uint32(b[0])
  139. case 2:
  140. d.DictionaryID = uint32(b[0]) | (uint32(b[1]) << 8)
  141. case 4:
  142. d.DictionaryID = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  143. }
  144. if debug {
  145. println("Dict size", size, "ID:", d.DictionaryID)
  146. }
  147. if d.DictionaryID != 0 {
  148. return ErrUnknownDictionary
  149. }
  150. }
  151. // Read Frame_Content_Size
  152. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size
  153. var fcsSize int
  154. v := fhd >> 6
  155. switch v {
  156. case 0:
  157. if d.SingleSegment {
  158. fcsSize = 1
  159. }
  160. default:
  161. fcsSize = 1 << v
  162. }
  163. d.FrameContentSize = 0
  164. if fcsSize > 0 {
  165. b := br.readSmall(fcsSize)
  166. if b == nil {
  167. println("Reading Frame content", io.ErrUnexpectedEOF)
  168. return io.ErrUnexpectedEOF
  169. }
  170. switch fcsSize {
  171. case 1:
  172. d.FrameContentSize = uint64(b[0])
  173. case 2:
  174. // When FCS_Field_Size is 2, the offset of 256 is added.
  175. d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256
  176. case 4:
  177. d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
  178. case 8:
  179. d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  180. d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24)
  181. d.FrameContentSize = uint64(d1) | (uint64(d2) << 32)
  182. }
  183. if debug {
  184. println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize)
  185. }
  186. }
  187. // Move this to shared.
  188. d.HasCheckSum = fhd&(1<<2) != 0
  189. if d.HasCheckSum {
  190. if d.crc == nil {
  191. d.crc = xxhash.New()
  192. }
  193. d.crc.Reset()
  194. }
  195. if d.WindowSize == 0 && d.SingleSegment {
  196. // We may not need window in this case.
  197. d.WindowSize = d.FrameContentSize
  198. if d.WindowSize < MinWindowSize {
  199. d.WindowSize = MinWindowSize
  200. }
  201. }
  202. if d.WindowSize > d.maxWindowSize {
  203. printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize)
  204. return ErrWindowSizeExceeded
  205. }
  206. // The minimum Window_Size is 1 KB.
  207. if d.WindowSize < MinWindowSize {
  208. println("got window size: ", d.WindowSize)
  209. return ErrWindowSizeTooSmall
  210. }
  211. d.history.windowSize = int(d.WindowSize)
  212. d.history.maxSize = d.history.windowSize + maxBlockSize
  213. // history contains input - maybe we do something
  214. d.rawInput = br
  215. return nil
  216. }
  217. // next will start decoding the next block from stream.
  218. func (d *frameDec) next(block *blockDec) error {
  219. if debug {
  220. printf("decoding new block %p:%p", block, block.data)
  221. }
  222. err := block.reset(d.rawInput, d.WindowSize)
  223. if err != nil {
  224. println("block error:", err)
  225. // Signal the frame decoder we have a problem.
  226. d.sendErr(block, err)
  227. return err
  228. }
  229. block.input <- struct{}{}
  230. if debug {
  231. println("next block:", block)
  232. }
  233. d.asyncRunningMu.Lock()
  234. defer d.asyncRunningMu.Unlock()
  235. if !d.asyncRunning {
  236. return nil
  237. }
  238. if block.Last {
  239. // We indicate the frame is done by sending io.EOF
  240. d.decoding <- block
  241. return io.EOF
  242. }
  243. d.decoding <- block
  244. return nil
  245. }
  246. // sendEOF will queue an error block on the frame.
  247. // This will cause the frame decoder to return when it encounters the block.
  248. // Returns true if the decoder was added.
  249. func (d *frameDec) sendErr(block *blockDec, err error) bool {
  250. d.asyncRunningMu.Lock()
  251. defer d.asyncRunningMu.Unlock()
  252. if !d.asyncRunning {
  253. return false
  254. }
  255. println("sending error", err.Error())
  256. block.sendErr(err)
  257. d.decoding <- block
  258. return true
  259. }
  260. // checkCRC will check the checksum if the frame has one.
  261. // Will return ErrCRCMismatch if crc check failed, otherwise nil.
  262. func (d *frameDec) checkCRC() error {
  263. if !d.HasCheckSum {
  264. return nil
  265. }
  266. var tmp [4]byte
  267. got := d.crc.Sum64()
  268. // Flip to match file order.
  269. tmp[0] = byte(got >> 0)
  270. tmp[1] = byte(got >> 8)
  271. tmp[2] = byte(got >> 16)
  272. tmp[3] = byte(got >> 24)
  273. // We can overwrite upper tmp now
  274. want := d.rawInput.readSmall(4)
  275. if want == nil {
  276. println("CRC missing?")
  277. return io.ErrUnexpectedEOF
  278. }
  279. if !bytes.Equal(tmp[:], want) {
  280. if debug {
  281. println("CRC Check Failed:", tmp[:], "!=", want)
  282. }
  283. return ErrCRCMismatch
  284. }
  285. if debug {
  286. println("CRC ok", tmp[:])
  287. }
  288. return nil
  289. }
  290. func (d *frameDec) initAsync() {
  291. if !d.o.lowMem && !d.SingleSegment {
  292. // set max extra size history to 20MB.
  293. d.history.maxSize = d.history.windowSize + maxBlockSize*10
  294. }
  295. // re-alloc if more than one extra block size.
  296. if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize {
  297. d.history.b = make([]byte, 0, d.history.maxSize)
  298. }
  299. if cap(d.history.b) < d.history.maxSize {
  300. d.history.b = make([]byte, 0, d.history.maxSize)
  301. }
  302. if cap(d.decoding) < d.o.concurrent {
  303. d.decoding = make(chan *blockDec, d.o.concurrent)
  304. }
  305. if debug {
  306. h := d.history
  307. printf("history init. len: %d, cap: %d", len(h.b), cap(h.b))
  308. }
  309. d.asyncRunningMu.Lock()
  310. d.asyncRunning = true
  311. d.asyncRunningMu.Unlock()
  312. }
  313. // startDecoder will start decoding blocks and write them to the writer.
  314. // The decoder will stop as soon as an error occurs or at end of frame.
  315. // When the frame has finished decoding the *bufio.Reader
  316. // containing the remaining input will be sent on frameDec.frameDone.
  317. func (d *frameDec) startDecoder(output chan decodeOutput) {
  318. // TODO: Init to dictionary
  319. d.history.reset()
  320. written := int64(0)
  321. defer func() {
  322. d.asyncRunningMu.Lock()
  323. d.asyncRunning = false
  324. d.asyncRunningMu.Unlock()
  325. // Drain the currently decoding.
  326. d.history.error = true
  327. flushdone:
  328. for {
  329. select {
  330. case b := <-d.decoding:
  331. b.history <- &d.history
  332. output <- <-b.result
  333. default:
  334. break flushdone
  335. }
  336. }
  337. println("frame decoder done, signalling done")
  338. d.frameDone.Done()
  339. }()
  340. // Get decoder for first block.
  341. block := <-d.decoding
  342. block.history <- &d.history
  343. for {
  344. var next *blockDec
  345. // Get result
  346. r := <-block.result
  347. if r.err != nil {
  348. println("Result contained error", r.err)
  349. output <- r
  350. return
  351. }
  352. if debug {
  353. println("got result, from ", d.offset, "to", d.offset+int64(len(r.b)))
  354. d.offset += int64(len(r.b))
  355. }
  356. if !block.Last {
  357. // Send history to next block
  358. select {
  359. case next = <-d.decoding:
  360. if debug {
  361. println("Sending ", len(d.history.b), "bytes as history")
  362. }
  363. next.history <- &d.history
  364. default:
  365. // Wait until we have sent the block, so
  366. // other decoders can potentially get the decoder.
  367. next = nil
  368. }
  369. }
  370. // Add checksum, async to decoding.
  371. if d.HasCheckSum {
  372. n, err := d.crc.Write(r.b)
  373. if err != nil {
  374. r.err = err
  375. if n != len(r.b) {
  376. r.err = io.ErrShortWrite
  377. }
  378. output <- r
  379. return
  380. }
  381. }
  382. written += int64(len(r.b))
  383. if d.SingleSegment && uint64(written) > d.FrameContentSize {
  384. println("runDecoder: single segment and", uint64(written), ">", d.FrameContentSize)
  385. r.err = ErrFrameSizeExceeded
  386. output <- r
  387. return
  388. }
  389. if block.Last {
  390. r.err = d.checkCRC()
  391. output <- r
  392. return
  393. }
  394. output <- r
  395. if next == nil {
  396. // There was no decoder available, we wait for one now that we have sent to the writer.
  397. if debug {
  398. println("Sending ", len(d.history.b), " bytes as history")
  399. }
  400. next = <-d.decoding
  401. next.history <- &d.history
  402. }
  403. block = next
  404. }
  405. }
  406. // runDecoder will create a sync decoder that will decode a block of data.
  407. func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
  408. // TODO: Init to dictionary
  409. d.history.reset()
  410. saved := d.history.b
  411. // We use the history for output to avoid copying it.
  412. d.history.b = dst
  413. // Store input length, so we only check new data.
  414. crcStart := len(dst)
  415. var err error
  416. for {
  417. err = dec.reset(d.rawInput, d.WindowSize)
  418. if err != nil {
  419. break
  420. }
  421. if debug {
  422. println("next block:", dec)
  423. }
  424. err = dec.decodeBuf(&d.history)
  425. if err != nil || dec.Last {
  426. break
  427. }
  428. if uint64(len(d.history.b)) > d.o.maxDecodedSize {
  429. err = ErrDecoderSizeExceeded
  430. break
  431. }
  432. if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize {
  433. println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize)
  434. err = ErrFrameSizeExceeded
  435. break
  436. }
  437. }
  438. dst = d.history.b
  439. if err == nil {
  440. if d.HasCheckSum {
  441. var n int
  442. n, err = d.crc.Write(dst[crcStart:])
  443. if err == nil {
  444. if n != len(dst)-crcStart {
  445. err = io.ErrShortWrite
  446. } else {
  447. err = d.checkCRC()
  448. }
  449. }
  450. }
  451. }
  452. d.history.b = saved
  453. return dst, err
  454. }