buffer.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. // Copyright 2016 The CMux Authors. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
  12. // implied. See the License for the specific language governing
  13. // permissions and limitations under the License.
  14. package cmux
  15. import (
  16. "bytes"
  17. "io"
  18. )
  19. // bufferedReader is an optimized implementation of io.Reader that behaves like
  20. // ```
  21. // io.MultiReader(bytes.NewReader(buffer.Bytes()), io.TeeReader(source, buffer))
  22. // ```
  23. // without allocating.
  24. type bufferedReader struct {
  25. source io.Reader
  26. buffer bytes.Buffer
  27. bufferRead int
  28. bufferSize int
  29. sniffing bool
  30. lastErr error
  31. }
  32. func (s *bufferedReader) Read(p []byte) (int, error) {
  33. if s.bufferSize > s.bufferRead {
  34. // If we have already read something from the buffer before, we return the
  35. // same data and the last error if any. We need to immediately return,
  36. // otherwise we may block for ever, if we try to be smart and call
  37. // source.Read() seeking a little bit of more data.
  38. bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize])
  39. s.bufferRead += bn
  40. return bn, s.lastErr
  41. } else if !s.sniffing && s.buffer.Cap() != 0 {
  42. // We don't need the buffer anymore.
  43. // Reset it to release the internal slice.
  44. s.buffer = bytes.Buffer{}
  45. }
  46. // If there is nothing more to return in the sniffed buffer, read from the
  47. // source.
  48. sn, sErr := s.source.Read(p)
  49. if sn > 0 && s.sniffing {
  50. s.lastErr = sErr
  51. if wn, wErr := s.buffer.Write(p[:sn]); wErr != nil {
  52. return wn, wErr
  53. }
  54. }
  55. return sn, sErr
  56. }
  57. func (s *bufferedReader) reset(snif bool) {
  58. s.sniffing = snif
  59. s.bufferRead = 0
  60. s.bufferSize = s.buffer.Len()
  61. }