compression.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "compress/flate"
  7. "errors"
  8. "io"
  9. "strings"
  10. "sync"
  11. )
  12. var (
  13. flateWriterPool = sync.Pool{}
  14. flateReaderPool = sync.Pool{}
  15. )
  16. func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
  17. const tail =
  18. // Add four bytes as specified in RFC
  19. "\x00\x00\xff\xff" +
  20. // Add final block to squelch unexpected EOF error from flate reader.
  21. "\x01\x00\x00\xff\xff"
  22. i := flateReaderPool.Get()
  23. if i == nil {
  24. i = flate.NewReader(nil)
  25. }
  26. i.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
  27. return &flateReadWrapper{i.(io.ReadCloser)}
  28. }
  29. func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
  30. tw := &truncWriter{w: w}
  31. i := flateWriterPool.Get()
  32. var fw *flate.Writer
  33. var err error
  34. if i == nil {
  35. fw, err = flate.NewWriter(tw, 3)
  36. } else {
  37. fw = i.(*flate.Writer)
  38. fw.Reset(tw)
  39. }
  40. return &flateWriteWrapper{fw: fw, tw: tw}, err
  41. }
  42. // truncWriter is an io.Writer that writes all but the last four bytes of the
  43. // stream to another io.Writer.
  44. type truncWriter struct {
  45. w io.WriteCloser
  46. n int
  47. p [4]byte
  48. }
  49. func (w *truncWriter) Write(p []byte) (int, error) {
  50. n := 0
  51. // fill buffer first for simplicity.
  52. if w.n < len(w.p) {
  53. n = copy(w.p[w.n:], p)
  54. p = p[n:]
  55. w.n += n
  56. if len(p) == 0 {
  57. return n, nil
  58. }
  59. }
  60. m := len(p)
  61. if m > len(w.p) {
  62. m = len(w.p)
  63. }
  64. if nn, err := w.w.Write(w.p[:m]); err != nil {
  65. return n + nn, err
  66. }
  67. copy(w.p[:], w.p[m:])
  68. copy(w.p[len(w.p)-m:], p[len(p)-m:])
  69. nn, err := w.w.Write(p[:len(p)-m])
  70. return n + nn, err
  71. }
  72. type flateWriteWrapper struct {
  73. fw *flate.Writer
  74. tw *truncWriter
  75. }
  76. func (w *flateWriteWrapper) Write(p []byte) (int, error) {
  77. if w.fw == nil {
  78. return 0, errWriteClosed
  79. }
  80. return w.fw.Write(p)
  81. }
  82. func (w *flateWriteWrapper) Close() error {
  83. if w.fw == nil {
  84. return errWriteClosed
  85. }
  86. err1 := w.fw.Flush()
  87. flateWriterPool.Put(w.fw)
  88. w.fw = nil
  89. if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
  90. return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
  91. }
  92. err2 := w.tw.w.Close()
  93. if err1 != nil {
  94. return err1
  95. }
  96. return err2
  97. }
  98. type flateReadWrapper struct {
  99. fr io.ReadCloser
  100. }
  101. func (r *flateReadWrapper) Read(p []byte) (int, error) {
  102. if r.fr == nil {
  103. return 0, io.ErrClosedPipe
  104. }
  105. n, err := r.fr.Read(p)
  106. if err == io.EOF {
  107. // Preemptively place the reader back in the pool. This helps with
  108. // scenarios where the application does not call NextReader() soon after
  109. // this final read.
  110. r.Close()
  111. }
  112. return n, err
  113. }
  114. func (r *flateReadWrapper) Close() error {
  115. if r.fr == nil {
  116. return io.ErrClosedPipe
  117. }
  118. err := r.fr.Close()
  119. flateReaderPool.Put(r.fr)
  120. r.fr = nil
  121. return err
  122. }