| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package websocket
- import (
- "compress/flate"
- "errors"
- "io"
- "strings"
- )
- func decompressNoContextTakeover(r io.Reader) io.Reader {
- const tail =
- // Add four bytes as specified in RFC
- "\x00\x00\xff\xff" +
- // Add final block to squelch unexpected EOF error from flate reader.
- "\x01\x00\x00\xff\xff"
- return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
- }
- func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
- tw := &truncWriter{w: w}
- fw, err := flate.NewWriter(tw, 3)
- return &flateWrapper{fw: fw, tw: tw}, err
- }
- // truncWriter is an io.Writer that writes all but the last four bytes of the
- // stream to another io.Writer.
- type truncWriter struct {
- w io.WriteCloser
- n int
- p [4]byte
- }
- func (w *truncWriter) Write(p []byte) (int, error) {
- n := 0
- // fill buffer first for simplicity.
- if w.n < len(w.p) {
- n = copy(w.p[w.n:], p)
- p = p[n:]
- w.n += n
- if len(p) == 0 {
- return n, nil
- }
- }
- m := len(p)
- if m > len(w.p) {
- m = len(w.p)
- }
- if nn, err := w.w.Write(w.p[:m]); err != nil {
- return n + nn, err
- }
- copy(w.p[:], w.p[m:])
- copy(w.p[len(w.p)-m:], p[len(p)-m:])
- nn, err := w.w.Write(p[:len(p)-m])
- return n + nn, err
- }
- type flateWrapper struct {
- fw *flate.Writer
- tw *truncWriter
- }
- func (w *flateWrapper) Write(p []byte) (int, error) {
- return w.fw.Write(p)
- }
- func (w *flateWrapper) Close() error {
- err1 := w.fw.Flush()
- if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
- return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
- }
- err2 := w.tw.w.Close()
- if err1 != nil {
- return err1
- }
- return err2
- }
|