conn_test.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bufio"
  7. "bytes"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/ioutil"
  12. "net"
  13. "reflect"
  14. "sync"
  15. "testing"
  16. "testing/iotest"
  17. "time"
  18. )
  19. var _ net.Error = errWriteTimeout
  20. type fakeNetConn struct {
  21. io.Reader
  22. io.Writer
  23. }
  24. func (c fakeNetConn) Close() error { return nil }
  25. func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
  26. func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
  27. func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
  28. func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
  29. func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
  30. type fakeAddr int
  31. var (
  32. localAddr = fakeAddr(1)
  33. remoteAddr = fakeAddr(2)
  34. )
  35. func (a fakeAddr) Network() string {
  36. return "net"
  37. }
  38. func (a fakeAddr) String() string {
  39. return "str"
  40. }
  41. // newTestConn creates a connnection backed by a fake network connection using
  42. // default values for buffering.
  43. func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
  44. return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
  45. }
  46. func TestFraming(t *testing.T) {
  47. frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
  48. var readChunkers = []struct {
  49. name string
  50. f func(io.Reader) io.Reader
  51. }{
  52. {"half", iotest.HalfReader},
  53. {"one", iotest.OneByteReader},
  54. {"asis", func(r io.Reader) io.Reader { return r }},
  55. }
  56. writeBuf := make([]byte, 65537)
  57. for i := range writeBuf {
  58. writeBuf[i] = byte(i)
  59. }
  60. var writers = []struct {
  61. name string
  62. f func(w io.Writer, n int) (int, error)
  63. }{
  64. {"iocopy", func(w io.Writer, n int) (int, error) {
  65. nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
  66. return int(nn), err
  67. }},
  68. {"write", func(w io.Writer, n int) (int, error) {
  69. return w.Write(writeBuf[:n])
  70. }},
  71. {"string", func(w io.Writer, n int) (int, error) {
  72. return io.WriteString(w, string(writeBuf[:n]))
  73. }},
  74. }
  75. for _, compress := range []bool{false, true} {
  76. for _, isServer := range []bool{true, false} {
  77. for _, chunker := range readChunkers {
  78. var connBuf bytes.Buffer
  79. wc := newTestConn(nil, &connBuf, isServer)
  80. rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
  81. if compress {
  82. wc.newCompressionWriter = compressNoContextTakeover
  83. rc.newDecompressionReader = decompressNoContextTakeover
  84. }
  85. for _, n := range frameSizes {
  86. for _, writer := range writers {
  87. name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
  88. w, err := wc.NextWriter(TextMessage)
  89. if err != nil {
  90. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  91. continue
  92. }
  93. nn, err := writer.f(w, n)
  94. if err != nil || nn != n {
  95. t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
  96. continue
  97. }
  98. err = w.Close()
  99. if err != nil {
  100. t.Errorf("%s: w.Close() returned %v", name, err)
  101. continue
  102. }
  103. opCode, r, err := rc.NextReader()
  104. if err != nil || opCode != TextMessage {
  105. t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
  106. continue
  107. }
  108. rbuf, err := ioutil.ReadAll(r)
  109. if err != nil {
  110. t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
  111. continue
  112. }
  113. if len(rbuf) != n {
  114. t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
  115. continue
  116. }
  117. for i, b := range rbuf {
  118. if byte(i) != b {
  119. t.Errorf("%s: bad byte at offset %d", name, i)
  120. break
  121. }
  122. }
  123. }
  124. }
  125. }
  126. }
  127. }
  128. }
  129. func TestControl(t *testing.T) {
  130. const message = "this is a ping/pong messsage"
  131. for _, isServer := range []bool{true, false} {
  132. for _, isWriteControl := range []bool{true, false} {
  133. name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
  134. var connBuf bytes.Buffer
  135. wc := newTestConn(nil, &connBuf, isServer)
  136. rc := newTestConn(&connBuf, nil, !isServer)
  137. if isWriteControl {
  138. wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
  139. } else {
  140. w, err := wc.NextWriter(PongMessage)
  141. if err != nil {
  142. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  143. continue
  144. }
  145. if _, err := w.Write([]byte(message)); err != nil {
  146. t.Errorf("%s: w.Write() returned %v", name, err)
  147. continue
  148. }
  149. if err := w.Close(); err != nil {
  150. t.Errorf("%s: w.Close() returned %v", name, err)
  151. continue
  152. }
  153. var actualMessage string
  154. rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
  155. rc.NextReader()
  156. if actualMessage != message {
  157. t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
  158. continue
  159. }
  160. }
  161. }
  162. }
  163. }
  164. // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
  165. type simpleBufferPool struct {
  166. v interface{}
  167. }
  168. func (p *simpleBufferPool) Get() interface{} {
  169. v := p.v
  170. p.v = nil
  171. return v
  172. }
  173. func (p *simpleBufferPool) Put(v interface{}) {
  174. p.v = v
  175. }
  176. func TestWriteBufferPool(t *testing.T) {
  177. var buf bytes.Buffer
  178. var pool simpleBufferPool
  179. wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
  180. rc := newTestConn(&buf, nil, false)
  181. if wc.writeBuf != nil {
  182. t.Fatal("writeBuf not nil after create")
  183. }
  184. // Part 1: test NextWriter/Write/Close
  185. w, err := wc.NextWriter(TextMessage)
  186. if err != nil {
  187. t.Fatalf("wc.NextWriter() returned %v", err)
  188. }
  189. if wc.writeBuf == nil {
  190. t.Fatal("writeBuf is nil after NextWriter")
  191. }
  192. writeBufAddr := &wc.writeBuf[0]
  193. const message = "Hello World!"
  194. if _, err := io.WriteString(w, message); err != nil {
  195. t.Fatalf("io.WriteString(w, message) returned %v", err)
  196. }
  197. if err := w.Close(); err != nil {
  198. t.Fatalf("w.Close() returned %v", err)
  199. }
  200. if wc.writeBuf != nil {
  201. t.Fatal("writeBuf not nil after w.Close()")
  202. }
  203. if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
  204. t.Fatal("writeBuf not returned to pool")
  205. }
  206. opCode, p, err := rc.ReadMessage()
  207. if opCode != TextMessage || err != nil {
  208. t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
  209. }
  210. if s := string(p); s != message {
  211. t.Fatalf("message is %s, want %s", s, message)
  212. }
  213. // Part 2: Test WriteMessage.
  214. if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
  215. t.Fatalf("wc.WriteMessage() returned %v", err)
  216. }
  217. if wc.writeBuf != nil {
  218. t.Fatal("writeBuf not nil after wc.WriteMessage()")
  219. }
  220. if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
  221. t.Fatal("writeBuf not returned to pool after WriteMessage")
  222. }
  223. opCode, p, err = rc.ReadMessage()
  224. if opCode != TextMessage || err != nil {
  225. t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
  226. }
  227. if s := string(p); s != message {
  228. t.Fatalf("message is %s, want %s", s, message)
  229. }
  230. }
  231. func TestWriteBufferPoolSync(t *testing.T) {
  232. var buf bytes.Buffer
  233. var pool sync.Pool
  234. wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
  235. rc := newTestConn(&buf, nil, false)
  236. const message = "Hello World!"
  237. for i := 0; i < 3; i++ {
  238. if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
  239. t.Fatalf("wc.WriteMessage() returned %v", err)
  240. }
  241. opCode, p, err := rc.ReadMessage()
  242. if opCode != TextMessage || err != nil {
  243. t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
  244. }
  245. if s := string(p); s != message {
  246. t.Fatalf("message is %s, want %s", s, message)
  247. }
  248. }
  249. }
  250. func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
  251. const bufSize = 512
  252. expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
  253. var b1, b2 bytes.Buffer
  254. wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
  255. rc := newTestConn(&b1, &b2, true)
  256. w, _ := wc.NextWriter(BinaryMessage)
  257. w.Write(make([]byte, bufSize+bufSize/2))
  258. wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
  259. w.Close()
  260. op, r, err := rc.NextReader()
  261. if op != BinaryMessage || err != nil {
  262. t.Fatalf("NextReader() returned %d, %v", op, err)
  263. }
  264. _, err = io.Copy(ioutil.Discard, r)
  265. if !reflect.DeepEqual(err, expectedErr) {
  266. t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
  267. }
  268. _, _, err = rc.NextReader()
  269. if !reflect.DeepEqual(err, expectedErr) {
  270. t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
  271. }
  272. }
  273. func TestEOFWithinFrame(t *testing.T) {
  274. const bufSize = 64
  275. for n := 0; ; n++ {
  276. var b bytes.Buffer
  277. wc := newTestConn(nil, &b, false)
  278. rc := newTestConn(&b, nil, true)
  279. w, _ := wc.NextWriter(BinaryMessage)
  280. w.Write(make([]byte, bufSize))
  281. w.Close()
  282. if n >= b.Len() {
  283. break
  284. }
  285. b.Truncate(n)
  286. op, r, err := rc.NextReader()
  287. if err == errUnexpectedEOF {
  288. continue
  289. }
  290. if op != BinaryMessage || err != nil {
  291. t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
  292. }
  293. _, err = io.Copy(ioutil.Discard, r)
  294. if err != errUnexpectedEOF {
  295. t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
  296. }
  297. _, _, err = rc.NextReader()
  298. if err != errUnexpectedEOF {
  299. t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
  300. }
  301. }
  302. }
  303. func TestEOFBeforeFinalFrame(t *testing.T) {
  304. const bufSize = 512
  305. var b1, b2 bytes.Buffer
  306. wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
  307. rc := newTestConn(&b1, &b2, true)
  308. w, _ := wc.NextWriter(BinaryMessage)
  309. w.Write(make([]byte, bufSize+bufSize/2))
  310. op, r, err := rc.NextReader()
  311. if op != BinaryMessage || err != nil {
  312. t.Fatalf("NextReader() returned %d, %v", op, err)
  313. }
  314. _, err = io.Copy(ioutil.Discard, r)
  315. if err != errUnexpectedEOF {
  316. t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
  317. }
  318. _, _, err = rc.NextReader()
  319. if err != errUnexpectedEOF {
  320. t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
  321. }
  322. }
  323. func TestWriteAfterMessageWriterClose(t *testing.T) {
  324. wc := newTestConn(nil, &bytes.Buffer{}, false)
  325. w, _ := wc.NextWriter(BinaryMessage)
  326. io.WriteString(w, "hello")
  327. if err := w.Close(); err != nil {
  328. t.Fatalf("unxpected error closing message writer, %v", err)
  329. }
  330. if _, err := io.WriteString(w, "world"); err == nil {
  331. t.Fatalf("no error writing after close")
  332. }
  333. w, _ = wc.NextWriter(BinaryMessage)
  334. io.WriteString(w, "hello")
  335. // close w by getting next writer
  336. _, err := wc.NextWriter(BinaryMessage)
  337. if err != nil {
  338. t.Fatalf("unexpected error getting next writer, %v", err)
  339. }
  340. if _, err := io.WriteString(w, "world"); err == nil {
  341. t.Fatalf("no error writing after close")
  342. }
  343. }
  344. func TestReadLimit(t *testing.T) {
  345. const readLimit = 512
  346. message := make([]byte, readLimit+1)
  347. var b1, b2 bytes.Buffer
  348. wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
  349. rc := newTestConn(&b1, &b2, true)
  350. rc.SetReadLimit(readLimit)
  351. // Send message at the limit with interleaved pong.
  352. w, _ := wc.NextWriter(BinaryMessage)
  353. w.Write(message[:readLimit-1])
  354. wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
  355. w.Write(message[:1])
  356. w.Close()
  357. // Send message larger than the limit.
  358. wc.WriteMessage(BinaryMessage, message[:readLimit+1])
  359. op, _, err := rc.NextReader()
  360. if op != BinaryMessage || err != nil {
  361. t.Fatalf("1: NextReader() returned %d, %v", op, err)
  362. }
  363. op, r, err := rc.NextReader()
  364. if op != BinaryMessage || err != nil {
  365. t.Fatalf("2: NextReader() returned %d, %v", op, err)
  366. }
  367. _, err = io.Copy(ioutil.Discard, r)
  368. if err != ErrReadLimit {
  369. t.Fatalf("io.Copy() returned %v", err)
  370. }
  371. }
  372. func TestAddrs(t *testing.T) {
  373. c := newTestConn(nil, nil, true)
  374. if c.LocalAddr() != localAddr {
  375. t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
  376. }
  377. if c.RemoteAddr() != remoteAddr {
  378. t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
  379. }
  380. }
  381. func TestUnderlyingConn(t *testing.T) {
  382. var b1, b2 bytes.Buffer
  383. fc := fakeNetConn{Reader: &b1, Writer: &b2}
  384. c := newConn(fc, true, 1024, 1024, nil, nil, nil)
  385. ul := c.UnderlyingConn()
  386. if ul != fc {
  387. t.Fatalf("Underlying conn is not what it should be.")
  388. }
  389. }
  390. func TestBufioReadBytes(t *testing.T) {
  391. // Test calling bufio.ReadBytes for value longer than read buffer size.
  392. m := make([]byte, 512)
  393. m[len(m)-1] = '\n'
  394. var b1, b2 bytes.Buffer
  395. wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
  396. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
  397. w, _ := wc.NextWriter(BinaryMessage)
  398. w.Write(m)
  399. w.Close()
  400. op, r, err := rc.NextReader()
  401. if op != BinaryMessage || err != nil {
  402. t.Fatalf("NextReader() returned %d, %v", op, err)
  403. }
  404. br := bufio.NewReader(r)
  405. p, err := br.ReadBytes('\n')
  406. if err != nil {
  407. t.Fatalf("ReadBytes() returned %v", err)
  408. }
  409. if len(p) != len(m) {
  410. t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
  411. }
  412. }
  413. var closeErrorTests = []struct {
  414. err error
  415. codes []int
  416. ok bool
  417. }{
  418. {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
  419. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
  420. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
  421. {errors.New("hello"), []int{CloseNormalClosure}, false},
  422. }
  423. func TestCloseError(t *testing.T) {
  424. for _, tt := range closeErrorTests {
  425. ok := IsCloseError(tt.err, tt.codes...)
  426. if ok != tt.ok {
  427. t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
  428. }
  429. }
  430. }
  431. var unexpectedCloseErrorTests = []struct {
  432. err error
  433. codes []int
  434. ok bool
  435. }{
  436. {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
  437. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
  438. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
  439. {errors.New("hello"), []int{CloseNormalClosure}, false},
  440. }
  441. func TestUnexpectedCloseErrors(t *testing.T) {
  442. for _, tt := range unexpectedCloseErrorTests {
  443. ok := IsUnexpectedCloseError(tt.err, tt.codes...)
  444. if ok != tt.ok {
  445. t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
  446. }
  447. }
  448. }
  449. type blockingWriter struct {
  450. c1, c2 chan struct{}
  451. }
  452. func (w blockingWriter) Write(p []byte) (int, error) {
  453. // Allow main to continue
  454. close(w.c1)
  455. // Wait for panic in main
  456. <-w.c2
  457. return len(p), nil
  458. }
  459. func TestConcurrentWritePanic(t *testing.T) {
  460. w := blockingWriter{make(chan struct{}), make(chan struct{})}
  461. c := newTestConn(nil, w, false)
  462. go func() {
  463. c.WriteMessage(TextMessage, []byte{})
  464. }()
  465. // wait for goroutine to block in write.
  466. <-w.c1
  467. defer func() {
  468. close(w.c2)
  469. if v := recover(); v != nil {
  470. return
  471. }
  472. }()
  473. c.WriteMessage(TextMessage, []byte{})
  474. t.Fatal("should not get here")
  475. }
  476. type failingReader struct{}
  477. func (r failingReader) Read(p []byte) (int, error) {
  478. return 0, io.EOF
  479. }
  480. func TestFailedConnectionReadPanic(t *testing.T) {
  481. c := newTestConn(failingReader{}, nil, false)
  482. defer func() {
  483. if v := recover(); v != nil {
  484. return
  485. }
  486. }()
  487. for i := 0; i < 20000; i++ {
  488. c.ReadMessage()
  489. }
  490. t.Fatal("should not get here")
  491. }