seqdec.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. )
  10. type seq struct {
  11. litLen uint32
  12. matchLen uint32
  13. offset uint32
  14. // Codes are stored here for the encoder
  15. // so they only have to be looked up once.
  16. llCode, mlCode, ofCode uint8
  17. }
  18. func (s seq) String() string {
  19. if s.offset <= 3 {
  20. if s.offset == 0 {
  21. return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset: INVALID (0)")
  22. }
  23. return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset, " (repeat)")
  24. }
  25. return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset-3, " (new)")
  26. }
  27. type seqCompMode uint8
  28. const (
  29. compModePredefined seqCompMode = iota
  30. compModeRLE
  31. compModeFSE
  32. compModeRepeat
  33. )
  34. type sequenceDec struct {
  35. // decoder keeps track of the current state and updates it from the bitstream.
  36. fse *fseDecoder
  37. state fseState
  38. repeat bool
  39. }
  40. // init the state of the decoder with input from stream.
  41. func (s *sequenceDec) init(br *bitReader) error {
  42. if s.fse == nil {
  43. return errors.New("sequence decoder not defined")
  44. }
  45. s.state.init(br, s.fse.actualTableLog, s.fse.dt[:1<<s.fse.actualTableLog])
  46. return nil
  47. }
  48. // sequenceDecs contains all 3 sequence decoders and their state.
  49. type sequenceDecs struct {
  50. litLengths sequenceDec
  51. offsets sequenceDec
  52. matchLengths sequenceDec
  53. prevOffset [3]int
  54. hist []byte
  55. literals []byte
  56. out []byte
  57. maxBits uint8
  58. }
  59. // initialize all 3 decoders from the stream input.
  60. func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out []byte) error {
  61. if err := s.litLengths.init(br); err != nil {
  62. return errors.New("litLengths:" + err.Error())
  63. }
  64. if err := s.offsets.init(br); err != nil {
  65. return errors.New("offsets:" + err.Error())
  66. }
  67. if err := s.matchLengths.init(br); err != nil {
  68. return errors.New("matchLengths:" + err.Error())
  69. }
  70. s.literals = literals
  71. s.hist = hist.b
  72. s.prevOffset = hist.recentOffsets
  73. s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits
  74. s.out = out
  75. return nil
  76. }
  77. // decode sequences from the stream with the provided history.
  78. func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error {
  79. startSize := len(s.out)
  80. // Grab full sizes tables, to avoid bounds checks.
  81. llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize]
  82. llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
  83. for i := seqs - 1; i >= 0; i-- {
  84. if br.overread() {
  85. printf("reading sequence %d, exceeded available data\n", seqs-i)
  86. return io.ErrUnexpectedEOF
  87. }
  88. var litLen, matchOff, matchLen int
  89. if br.off > 4+((maxOffsetBits+16+16)>>3) {
  90. litLen, matchOff, matchLen = s.nextFast(br, llState, mlState, ofState)
  91. br.fillFast()
  92. } else {
  93. litLen, matchOff, matchLen = s.next(br, llState, mlState, ofState)
  94. br.fill()
  95. }
  96. if debugSequences {
  97. println("Seq", seqs-i-1, "Litlen:", litLen, "matchOff:", matchOff, "(abs) matchLen:", matchLen)
  98. }
  99. if litLen > len(s.literals) {
  100. return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", litLen, len(s.literals))
  101. }
  102. size := litLen + matchLen + len(s.out)
  103. if size-startSize > maxBlockSize {
  104. return fmt.Errorf("output (%d) bigger than max block size", size)
  105. }
  106. if size > cap(s.out) {
  107. // Not enough size, will be extremely rarely triggered,
  108. // but could be if destination slice is too small for sync operations.
  109. // We add maxBlockSize to the capacity.
  110. s.out = append(s.out, make([]byte, maxBlockSize)...)
  111. s.out = s.out[:len(s.out)-maxBlockSize]
  112. }
  113. if matchLen > maxMatchLen {
  114. return fmt.Errorf("match len (%d) bigger than max allowed length", matchLen)
  115. }
  116. if matchOff > len(s.out)+len(hist)+litLen {
  117. return fmt.Errorf("match offset (%d) bigger than current history (%d)", matchOff, len(s.out)+len(hist)+litLen)
  118. }
  119. if matchOff == 0 && matchLen > 0 {
  120. return fmt.Errorf("zero matchoff and matchlen > 0")
  121. }
  122. s.out = append(s.out, s.literals[:litLen]...)
  123. s.literals = s.literals[litLen:]
  124. out := s.out
  125. // Copy from history.
  126. // TODO: Blocks without history could be made to ignore this completely.
  127. if v := matchOff - len(s.out); v > 0 {
  128. // v is the start position in history from end.
  129. start := len(s.hist) - v
  130. if matchLen > v {
  131. // Some goes into current block.
  132. // Copy remainder of history
  133. out = append(out, s.hist[start:]...)
  134. matchOff -= v
  135. matchLen -= v
  136. } else {
  137. out = append(out, s.hist[start:start+matchLen]...)
  138. matchLen = 0
  139. }
  140. }
  141. // We must be in current buffer now
  142. if matchLen > 0 {
  143. start := len(s.out) - matchOff
  144. if matchLen <= len(s.out)-start {
  145. // No overlap
  146. out = append(out, s.out[start:start+matchLen]...)
  147. } else {
  148. // Overlapping copy
  149. // Extend destination slice and copy one byte at the time.
  150. out = out[:len(out)+matchLen]
  151. src := out[start : start+matchLen]
  152. // Destination is the space we just added.
  153. dst := out[len(out)-matchLen:]
  154. dst = dst[:len(src)]
  155. for i := range src {
  156. dst[i] = src[i]
  157. }
  158. }
  159. }
  160. s.out = out
  161. if i == 0 {
  162. // This is the last sequence, so we shouldn't update state.
  163. break
  164. }
  165. // Manually inlined, ~ 5-20% faster
  166. // Update all 3 states at once. Approx 20% faster.
  167. nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits()
  168. if nBits == 0 {
  169. llState = llTable[llState.newState()&maxTableMask]
  170. mlState = mlTable[mlState.newState()&maxTableMask]
  171. ofState = ofTable[ofState.newState()&maxTableMask]
  172. } else {
  173. bits := br.getBitsFast(nBits)
  174. lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
  175. llState = llTable[(llState.newState()+lowBits)&maxTableMask]
  176. lowBits = uint16(bits >> (ofState.nbBits() & 31))
  177. lowBits &= bitMask[mlState.nbBits()&15]
  178. mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask]
  179. lowBits = uint16(bits) & bitMask[ofState.nbBits()&15]
  180. ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask]
  181. }
  182. }
  183. // Add final literals
  184. s.out = append(s.out, s.literals...)
  185. return nil
  186. }
  187. // update states, at least 27 bits must be available.
  188. func (s *sequenceDecs) update(br *bitReader) {
  189. // Max 8 bits
  190. s.litLengths.state.next(br)
  191. // Max 9 bits
  192. s.matchLengths.state.next(br)
  193. // Max 8 bits
  194. s.offsets.state.next(br)
  195. }
  196. var bitMask [16]uint16
  197. func init() {
  198. for i := range bitMask[:] {
  199. bitMask[i] = uint16((1 << uint(i)) - 1)
  200. }
  201. }
  202. // update states, at least 27 bits must be available.
  203. func (s *sequenceDecs) updateAlt(br *bitReader) {
  204. // Update all 3 states at once. Approx 20% faster.
  205. a, b, c := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
  206. nBits := a.nbBits() + b.nbBits() + c.nbBits()
  207. if nBits == 0 {
  208. s.litLengths.state.state = s.litLengths.state.dt[a.newState()]
  209. s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()]
  210. s.offsets.state.state = s.offsets.state.dt[c.newState()]
  211. return
  212. }
  213. bits := br.getBitsFast(nBits)
  214. lowBits := uint16(bits >> ((c.nbBits() + b.nbBits()) & 31))
  215. s.litLengths.state.state = s.litLengths.state.dt[a.newState()+lowBits]
  216. lowBits = uint16(bits >> (c.nbBits() & 31))
  217. lowBits &= bitMask[b.nbBits()&15]
  218. s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()+lowBits]
  219. lowBits = uint16(bits) & bitMask[c.nbBits()&15]
  220. s.offsets.state.state = s.offsets.state.dt[c.newState()+lowBits]
  221. }
  222. // nextFast will return new states when there are at least 4 unused bytes left on the stream when done.
  223. func (s *sequenceDecs) nextFast(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
  224. // Final will not read from stream.
  225. ll, llB := llState.final()
  226. ml, mlB := mlState.final()
  227. mo, moB := ofState.final()
  228. // extra bits are stored in reverse order.
  229. br.fillFast()
  230. mo += br.getBits(moB)
  231. if s.maxBits > 32 {
  232. br.fillFast()
  233. }
  234. ml += br.getBits(mlB)
  235. ll += br.getBits(llB)
  236. if moB > 1 {
  237. s.prevOffset[2] = s.prevOffset[1]
  238. s.prevOffset[1] = s.prevOffset[0]
  239. s.prevOffset[0] = mo
  240. return
  241. }
  242. // mo = s.adjustOffset(mo, ll, moB)
  243. // Inlined for rather big speedup
  244. if ll == 0 {
  245. // There is an exception though, when current sequence's literals_length = 0.
  246. // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
  247. // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
  248. mo++
  249. }
  250. if mo == 0 {
  251. mo = s.prevOffset[0]
  252. return
  253. }
  254. var temp int
  255. if mo == 3 {
  256. temp = s.prevOffset[0] - 1
  257. } else {
  258. temp = s.prevOffset[mo]
  259. }
  260. if temp == 0 {
  261. // 0 is not valid; input is corrupted; force offset to 1
  262. println("temp was 0")
  263. temp = 1
  264. }
  265. if mo != 1 {
  266. s.prevOffset[2] = s.prevOffset[1]
  267. }
  268. s.prevOffset[1] = s.prevOffset[0]
  269. s.prevOffset[0] = temp
  270. mo = temp
  271. return
  272. }
  273. func (s *sequenceDecs) next(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
  274. // Final will not read from stream.
  275. ll, llB := llState.final()
  276. ml, mlB := mlState.final()
  277. mo, moB := ofState.final()
  278. // extra bits are stored in reverse order.
  279. br.fill()
  280. if s.maxBits <= 32 {
  281. mo += br.getBits(moB)
  282. ml += br.getBits(mlB)
  283. ll += br.getBits(llB)
  284. } else {
  285. mo += br.getBits(moB)
  286. br.fill()
  287. // matchlength+literal length, max 32 bits
  288. ml += br.getBits(mlB)
  289. ll += br.getBits(llB)
  290. }
  291. mo = s.adjustOffset(mo, ll, moB)
  292. return
  293. }
  294. func (s *sequenceDecs) adjustOffset(offset, litLen int, offsetB uint8) int {
  295. if offsetB > 1 {
  296. s.prevOffset[2] = s.prevOffset[1]
  297. s.prevOffset[1] = s.prevOffset[0]
  298. s.prevOffset[0] = offset
  299. return offset
  300. }
  301. if litLen == 0 {
  302. // There is an exception though, when current sequence's literals_length = 0.
  303. // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
  304. // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
  305. offset++
  306. }
  307. if offset == 0 {
  308. return s.prevOffset[0]
  309. }
  310. var temp int
  311. if offset == 3 {
  312. temp = s.prevOffset[0] - 1
  313. } else {
  314. temp = s.prevOffset[offset]
  315. }
  316. if temp == 0 {
  317. // 0 is not valid; input is corrupted; force offset to 1
  318. println("temp was 0")
  319. temp = 1
  320. }
  321. if offset != 1 {
  322. s.prevOffset[2] = s.prevOffset[1]
  323. }
  324. s.prevOffset[1] = s.prevOffset[0]
  325. s.prevOffset[0] = temp
  326. return temp
  327. }
  328. // mergeHistory will merge history.
  329. func (s *sequenceDecs) mergeHistory(hist *sequenceDecs) (*sequenceDecs, error) {
  330. for i := uint(0); i < 3; i++ {
  331. var sNew, sHist *sequenceDec
  332. switch i {
  333. default:
  334. // same as "case 0":
  335. sNew = &s.litLengths
  336. sHist = &hist.litLengths
  337. case 1:
  338. sNew = &s.offsets
  339. sHist = &hist.offsets
  340. case 2:
  341. sNew = &s.matchLengths
  342. sHist = &hist.matchLengths
  343. }
  344. if sNew.repeat {
  345. if sHist.fse == nil {
  346. return nil, fmt.Errorf("sequence stream %d, repeat requested, but no history", i)
  347. }
  348. continue
  349. }
  350. if sNew.fse == nil {
  351. return nil, fmt.Errorf("sequence stream %d, no fse found", i)
  352. }
  353. if sHist.fse != nil && !sHist.fse.preDefined {
  354. fseDecoderPool.Put(sHist.fse)
  355. }
  356. sHist.fse = sNew.fse
  357. }
  358. return hist, nil
  359. }