encoder.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "crypto/rand"
  7. "fmt"
  8. "io"
  9. rdebug "runtime/debug"
  10. "sync"
  11. "github.com/klauspost/compress/zstd/internal/xxhash"
  12. )
  13. // Encoder provides encoding to Zstandard.
  14. // An Encoder can be used for either compressing a stream via the
  15. // io.WriteCloser interface supported by the Encoder or as multiple independent
  16. // tasks via the EncodeAll function.
  17. // Smaller encodes are encouraged to use the EncodeAll function.
  18. // Use NewWriter to create a new instance.
  19. type Encoder struct {
  20. o encoderOptions
  21. encoders chan encoder
  22. state encoderState
  23. init sync.Once
  24. }
  25. type encoder interface {
  26. Encode(blk *blockEnc, src []byte)
  27. EncodeNoHist(blk *blockEnc, src []byte)
  28. Block() *blockEnc
  29. CRC() *xxhash.Digest
  30. AppendCRC([]byte) []byte
  31. WindowSize(size int) int32
  32. UseBlock(*blockEnc)
  33. Reset()
  34. }
  35. type encoderState struct {
  36. w io.Writer
  37. filling []byte
  38. current []byte
  39. previous []byte
  40. encoder encoder
  41. writing *blockEnc
  42. err error
  43. writeErr error
  44. nWritten int64
  45. headerWritten bool
  46. eofWritten bool
  47. // This waitgroup indicates an encode is running.
  48. wg sync.WaitGroup
  49. // This waitgroup indicates we have a block encoding/writing.
  50. wWg sync.WaitGroup
  51. }
  52. // NewWriter will create a new Zstandard encoder.
  53. // If the encoder will be used for encoding blocks a nil writer can be used.
  54. func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
  55. initPredefined()
  56. var e Encoder
  57. e.o.setDefault()
  58. for _, o := range opts {
  59. err := o(&e.o)
  60. if err != nil {
  61. return nil, err
  62. }
  63. }
  64. if w != nil {
  65. e.Reset(w)
  66. }
  67. return &e, nil
  68. }
  69. func (e *Encoder) initialize() {
  70. if e.o.concurrent == 0 {
  71. e.o.setDefault()
  72. }
  73. e.encoders = make(chan encoder, e.o.concurrent)
  74. for i := 0; i < e.o.concurrent; i++ {
  75. e.encoders <- e.o.encoder()
  76. }
  77. }
  78. // Reset will re-initialize the writer and new writes will encode to the supplied writer
  79. // as a new, independent stream.
  80. func (e *Encoder) Reset(w io.Writer) {
  81. s := &e.state
  82. s.wg.Wait()
  83. s.wWg.Wait()
  84. if cap(s.filling) == 0 {
  85. s.filling = make([]byte, 0, e.o.blockSize)
  86. }
  87. if cap(s.current) == 0 {
  88. s.current = make([]byte, 0, e.o.blockSize)
  89. }
  90. if cap(s.previous) == 0 {
  91. s.previous = make([]byte, 0, e.o.blockSize)
  92. }
  93. if s.encoder == nil {
  94. s.encoder = e.o.encoder()
  95. }
  96. if s.writing == nil {
  97. s.writing = &blockEnc{}
  98. s.writing.init()
  99. }
  100. s.writing.initNewEncode()
  101. s.filling = s.filling[:0]
  102. s.current = s.current[:0]
  103. s.previous = s.previous[:0]
  104. s.encoder.Reset()
  105. s.headerWritten = false
  106. s.eofWritten = false
  107. s.w = w
  108. s.err = nil
  109. s.nWritten = 0
  110. s.writeErr = nil
  111. }
  112. // Write data to the encoder.
  113. // Input data will be buffered and as the buffer fills up
  114. // content will be compressed and written to the output.
  115. // When done writing, use Close to flush the remaining output
  116. // and write CRC if requested.
  117. func (e *Encoder) Write(p []byte) (n int, err error) {
  118. s := &e.state
  119. for len(p) > 0 {
  120. if len(p)+len(s.filling) < e.o.blockSize {
  121. if e.o.crc {
  122. _, _ = s.encoder.CRC().Write(p)
  123. }
  124. s.filling = append(s.filling, p...)
  125. return n + len(p), nil
  126. }
  127. add := p
  128. if len(p)+len(s.filling) > e.o.blockSize {
  129. add = add[:e.o.blockSize-len(s.filling)]
  130. }
  131. if e.o.crc {
  132. _, _ = s.encoder.CRC().Write(add)
  133. }
  134. s.filling = append(s.filling, add...)
  135. p = p[len(add):]
  136. n += len(add)
  137. if len(s.filling) < e.o.blockSize {
  138. return n, nil
  139. }
  140. err := e.nextBlock(false)
  141. if err != nil {
  142. return n, err
  143. }
  144. if debugAsserts && len(s.filling) > 0 {
  145. panic(len(s.filling))
  146. }
  147. }
  148. return n, nil
  149. }
  150. // nextBlock will synchronize and start compressing input in e.state.filling.
  151. // If an error has occurred during encoding it will be returned.
  152. func (e *Encoder) nextBlock(final bool) error {
  153. s := &e.state
  154. // Wait for current block.
  155. s.wg.Wait()
  156. if s.err != nil {
  157. return s.err
  158. }
  159. if len(s.filling) > e.o.blockSize {
  160. return fmt.Errorf("block > maxStoreBlockSize")
  161. }
  162. if !s.headerWritten {
  163. var tmp [maxHeaderSize]byte
  164. fh := frameHeader{
  165. ContentSize: 0,
  166. WindowSize: uint32(s.encoder.WindowSize(0)),
  167. SingleSegment: false,
  168. Checksum: e.o.crc,
  169. DictID: 0,
  170. }
  171. dst, err := fh.appendTo(tmp[:0])
  172. if err != nil {
  173. return err
  174. }
  175. s.headerWritten = true
  176. s.wWg.Wait()
  177. var n2 int
  178. n2, s.err = s.w.Write(dst)
  179. if s.err != nil {
  180. return s.err
  181. }
  182. s.nWritten += int64(n2)
  183. }
  184. if s.eofWritten {
  185. // Ensure we only write it once.
  186. final = false
  187. }
  188. if len(s.filling) == 0 {
  189. // Final block, but no data.
  190. if final {
  191. enc := s.encoder
  192. blk := enc.Block()
  193. blk.reset(nil)
  194. blk.last = true
  195. blk.encodeRaw(nil)
  196. s.wWg.Wait()
  197. _, s.err = s.w.Write(blk.output)
  198. s.nWritten += int64(len(blk.output))
  199. s.eofWritten = true
  200. }
  201. return s.err
  202. }
  203. // Move blocks forward.
  204. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
  205. s.wg.Add(1)
  206. go func(src []byte) {
  207. if debug {
  208. println("Adding block,", len(src), "bytes, final:", final)
  209. }
  210. defer func() {
  211. if r := recover(); r != nil {
  212. s.err = fmt.Errorf("panic while encoding: %v", r)
  213. rdebug.PrintStack()
  214. }
  215. s.wg.Done()
  216. }()
  217. enc := s.encoder
  218. blk := enc.Block()
  219. enc.Encode(blk, src)
  220. blk.last = final
  221. if final {
  222. s.eofWritten = true
  223. }
  224. // Wait for pending writes.
  225. s.wWg.Wait()
  226. if s.writeErr != nil {
  227. s.err = s.writeErr
  228. return
  229. }
  230. // Transfer encoders from previous write block.
  231. blk.swapEncoders(s.writing)
  232. // Transfer recent offsets to next.
  233. enc.UseBlock(s.writing)
  234. s.writing = blk
  235. s.wWg.Add(1)
  236. go func() {
  237. defer func() {
  238. if r := recover(); r != nil {
  239. s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
  240. rdebug.PrintStack()
  241. }
  242. s.wWg.Done()
  243. }()
  244. err := errIncompressible
  245. // If we got the exact same number of literals as input,
  246. // assume the literals cannot be compressed.
  247. if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
  248. err = blk.encode(e.o.noEntropy)
  249. }
  250. switch err {
  251. case errIncompressible:
  252. if debug {
  253. println("Storing incompressible block as raw")
  254. }
  255. blk.encodeRaw(src)
  256. // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
  257. case nil:
  258. default:
  259. s.writeErr = err
  260. return
  261. }
  262. _, s.writeErr = s.w.Write(blk.output)
  263. s.nWritten += int64(len(blk.output))
  264. }()
  265. }(s.current)
  266. return nil
  267. }
  268. // ReadFrom reads data from r until EOF or error.
  269. // The return value n is the number of bytes read.
  270. // Any error except io.EOF encountered during the read is also returned.
  271. //
  272. // The Copy function uses ReaderFrom if available.
  273. func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
  274. if debug {
  275. println("Using ReadFrom")
  276. }
  277. // Maybe handle stuff queued?
  278. e.state.filling = e.state.filling[:e.o.blockSize]
  279. src := e.state.filling
  280. for {
  281. n2, err := r.Read(src)
  282. _, _ = e.state.encoder.CRC().Write(src[:n2])
  283. // src is now the unfilled part...
  284. src = src[n2:]
  285. n += int64(n2)
  286. switch err {
  287. case io.EOF:
  288. e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
  289. if debug {
  290. println("ReadFrom: got EOF final block:", len(e.state.filling))
  291. }
  292. return n, e.nextBlock(true)
  293. default:
  294. if debug {
  295. println("ReadFrom: got error:", err)
  296. }
  297. e.state.err = err
  298. return n, err
  299. case nil:
  300. }
  301. if len(src) > 0 {
  302. if debug {
  303. println("ReadFrom: got space left in source:", len(src))
  304. }
  305. continue
  306. }
  307. err = e.nextBlock(false)
  308. if err != nil {
  309. return n, err
  310. }
  311. e.state.filling = e.state.filling[:e.o.blockSize]
  312. src = e.state.filling
  313. }
  314. }
  315. // Flush will send the currently written data to output
  316. // and block until everything has been written.
  317. // This should only be used on rare occasions where pushing the currently queued data is critical.
  318. func (e *Encoder) Flush() error {
  319. s := &e.state
  320. if len(s.filling) > 0 {
  321. err := e.nextBlock(false)
  322. if err != nil {
  323. return err
  324. }
  325. }
  326. s.wg.Wait()
  327. s.wWg.Wait()
  328. if s.err != nil {
  329. return s.err
  330. }
  331. return s.writeErr
  332. }
  333. // Close will flush the final output and close the stream.
  334. // The function will block until everything has been written.
  335. // The Encoder can still be re-used after calling this.
  336. func (e *Encoder) Close() error {
  337. s := &e.state
  338. if s.encoder == nil {
  339. return nil
  340. }
  341. err := e.nextBlock(true)
  342. if err != nil {
  343. return err
  344. }
  345. s.wg.Wait()
  346. s.wWg.Wait()
  347. if s.err != nil {
  348. return s.err
  349. }
  350. if s.writeErr != nil {
  351. return s.writeErr
  352. }
  353. // Write CRC
  354. if e.o.crc && s.err == nil {
  355. // heap alloc.
  356. var tmp [4]byte
  357. _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
  358. s.nWritten += 4
  359. }
  360. // Add padding with content from crypto/rand.Reader
  361. if s.err == nil && e.o.pad > 0 {
  362. add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
  363. frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
  364. if err != nil {
  365. return err
  366. }
  367. _, s.err = s.w.Write(frame)
  368. }
  369. return s.err
  370. }
  371. // EncodeAll will encode all input in src and append it to dst.
  372. // This function can be called concurrently, but each call will only run on a single goroutine.
  373. // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
  374. // Encoded blocks can be concatenated and the result will be the combined input stream.
  375. // Data compressed with EncodeAll can be decoded with the Decoder,
  376. // using either a stream or DecodeAll.
  377. func (e *Encoder) EncodeAll(src, dst []byte) []byte {
  378. if len(src) == 0 {
  379. if e.o.fullZero {
  380. // Add frame header.
  381. fh := frameHeader{
  382. ContentSize: 0,
  383. WindowSize: MinWindowSize,
  384. SingleSegment: true,
  385. // Adding a checksum would be a waste of space.
  386. Checksum: false,
  387. DictID: 0,
  388. }
  389. dst, _ = fh.appendTo(dst)
  390. // Write raw block as last one only.
  391. var blk blockHeader
  392. blk.setSize(0)
  393. blk.setType(blockTypeRaw)
  394. blk.setLast(true)
  395. dst = blk.appendTo(dst)
  396. }
  397. return dst
  398. }
  399. e.init.Do(e.initialize)
  400. enc := <-e.encoders
  401. defer func() {
  402. // Release encoder reference to last block.
  403. enc.Reset()
  404. e.encoders <- enc
  405. }()
  406. enc.Reset()
  407. blk := enc.Block()
  408. // Use single segments when above minimum window and below 1MB.
  409. single := len(src) < 1<<20 && len(src) > MinWindowSize
  410. if e.o.single != nil {
  411. single = *e.o.single
  412. }
  413. fh := frameHeader{
  414. ContentSize: uint64(len(src)),
  415. WindowSize: uint32(enc.WindowSize(len(src))),
  416. SingleSegment: single,
  417. Checksum: e.o.crc,
  418. DictID: 0,
  419. }
  420. // If less than 1MB, allocate a buffer up front.
  421. if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 {
  422. dst = make([]byte, 0, len(src))
  423. }
  424. dst, err := fh.appendTo(dst)
  425. if err != nil {
  426. panic(err)
  427. }
  428. if len(src) <= e.o.blockSize && len(src) <= maxBlockSize {
  429. // Slightly faster with no history and everything in one block.
  430. if e.o.crc {
  431. _, _ = enc.CRC().Write(src)
  432. }
  433. blk.reset(nil)
  434. blk.last = true
  435. enc.EncodeNoHist(blk, src)
  436. // If we got the exact same number of literals as input,
  437. // assume the literals cannot be compressed.
  438. err := errIncompressible
  439. oldout := blk.output
  440. if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
  441. // Output directly to dst
  442. blk.output = dst
  443. err = blk.encode(e.o.noEntropy)
  444. }
  445. switch err {
  446. case errIncompressible:
  447. if debug {
  448. println("Storing incompressible block as raw")
  449. }
  450. dst = blk.encodeRawTo(dst, src)
  451. case nil:
  452. dst = blk.output
  453. default:
  454. panic(err)
  455. }
  456. blk.output = oldout
  457. } else {
  458. for len(src) > 0 {
  459. todo := src
  460. if len(todo) > e.o.blockSize {
  461. todo = todo[:e.o.blockSize]
  462. }
  463. src = src[len(todo):]
  464. if e.o.crc {
  465. _, _ = enc.CRC().Write(todo)
  466. }
  467. blk.reset(nil)
  468. blk.pushOffsets()
  469. enc.Encode(blk, todo)
  470. if len(src) == 0 {
  471. blk.last = true
  472. }
  473. err := errIncompressible
  474. // If we got the exact same number of literals as input,
  475. // assume the literals cannot be compressed.
  476. if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
  477. err = blk.encode(e.o.noEntropy)
  478. }
  479. switch err {
  480. case errIncompressible:
  481. if debug {
  482. println("Storing incompressible block as raw")
  483. }
  484. dst = blk.encodeRawTo(dst, todo)
  485. blk.popOffsets()
  486. case nil:
  487. dst = append(dst, blk.output...)
  488. default:
  489. panic(err)
  490. }
  491. }
  492. }
  493. if e.o.crc {
  494. dst = enc.AppendCRC(dst)
  495. }
  496. // Add padding with content from crypto/rand.Reader
  497. if e.o.pad > 0 {
  498. add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
  499. dst, err = skippableFrame(dst, add, rand.Reader)
  500. if err != nil {
  501. panic(err)
  502. }
  503. }
  504. return dst
  505. }