gen_inflate.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. // +build generate
  2. //go:generate go run $GOFILE && gofmt -w inflate_gen.go
  3. package main
  4. import (
  5. "os"
  6. "strings"
  7. )
  8. func main() {
  9. f, err := os.Create("inflate_gen.go")
  10. if err != nil {
  11. panic(err)
  12. }
  13. defer f.Close()
  14. types := []string{"*bytes.Buffer", "*bytes.Reader", "*bufio.Reader", "*strings.Reader"}
  15. names := []string{"BytesBuffer", "BytesReader", "BufioReader", "StringsReader"}
  16. imports := []string{"bytes", "bufio", "io", "strings", "math/bits"}
  17. f.WriteString(`// Code generated by go generate gen_inflate.go. DO NOT EDIT.
  18. package flate
  19. import (
  20. `)
  21. for _, imp := range imports {
  22. f.WriteString("\t\"" + imp + "\"\n")
  23. }
  24. f.WriteString(")\n\n")
  25. template := `
  26. // Decode a single Huffman block from f.
  27. // hl and hd are the Huffman states for the lit/length values
  28. // and the distance values, respectively. If hd == nil, using the
  29. // fixed distance encoding associated with fixed Huffman blocks.
  30. func (f *decompressor) $FUNCNAME$() {
  31. const (
  32. stateInit = iota // Zero value must be stateInit
  33. stateDict
  34. )
  35. fr := f.r.($TYPE$)
  36. moreBits := func() error {
  37. c, err := fr.ReadByte()
  38. if err != nil {
  39. return noEOF(err)
  40. }
  41. f.roffset++
  42. f.b |= uint32(c) << f.nb
  43. f.nb += 8
  44. return nil
  45. }
  46. switch f.stepState {
  47. case stateInit:
  48. goto readLiteral
  49. case stateDict:
  50. goto copyHistory
  51. }
  52. readLiteral:
  53. // Read literal and/or (length, distance) according to RFC section 3.2.3.
  54. {
  55. var v int
  56. {
  57. // Inlined v, err := f.huffSym(f.hl)
  58. // Since a huffmanDecoder can be empty or be composed of a degenerate tree
  59. // with single element, huffSym must error on these two edge cases. In both
  60. // cases, the chunks slice will be 0 for the invalid sequence, leading it
  61. // satisfy the n == 0 check below.
  62. n := uint(f.hl.maxRead)
  63. // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers,
  64. // but is smart enough to keep local variables in registers, so use nb and b,
  65. // inline call to moreBits and reassign b,nb back to f on return.
  66. nb, b := f.nb, f.b
  67. for {
  68. for nb < n {
  69. c, err := fr.ReadByte()
  70. if err != nil {
  71. f.b = b
  72. f.nb = nb
  73. f.err = noEOF(err)
  74. return
  75. }
  76. f.roffset++
  77. b |= uint32(c) << (nb & 31)
  78. nb += 8
  79. }
  80. chunk := f.hl.chunks[b&(huffmanNumChunks-1)]
  81. n = uint(chunk & huffmanCountMask)
  82. if n > huffmanChunkBits {
  83. chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask]
  84. n = uint(chunk & huffmanCountMask)
  85. }
  86. if n <= nb {
  87. if n == 0 {
  88. f.b = b
  89. f.nb = nb
  90. if debugDecode {
  91. fmt.Println("huffsym: n==0")
  92. }
  93. f.err = CorruptInputError(f.roffset)
  94. return
  95. }
  96. f.b = b >> (n & 31)
  97. f.nb = nb - n
  98. v = int(chunk >> huffmanValueShift)
  99. break
  100. }
  101. }
  102. }
  103. var n uint // number of bits extra
  104. var length int
  105. var err error
  106. switch {
  107. case v < 256:
  108. f.dict.writeByte(byte(v))
  109. if f.dict.availWrite() == 0 {
  110. f.toRead = f.dict.readFlush()
  111. f.step = (*decompressor).$FUNCNAME$
  112. f.stepState = stateInit
  113. return
  114. }
  115. goto readLiteral
  116. case v == 256:
  117. f.finishBlock()
  118. return
  119. // otherwise, reference to older data
  120. case v < 265:
  121. length = v - (257 - 3)
  122. n = 0
  123. case v < 269:
  124. length = v*2 - (265*2 - 11)
  125. n = 1
  126. case v < 273:
  127. length = v*4 - (269*4 - 19)
  128. n = 2
  129. case v < 277:
  130. length = v*8 - (273*8 - 35)
  131. n = 3
  132. case v < 281:
  133. length = v*16 - (277*16 - 67)
  134. n = 4
  135. case v < 285:
  136. length = v*32 - (281*32 - 131)
  137. n = 5
  138. case v < maxNumLit:
  139. length = 258
  140. n = 0
  141. default:
  142. if debugDecode {
  143. fmt.Println(v, ">= maxNumLit")
  144. }
  145. f.err = CorruptInputError(f.roffset)
  146. return
  147. }
  148. if n > 0 {
  149. for f.nb < n {
  150. if err = moreBits(); err != nil {
  151. if debugDecode {
  152. fmt.Println("morebits n>0:", err)
  153. }
  154. f.err = err
  155. return
  156. }
  157. }
  158. length += int(f.b & uint32(1<<n-1))
  159. f.b >>= n
  160. f.nb -= n
  161. }
  162. var dist int
  163. if f.hd == nil {
  164. for f.nb < 5 {
  165. if err = moreBits(); err != nil {
  166. if debugDecode {
  167. fmt.Println("morebits f.nb<5:", err)
  168. }
  169. f.err = err
  170. return
  171. }
  172. }
  173. dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3)))
  174. f.b >>= 5
  175. f.nb -= 5
  176. } else {
  177. if dist, err = f.huffSym(f.hd); err != nil {
  178. if debugDecode {
  179. fmt.Println("huffsym:", err)
  180. }
  181. f.err = err
  182. return
  183. }
  184. }
  185. switch {
  186. case dist < 4:
  187. dist++
  188. case dist < maxNumDist:
  189. nb := uint(dist-2) >> 1
  190. // have 1 bit in bottom of dist, need nb more.
  191. extra := (dist & 1) << nb
  192. for f.nb < nb {
  193. if err = moreBits(); err != nil {
  194. if debugDecode {
  195. fmt.Println("morebits f.nb<nb:", err)
  196. }
  197. f.err = err
  198. return
  199. }
  200. }
  201. extra |= int(f.b & uint32(1<<nb-1))
  202. f.b >>= nb
  203. f.nb -= nb
  204. dist = 1<<(nb+1) + 1 + extra
  205. default:
  206. if debugDecode {
  207. fmt.Println("dist too big:", dist, maxNumDist)
  208. }
  209. f.err = CorruptInputError(f.roffset)
  210. return
  211. }
  212. // No check on length; encoding can be prescient.
  213. if dist > f.dict.histSize() {
  214. if debugDecode {
  215. fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize())
  216. }
  217. f.err = CorruptInputError(f.roffset)
  218. return
  219. }
  220. f.copyLen, f.copyDist = length, dist
  221. goto copyHistory
  222. }
  223. copyHistory:
  224. // Perform a backwards copy according to RFC section 3.2.3.
  225. {
  226. cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen)
  227. if cnt == 0 {
  228. cnt = f.dict.writeCopy(f.copyDist, f.copyLen)
  229. }
  230. f.copyLen -= cnt
  231. if f.dict.availWrite() == 0 || f.copyLen > 0 {
  232. f.toRead = f.dict.readFlush()
  233. f.step = (*decompressor).$FUNCNAME$ // We need to continue this work
  234. f.stepState = stateDict
  235. return
  236. }
  237. goto readLiteral
  238. }
  239. }
  240. `
  241. for i, t := range types {
  242. s := strings.Replace(template, "$FUNCNAME$", "huffman"+names[i], -1)
  243. s = strings.Replace(s, "$TYPE$", t, -1)
  244. f.WriteString(s)
  245. }
  246. f.WriteString("func (f *decompressor) huffmanBlockDecoder() func() {\n")
  247. f.WriteString("\tswitch f.r.(type) {\n")
  248. for i, t := range types {
  249. f.WriteString("\t\tcase " + t + ":\n")
  250. f.WriteString("\t\t\treturn f.huffman" + names[i] + "\n")
  251. }
  252. f.WriteString("\t\tdefault:\n")
  253. f.WriteString("\t\t\treturn f.huffmanBlockGeneric")
  254. f.WriteString("\t}\n}\n")
  255. }