conn_test.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. // Copyright 2013 Gary Burd. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bytes"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "net"
  11. "testing"
  12. "testing/iotest"
  13. "time"
  14. )
  15. type fakeNetConn struct {
  16. io.Reader
  17. io.Writer
  18. }
  19. func (c fakeNetConn) Close() error { return nil }
  20. func (c fakeNetConn) LocalAddr() net.Addr { return nil }
  21. func (c fakeNetConn) RemoteAddr() net.Addr { return nil }
  22. func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
  23. func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
  24. func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
  25. func TestFraming(t *testing.T) {
  26. frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
  27. var readChunkers = []struct {
  28. name string
  29. f func(io.Reader) io.Reader
  30. }{
  31. {"half", iotest.HalfReader},
  32. {"one", iotest.OneByteReader},
  33. {"asis", func(r io.Reader) io.Reader { return r }},
  34. }
  35. writeBuf := make([]byte, 65537)
  36. for i := range writeBuf {
  37. writeBuf[i] = byte(i)
  38. }
  39. for _, isServer := range []bool{true, false} {
  40. for _, chunker := range readChunkers {
  41. var connBuf bytes.Buffer
  42. wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
  43. rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
  44. for _, n := range frameSizes {
  45. for _, iocopy := range []bool{true, false} {
  46. name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy)
  47. w, err := wc.NextWriter(TextMessage)
  48. if err != nil {
  49. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  50. continue
  51. }
  52. var nn int
  53. if iocopy {
  54. var n64 int64
  55. n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
  56. nn = int(n64)
  57. } else {
  58. nn, err = w.Write(writeBuf[:n])
  59. }
  60. if err != nil || nn != n {
  61. t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
  62. continue
  63. }
  64. err = w.Close()
  65. if err != nil {
  66. t.Errorf("%s: w.Close() returned %v", name, err)
  67. continue
  68. }
  69. opCode, r, err := rc.NextReader()
  70. if err != nil || opCode != TextMessage {
  71. t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
  72. continue
  73. }
  74. rbuf, err := ioutil.ReadAll(r)
  75. if err != nil {
  76. t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
  77. continue
  78. }
  79. if len(rbuf) != n {
  80. t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
  81. continue
  82. }
  83. for i, b := range rbuf {
  84. if byte(i) != b {
  85. t.Errorf("%s: bad byte at offset %d", name, i)
  86. break
  87. }
  88. }
  89. }
  90. }
  91. }
  92. }
  93. }
  94. func TestReadLimit(t *testing.T) {
  95. const readLimit = 512
  96. message := make([]byte, readLimit+1)
  97. var b1, b2 bytes.Buffer
  98. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
  99. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
  100. rc.SetReadLimit(readLimit)
  101. // Send message at the limit with interleaved pong.
  102. w, _ := wc.NextWriter(BinaryMessage)
  103. w.Write(message[:readLimit-1])
  104. wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
  105. w.Write(message[:1])
  106. w.Close()
  107. // Send message larger than the limit.
  108. wc.WriteMessage(BinaryMessage, message[:readLimit+1])
  109. op, _, err := rc.NextReader()
  110. if op != BinaryMessage || err != nil {
  111. t.Fatalf("1: NextReader() returned %d, %v", op, err)
  112. }
  113. op, r, err := rc.NextReader()
  114. if op != BinaryMessage || err != nil {
  115. t.Fatalf("2: NextReader() returned %d, %v", op, err)
  116. }
  117. _, err = io.Copy(ioutil.Discard, r)
  118. if err != ErrReadLimit {
  119. t.Fatalf("io.Copy() returned %v", err)
  120. }
  121. }