compression.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "compress/flate"
  7. "errors"
  8. "io"
  9. "strings"
  10. )
  11. func decompressNoContextTakeover(r io.Reader) io.Reader {
  12. const tail =
  13. // Add four bytes as specified in RFC
  14. "\x00\x00\xff\xff" +
  15. // Add final block to squelch unexpected EOF error from flate reader.
  16. "\x01\x00\x00\xff\xff"
  17. return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
  18. }
  19. func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
  20. tw := &truncWriter{w: w}
  21. fw, err := flate.NewWriter(tw, 3)
  22. return &flateWrapper{fw: fw, tw: tw}, err
  23. }
  24. // truncWriter is an io.Writer that writes all but the last four bytes of the
  25. // stream to another io.Writer.
  26. type truncWriter struct {
  27. w io.WriteCloser
  28. n int
  29. p [4]byte
  30. }
  31. func (w *truncWriter) Write(p []byte) (int, error) {
  32. n := 0
  33. // fill buffer first for simplicity.
  34. if w.n < len(w.p) {
  35. n = copy(w.p[w.n:], p)
  36. p = p[n:]
  37. w.n += n
  38. if len(p) == 0 {
  39. return n, nil
  40. }
  41. }
  42. m := len(p)
  43. if m > len(w.p) {
  44. m = len(w.p)
  45. }
  46. if nn, err := w.w.Write(w.p[:m]); err != nil {
  47. return n + nn, err
  48. }
  49. copy(w.p[:], w.p[m:])
  50. copy(w.p[len(w.p)-m:], p[len(p)-m:])
  51. nn, err := w.w.Write(p[:len(p)-m])
  52. return n + nn, err
  53. }
  54. type flateWrapper struct {
  55. fw *flate.Writer
  56. tw *truncWriter
  57. }
  58. func (w *flateWrapper) Write(p []byte) (int, error) {
  59. return w.fw.Write(p)
  60. }
  61. func (w *flateWrapper) Close() error {
  62. err1 := w.fw.Flush()
  63. if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
  64. return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
  65. }
  66. err2 := w.tw.w.Close()
  67. if err1 != nil {
  68. return err1
  69. }
  70. return err2
  71. }