conn_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bufio"
  7. "bytes"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/ioutil"
  12. "net"
  13. "reflect"
  14. "testing"
  15. "testing/iotest"
  16. "time"
  17. )
  18. var _ net.Error = errWriteTimeout
  19. type fakeNetConn struct {
  20. io.Reader
  21. io.Writer
  22. }
  23. func (c fakeNetConn) Close() error { return nil }
  24. func (c fakeNetConn) LocalAddr() net.Addr { return nil }
  25. func (c fakeNetConn) RemoteAddr() net.Addr { return nil }
  26. func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
  27. func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
  28. func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
  29. func TestFraming(t *testing.T) {
  30. frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
  31. var readChunkers = []struct {
  32. name string
  33. f func(io.Reader) io.Reader
  34. }{
  35. {"half", iotest.HalfReader},
  36. {"one", iotest.OneByteReader},
  37. {"asis", func(r io.Reader) io.Reader { return r }},
  38. }
  39. writeBuf := make([]byte, 65537)
  40. for i := range writeBuf {
  41. writeBuf[i] = byte(i)
  42. }
  43. for _, isServer := range []bool{true, false} {
  44. for _, chunker := range readChunkers {
  45. var connBuf bytes.Buffer
  46. wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
  47. rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
  48. for _, n := range frameSizes {
  49. for _, iocopy := range []bool{true, false} {
  50. name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy)
  51. w, err := wc.NextWriter(TextMessage)
  52. if err != nil {
  53. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  54. continue
  55. }
  56. var nn int
  57. if iocopy {
  58. var n64 int64
  59. n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
  60. nn = int(n64)
  61. } else {
  62. nn, err = w.Write(writeBuf[:n])
  63. }
  64. if err != nil || nn != n {
  65. t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
  66. continue
  67. }
  68. err = w.Close()
  69. if err != nil {
  70. t.Errorf("%s: w.Close() returned %v", name, err)
  71. continue
  72. }
  73. opCode, r, err := rc.NextReader()
  74. if err != nil || opCode != TextMessage {
  75. t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
  76. continue
  77. }
  78. rbuf, err := ioutil.ReadAll(r)
  79. if err != nil {
  80. t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
  81. continue
  82. }
  83. if len(rbuf) != n {
  84. t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
  85. continue
  86. }
  87. for i, b := range rbuf {
  88. if byte(i) != b {
  89. t.Errorf("%s: bad byte at offset %d", name, i)
  90. break
  91. }
  92. }
  93. }
  94. }
  95. }
  96. }
  97. }
  98. func TestControl(t *testing.T) {
  99. const message = "this is a ping/pong messsage"
  100. for _, isServer := range []bool{true, false} {
  101. for _, isWriteControl := range []bool{true, false} {
  102. name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
  103. var connBuf bytes.Buffer
  104. wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
  105. rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
  106. if isWriteControl {
  107. wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
  108. } else {
  109. w, err := wc.NextWriter(PongMessage)
  110. if err != nil {
  111. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  112. continue
  113. }
  114. if _, err := w.Write([]byte(message)); err != nil {
  115. t.Errorf("%s: w.Write() returned %v", name, err)
  116. continue
  117. }
  118. if err := w.Close(); err != nil {
  119. t.Errorf("%s: w.Close() returned %v", name, err)
  120. continue
  121. }
  122. var actualMessage string
  123. rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
  124. rc.NextReader()
  125. if actualMessage != message {
  126. t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
  127. continue
  128. }
  129. }
  130. }
  131. }
  132. }
  133. func TestCloseBeforeFinalFrame(t *testing.T) {
  134. const bufSize = 512
  135. expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
  136. var b1, b2 bytes.Buffer
  137. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
  138. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
  139. w, _ := wc.NextWriter(BinaryMessage)
  140. w.Write(make([]byte, bufSize+bufSize/2))
  141. wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
  142. w.Close()
  143. op, r, err := rc.NextReader()
  144. if op != BinaryMessage || err != nil {
  145. t.Fatalf("NextReader() returned %d, %v", op, err)
  146. }
  147. _, err = io.Copy(ioutil.Discard, r)
  148. if !reflect.DeepEqual(err, expectedErr) {
  149. t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
  150. }
  151. _, _, err = rc.NextReader()
  152. if !reflect.DeepEqual(err, expectedErr) {
  153. t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
  154. }
  155. }
  156. func TestEOFWithinFrame(t *testing.T) {
  157. const bufSize = 64
  158. for n := 0; ; n++ {
  159. var b bytes.Buffer
  160. wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
  161. rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
  162. w, _ := wc.NextWriter(BinaryMessage)
  163. w.Write(make([]byte, bufSize))
  164. w.Close()
  165. if n >= b.Len() {
  166. break
  167. }
  168. b.Truncate(n)
  169. op, r, err := rc.NextReader()
  170. if err == errUnexpectedEOF {
  171. continue
  172. }
  173. if op != BinaryMessage || err != nil {
  174. t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
  175. }
  176. _, err = io.Copy(ioutil.Discard, r)
  177. if err != errUnexpectedEOF {
  178. t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
  179. }
  180. _, _, err = rc.NextReader()
  181. if err != errUnexpectedEOF {
  182. t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
  183. }
  184. }
  185. }
  186. func TestEOFBeforeFinalFrame(t *testing.T) {
  187. const bufSize = 512
  188. var b1, b2 bytes.Buffer
  189. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
  190. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
  191. w, _ := wc.NextWriter(BinaryMessage)
  192. w.Write(make([]byte, bufSize+bufSize/2))
  193. op, r, err := rc.NextReader()
  194. if op != BinaryMessage || err != nil {
  195. t.Fatalf("NextReader() returned %d, %v", op, err)
  196. }
  197. _, err = io.Copy(ioutil.Discard, r)
  198. if err != errUnexpectedEOF {
  199. t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
  200. }
  201. _, _, err = rc.NextReader()
  202. if err != errUnexpectedEOF {
  203. t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
  204. }
  205. }
  206. func TestReadLimit(t *testing.T) {
  207. const readLimit = 512
  208. message := make([]byte, readLimit+1)
  209. var b1, b2 bytes.Buffer
  210. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
  211. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
  212. rc.SetReadLimit(readLimit)
  213. // Send message at the limit with interleaved pong.
  214. w, _ := wc.NextWriter(BinaryMessage)
  215. w.Write(message[:readLimit-1])
  216. wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
  217. w.Write(message[:1])
  218. w.Close()
  219. // Send message larger than the limit.
  220. wc.WriteMessage(BinaryMessage, message[:readLimit+1])
  221. op, _, err := rc.NextReader()
  222. if op != BinaryMessage || err != nil {
  223. t.Fatalf("1: NextReader() returned %d, %v", op, err)
  224. }
  225. op, r, err := rc.NextReader()
  226. if op != BinaryMessage || err != nil {
  227. t.Fatalf("2: NextReader() returned %d, %v", op, err)
  228. }
  229. _, err = io.Copy(ioutil.Discard, r)
  230. if err != ErrReadLimit {
  231. t.Fatalf("io.Copy() returned %v", err)
  232. }
  233. }
  234. func TestUnderlyingConn(t *testing.T) {
  235. var b1, b2 bytes.Buffer
  236. fc := fakeNetConn{Reader: &b1, Writer: &b2}
  237. c := newConn(fc, true, 1024, 1024)
  238. ul := c.UnderlyingConn()
  239. if ul != fc {
  240. t.Fatalf("Underlying conn is not what it should be.")
  241. }
  242. }
  243. func TestBufioReadBytes(t *testing.T) {
  244. // Test calling bufio.ReadBytes for value longer than read buffer size.
  245. m := make([]byte, 512)
  246. m[len(m)-1] = '\n'
  247. var b1, b2 bytes.Buffer
  248. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
  249. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
  250. w, _ := wc.NextWriter(BinaryMessage)
  251. w.Write(m)
  252. w.Close()
  253. op, r, err := rc.NextReader()
  254. if op != BinaryMessage || err != nil {
  255. t.Fatalf("NextReader() returned %d, %v", op, err)
  256. }
  257. br := bufio.NewReader(r)
  258. p, err := br.ReadBytes('\n')
  259. if err != nil {
  260. t.Fatalf("ReadBytes() returned %v", err)
  261. }
  262. if len(p) != len(m) {
  263. t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
  264. }
  265. }
  266. var closeErrorTests = []struct {
  267. err error
  268. codes []int
  269. ok bool
  270. }{
  271. {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
  272. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
  273. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
  274. {errors.New("hello"), []int{CloseNormalClosure}, false},
  275. }
  276. func TestCloseError(t *testing.T) {
  277. for _, tt := range closeErrorTests {
  278. ok := IsCloseError(tt.err, tt.codes...)
  279. if ok != tt.ok {
  280. t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
  281. }
  282. }
  283. }
  284. var unexpectedCloseErrorTests = []struct {
  285. err error
  286. codes []int
  287. ok bool
  288. }{
  289. {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
  290. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
  291. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
  292. {errors.New("hello"), []int{CloseNormalClosure}, false},
  293. }
  294. func TestUnexpectedCloseErrors(t *testing.T) {
  295. for _, tt := range unexpectedCloseErrorTests {
  296. ok := IsUnexpectedCloseError(tt.err, tt.codes...)
  297. if ok != tt.ok {
  298. t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
  299. }
  300. }
  301. }
  302. type blockingWriter struct {
  303. c1, c2 chan struct{}
  304. }
  305. func (w blockingWriter) Write(p []byte) (int, error) {
  306. // Allow main to continue
  307. close(w.c1)
  308. // Wait for panic in main
  309. <-w.c2
  310. return len(p), nil
  311. }
  312. func TestConcurrentWritePanic(t *testing.T) {
  313. w := blockingWriter{make(chan struct{}), make(chan struct{})}
  314. c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
  315. go func() {
  316. c.WriteMessage(TextMessage, []byte{})
  317. }()
  318. // wait for goroutine to block in write.
  319. <-w.c1
  320. defer func() {
  321. close(w.c2)
  322. if v := recover(); v != nil {
  323. return
  324. }
  325. }()
  326. c.WriteMessage(TextMessage, []byte{})
  327. t.Fatal("should not get here")
  328. }
  329. type failingReader struct{}
  330. func (r failingReader) Read(p []byte) (int, error) {
  331. return 0, io.EOF
  332. }
  333. func TestFailedConnectionReadPanic(t *testing.T) {
  334. c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
  335. defer func() {
  336. if v := recover(); v != nil {
  337. return
  338. }
  339. }()
  340. for i := 0; i < 20000; i++ {
  341. c.ReadMessage()
  342. }
  343. t.Fatal("should not get here")
  344. }