compression.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "compress/flate"
  7. "errors"
  8. "io"
  9. "strings"
  10. "sync"
  11. )
  12. var (
  13. flateWriterPool = sync.Pool{}
  14. )
  15. func decompressNoContextTakeover(r io.Reader) io.Reader {
  16. const tail =
  17. // Add four bytes as specified in RFC
  18. "\x00\x00\xff\xff" +
  19. // Add final block to squelch unexpected EOF error from flate reader.
  20. "\x01\x00\x00\xff\xff"
  21. return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
  22. }
  23. func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
  24. tw := &truncWriter{w: w}
  25. i := flateWriterPool.Get()
  26. var fw *flate.Writer
  27. var err error
  28. if i == nil {
  29. fw, err = flate.NewWriter(tw, 3)
  30. } else {
  31. fw = i.(*flate.Writer)
  32. fw.Reset(tw)
  33. }
  34. return &flateWrapper{fw: fw, tw: tw}, err
  35. }
  36. // truncWriter is an io.Writer that writes all but the last four bytes of the
  37. // stream to another io.Writer.
  38. type truncWriter struct {
  39. w io.WriteCloser
  40. n int
  41. p [4]byte
  42. }
  43. func (w *truncWriter) Write(p []byte) (int, error) {
  44. n := 0
  45. // fill buffer first for simplicity.
  46. if w.n < len(w.p) {
  47. n = copy(w.p[w.n:], p)
  48. p = p[n:]
  49. w.n += n
  50. if len(p) == 0 {
  51. return n, nil
  52. }
  53. }
  54. m := len(p)
  55. if m > len(w.p) {
  56. m = len(w.p)
  57. }
  58. if nn, err := w.w.Write(w.p[:m]); err != nil {
  59. return n + nn, err
  60. }
  61. copy(w.p[:], w.p[m:])
  62. copy(w.p[len(w.p)-m:], p[len(p)-m:])
  63. nn, err := w.w.Write(p[:len(p)-m])
  64. return n + nn, err
  65. }
  66. type flateWrapper struct {
  67. fw *flate.Writer
  68. tw *truncWriter
  69. }
  70. func (w *flateWrapper) Write(p []byte) (int, error) {
  71. if w.fw == nil {
  72. return 0, errWriteClosed
  73. }
  74. return w.fw.Write(p)
  75. }
  76. func (w *flateWrapper) Close() error {
  77. if w.fw == nil {
  78. return errWriteClosed
  79. }
  80. err1 := w.fw.Flush()
  81. flateWriterPool.Put(w.fw)
  82. w.fw = nil
  83. if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
  84. return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
  85. }
  86. err2 := w.tw.w.Close()
  87. if err1 != nil {
  88. return err1
  89. }
  90. return err2
  91. }