compression.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "compress/flate"
  7. "errors"
  8. "io"
  9. "strings"
  10. "sync"
  11. )
  12. var (
  13. flateWriterPool = sync.Pool{New: func() interface{} {
  14. fw, _ := flate.NewWriter(nil, 3)
  15. return fw
  16. }}
  17. flateReaderPool = sync.Pool{New: func() interface{} {
  18. return flate.NewReader(nil)
  19. }}
  20. )
  21. func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
  22. const tail =
  23. // Add four bytes as specified in RFC
  24. "\x00\x00\xff\xff" +
  25. // Add final block to squelch unexpected EOF error from flate reader.
  26. "\x01\x00\x00\xff\xff"
  27. fr, _ := flateReaderPool.Get().(io.ReadCloser)
  28. fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
  29. return &flateReadWrapper{fr}
  30. }
  31. func compressNoContextTakeover(w io.WriteCloser) io.WriteCloser {
  32. tw := &truncWriter{w: w}
  33. fw, _ := flateWriterPool.Get().(*flate.Writer)
  34. fw.Reset(tw)
  35. return &flateWriteWrapper{fw: fw, tw: tw}
  36. }
  37. // truncWriter is an io.Writer that writes all but the last four bytes of the
  38. // stream to another io.Writer.
  39. type truncWriter struct {
  40. w io.WriteCloser
  41. n int
  42. p [4]byte
  43. }
  44. func (w *truncWriter) Write(p []byte) (int, error) {
  45. n := 0
  46. // fill buffer first for simplicity.
  47. if w.n < len(w.p) {
  48. n = copy(w.p[w.n:], p)
  49. p = p[n:]
  50. w.n += n
  51. if len(p) == 0 {
  52. return n, nil
  53. }
  54. }
  55. m := len(p)
  56. if m > len(w.p) {
  57. m = len(w.p)
  58. }
  59. if nn, err := w.w.Write(w.p[:m]); err != nil {
  60. return n + nn, err
  61. }
  62. copy(w.p[:], w.p[m:])
  63. copy(w.p[len(w.p)-m:], p[len(p)-m:])
  64. nn, err := w.w.Write(p[:len(p)-m])
  65. return n + nn, err
  66. }
  67. type flateWriteWrapper struct {
  68. fw *flate.Writer
  69. tw *truncWriter
  70. }
  71. func (w *flateWriteWrapper) Write(p []byte) (int, error) {
  72. if w.fw == nil {
  73. return 0, errWriteClosed
  74. }
  75. return w.fw.Write(p)
  76. }
  77. func (w *flateWriteWrapper) Close() error {
  78. if w.fw == nil {
  79. return errWriteClosed
  80. }
  81. err1 := w.fw.Flush()
  82. flateWriterPool.Put(w.fw)
  83. w.fw = nil
  84. if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
  85. return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
  86. }
  87. err2 := w.tw.w.Close()
  88. if err1 != nil {
  89. return err1
  90. }
  91. return err2
  92. }
  93. type flateReadWrapper struct {
  94. fr io.ReadCloser
  95. }
  96. func (r *flateReadWrapper) Read(p []byte) (int, error) {
  97. if r.fr == nil {
  98. return 0, io.ErrClosedPipe
  99. }
  100. n, err := r.fr.Read(p)
  101. if err == io.EOF {
  102. // Preemptively place the reader back in the pool. This helps with
  103. // scenarios where the application does not call NextReader() soon after
  104. // this final read.
  105. r.Close()
  106. }
  107. return n, err
  108. }
  109. func (r *flateReadWrapper) Close() error {
  110. if r.fr == nil {
  111. return io.ErrClosedPipe
  112. }
  113. err := r.fr.Close()
  114. flateReaderPool.Put(r.fr)
  115. r.fr = nil
  116. return err
  117. }