conn_test.go 11 KB


  1. // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bufio"
  7. "bytes"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/ioutil"
  12. "net"
  13. "reflect"
  14. "testing"
  15. "testing/iotest"
  16. "time"
  17. )
  18. var _ net.Error = errWriteTimeout
  19. type fakeNetConn struct {
  20. io.Reader
  21. io.Writer
  22. }
  23. func (c fakeNetConn) Close() error { return nil }
  24. func (c fakeNetConn) LocalAddr() net.Addr { return nil }
  25. func (c fakeNetConn) RemoteAddr() net.Addr { return nil }
  26. func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
  27. func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
  28. func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
  29. func TestFraming(t *testing.T) {
  30. frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
  31. var readChunkers = []struct {
  32. name string
  33. f func(io.Reader) io.Reader
  34. }{
  35. {"half", iotest.HalfReader},
  36. {"one", iotest.OneByteReader},
  37. {"asis", func(r io.Reader) io.Reader { return r }},
  38. }
  39. writeBuf := make([]byte, 65537)
  40. for i := range writeBuf {
  41. writeBuf[i] = byte(i)
  42. }
  43. for _, compress := range []bool{false, true} {
  44. for _, isServer := range []bool{true, false} {
  45. for _, chunker := range readChunkers {
  46. var connBuf bytes.Buffer
  47. wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
  48. rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
  49. if compress {
  50. wc.newCompressionWriter = compressNoContextTakeover
  51. rc.newDecompressionReader = decompressNoContextTakeover
  52. }
  53. for _, n := range frameSizes {
  54. for _, iocopy := range []bool{true, false} {
  55. name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d c:%v", compress, isServer, chunker.name, n, iocopy)
  56. w, err := wc.NextWriter(TextMessage)
  57. if err != nil {
  58. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  59. continue
  60. }
  61. var nn int
  62. if iocopy {
  63. var n64 int64
  64. n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
  65. nn = int(n64)
  66. } else {
  67. nn, err = w.Write(writeBuf[:n])
  68. }
  69. if err != nil || nn != n {
  70. t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
  71. continue
  72. }
  73. err = w.Close()
  74. if err != nil {
  75. t.Errorf("%s: w.Close() returned %v", name, err)
  76. continue
  77. }
  78. opCode, r, err := rc.NextReader()
  79. if err != nil || opCode != TextMessage {
  80. t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
  81. continue
  82. }
  83. rbuf, err := ioutil.ReadAll(r)
  84. if err != nil {
  85. t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
  86. continue
  87. }
  88. if len(rbuf) != n {
  89. t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
  90. continue
  91. }
  92. for i, b := range rbuf {
  93. if byte(i) != b {
  94. t.Errorf("%s: bad byte at offset %d", name, i)
  95. break
  96. }
  97. }
  98. }
  99. }
  100. }
  101. }
  102. }
  103. }
  104. func TestControl(t *testing.T) {
  105. const message = "this is a ping/pong messsage"
  106. for _, isServer := range []bool{true, false} {
  107. for _, isWriteControl := range []bool{true, false} {
  108. name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
  109. var connBuf bytes.Buffer
  110. wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
  111. rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
  112. if isWriteControl {
  113. wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
  114. } else {
  115. w, err := wc.NextWriter(PongMessage)
  116. if err != nil {
  117. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  118. continue
  119. }
  120. if _, err := w.Write([]byte(message)); err != nil {
  121. t.Errorf("%s: w.Write() returned %v", name, err)
  122. continue
  123. }
  124. if err := w.Close(); err != nil {
  125. t.Errorf("%s: w.Close() returned %v", name, err)
  126. continue
  127. }
  128. var actualMessage string
  129. rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
  130. rc.NextReader()
  131. if actualMessage != message {
  132. t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
  133. continue
  134. }
  135. }
  136. }
  137. }
  138. }
  139. func TestCloseBeforeFinalFrame(t *testing.T) {
  140. const bufSize = 512
  141. expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
  142. var b1, b2 bytes.Buffer
  143. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
  144. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
  145. w, _ := wc.NextWriter(BinaryMessage)
  146. w.Write(make([]byte, bufSize+bufSize/2))
  147. wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
  148. w.Close()
  149. op, r, err := rc.NextReader()
  150. if op != BinaryMessage || err != nil {
  151. t.Fatalf("NextReader() returned %d, %v", op, err)
  152. }
  153. _, err = io.Copy(ioutil.Discard, r)
  154. if !reflect.DeepEqual(err, expectedErr) {
  155. t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
  156. }
  157. _, _, err = rc.NextReader()
  158. if !reflect.DeepEqual(err, expectedErr) {
  159. t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
  160. }
  161. }
  162. func TestEOFWithinFrame(t *testing.T) {
  163. const bufSize = 64
  164. for n := 0; ; n++ {
  165. var b bytes.Buffer
  166. wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
  167. rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
  168. w, _ := wc.NextWriter(BinaryMessage)
  169. w.Write(make([]byte, bufSize))
  170. w.Close()
  171. if n >= b.Len() {
  172. break
  173. }
  174. b.Truncate(n)
  175. op, r, err := rc.NextReader()
  176. if err == errUnexpectedEOF {
  177. continue
  178. }
  179. if op != BinaryMessage || err != nil {
  180. t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
  181. }
  182. _, err = io.Copy(ioutil.Discard, r)
  183. if err != errUnexpectedEOF {
  184. t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
  185. }
  186. _, _, err = rc.NextReader()
  187. if err != errUnexpectedEOF {
  188. t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
  189. }
  190. }
  191. }
  192. func TestEOFBeforeFinalFrame(t *testing.T) {
  193. const bufSize = 512
  194. var b1, b2 bytes.Buffer
  195. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
  196. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
  197. w, _ := wc.NextWriter(BinaryMessage)
  198. w.Write(make([]byte, bufSize+bufSize/2))
  199. op, r, err := rc.NextReader()
  200. if op != BinaryMessage || err != nil {
  201. t.Fatalf("NextReader() returned %d, %v", op, err)
  202. }
  203. _, err = io.Copy(ioutil.Discard, r)
  204. if err != errUnexpectedEOF {
  205. t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
  206. }
  207. _, _, err = rc.NextReader()
  208. if err != errUnexpectedEOF {
  209. t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
  210. }
  211. }
  212. func TestReadLimit(t *testing.T) {
  213. const readLimit = 512
  214. message := make([]byte, readLimit+1)
  215. var b1, b2 bytes.Buffer
  216. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
  217. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
  218. rc.SetReadLimit(readLimit)
  219. // Send message at the limit with interleaved pong.
  220. w, _ := wc.NextWriter(BinaryMessage)
  221. w.Write(message[:readLimit-1])
  222. wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
  223. w.Write(message[:1])
  224. w.Close()
  225. // Send message larger than the limit.
  226. wc.WriteMessage(BinaryMessage, message[:readLimit+1])
  227. op, _, err := rc.NextReader()
  228. if op != BinaryMessage || err != nil {
  229. t.Fatalf("1: NextReader() returned %d, %v", op, err)
  230. }
  231. op, r, err := rc.NextReader()
  232. if op != BinaryMessage || err != nil {
  233. t.Fatalf("2: NextReader() returned %d, %v", op, err)
  234. }
  235. _, err = io.Copy(ioutil.Discard, r)
  236. if err != ErrReadLimit {
  237. t.Fatalf("io.Copy() returned %v", err)
  238. }
  239. }
  240. func TestUnderlyingConn(t *testing.T) {
  241. var b1, b2 bytes.Buffer
  242. fc := fakeNetConn{Reader: &b1, Writer: &b2}
  243. c := newConn(fc, true, 1024, 1024)
  244. ul := c.UnderlyingConn()
  245. if ul != fc {
  246. t.Fatalf("Underlying conn is not what it should be.")
  247. }
  248. }
  249. func TestBufioReadBytes(t *testing.T) {
  250. // Test calling bufio.ReadBytes for value longer than read buffer size.
  251. m := make([]byte, 512)
  252. m[len(m)-1] = '\n'
  253. var b1, b2 bytes.Buffer
  254. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
  255. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
  256. w, _ := wc.NextWriter(BinaryMessage)
  257. w.Write(m)
  258. w.Close()
  259. op, r, err := rc.NextReader()
  260. if op != BinaryMessage || err != nil {
  261. t.Fatalf("NextReader() returned %d, %v", op, err)
  262. }
  263. br := bufio.NewReader(r)
  264. p, err := br.ReadBytes('\n')
  265. if err != nil {
  266. t.Fatalf("ReadBytes() returned %v", err)
  267. }
  268. if len(p) != len(m) {
  269. t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
  270. }
  271. }
  272. var closeErrorTests = []struct {
  273. err error
  274. codes []int
  275. ok bool
  276. }{
  277. {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
  278. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
  279. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
  280. {errors.New("hello"), []int{CloseNormalClosure}, false},
  281. }
  282. func TestCloseError(t *testing.T) {
  283. for _, tt := range closeErrorTests {
  284. ok := IsCloseError(tt.err, tt.codes...)
  285. if ok != tt.ok {
  286. t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
  287. }
  288. }
  289. }
  290. var unexpectedCloseErrorTests = []struct {
  291. err error
  292. codes []int
  293. ok bool
  294. }{
  295. {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
  296. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
  297. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
  298. {errors.New("hello"), []int{CloseNormalClosure}, false},
  299. }
  300. func TestUnexpectedCloseErrors(t *testing.T) {
  301. for _, tt := range unexpectedCloseErrorTests {
  302. ok := IsUnexpectedCloseError(tt.err, tt.codes...)
  303. if ok != tt.ok {
  304. t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
  305. }
  306. }
  307. }
  308. type blockingWriter struct {
  309. c1, c2 chan struct{}
  310. }
  311. func (w blockingWriter) Write(p []byte) (int, error) {
  312. // Allow main to continue
  313. close(w.c1)
  314. // Wait for panic in main
  315. <-w.c2
  316. return len(p), nil
  317. }
  318. func TestConcurrentWritePanic(t *testing.T) {
  319. w := blockingWriter{make(chan struct{}), make(chan struct{})}
  320. c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
  321. go func() {
  322. c.WriteMessage(TextMessage, []byte{})
  323. }()
  324. // wait for goroutine to block in write.
  325. <-w.c1
  326. defer func() {
  327. close(w.c2)
  328. if v := recover(); v != nil {
  329. return
  330. }
  331. }()
  332. c.WriteMessage(TextMessage, []byte{})
  333. t.Fatal("should not get here")
  334. }
  335. type failingReader struct{}
  336. func (r failingReader) Read(p []byte) (int, error) {
  337. return 0, io.EOF
  338. }
  339. func TestFailedConnectionReadPanic(t *testing.T) {
  340. c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
  341. defer func() {
  342. if v := recover(); v != nil {
  343. return
  344. }
  345. }()
  346. for i := 0; i < 20000; i++ {
  347. c.ReadMessage()
  348. }
  349. t.Fatal("should not get here")
  350. }