websocket_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bytes"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/http/httptest"
  13. "net/url"
  14. "reflect"
  15. "runtime"
  16. "strings"
  17. "sync"
  18. "testing"
  19. "time"
  20. )
  21. var serverAddr string
  22. var once sync.Once
  23. func echoServer(ws *Conn) {
  24. defer ws.Close()
  25. io.Copy(ws, ws)
  26. }
  27. type Count struct {
  28. S string
  29. N int
  30. }
  31. func countServer(ws *Conn) {
  32. defer ws.Close()
  33. for {
  34. var count Count
  35. err := JSON.Receive(ws, &count)
  36. if err != nil {
  37. return
  38. }
  39. count.N++
  40. count.S = strings.Repeat(count.S, count.N)
  41. err = JSON.Send(ws, count)
  42. if err != nil {
  43. return
  44. }
  45. }
  46. }
  47. type testCtrlAndDataHandler struct {
  48. hybiFrameHandler
  49. }
  50. func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
  51. h.hybiFrameHandler.conn.wio.Lock()
  52. defer h.hybiFrameHandler.conn.wio.Unlock()
  53. w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
  54. if err != nil {
  55. return 0, err
  56. }
  57. n, err := w.Write(b)
  58. w.Close()
  59. return n, err
  60. }
  61. func ctrlAndDataServer(ws *Conn) {
  62. defer ws.Close()
  63. h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
  64. ws.frameHandler = h
  65. go func() {
  66. for i := 0; ; i++ {
  67. var b []byte
  68. if i%2 != 0 { // with or without payload
  69. b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
  70. }
  71. if _, err := h.WritePing(b); err != nil {
  72. break
  73. }
  74. if _, err := h.WritePong(b); err != nil { // unsolicited pong
  75. break
  76. }
  77. time.Sleep(10 * time.Millisecond)
  78. }
  79. }()
  80. b := make([]byte, 128)
  81. for {
  82. n, err := ws.Read(b)
  83. if err != nil {
  84. break
  85. }
  86. if _, err := ws.Write(b[:n]); err != nil {
  87. break
  88. }
  89. }
  90. }
  91. func subProtocolHandshake(config *Config, req *http.Request) error {
  92. for _, proto := range config.Protocol {
  93. if proto == "chat" {
  94. config.Protocol = []string{proto}
  95. return nil
  96. }
  97. }
  98. return ErrBadWebSocketProtocol
  99. }
  100. func subProtoServer(ws *Conn) {
  101. for _, proto := range ws.Config().Protocol {
  102. io.WriteString(ws, proto)
  103. }
  104. }
  105. func startServer() {
  106. http.Handle("/echo", Handler(echoServer))
  107. http.Handle("/count", Handler(countServer))
  108. http.Handle("/ctrldata", Handler(ctrlAndDataServer))
  109. subproto := Server{
  110. Handshake: subProtocolHandshake,
  111. Handler: Handler(subProtoServer),
  112. }
  113. http.Handle("/subproto", subproto)
  114. server := httptest.NewServer(nil)
  115. serverAddr = server.Listener.Addr().String()
  116. log.Print("Test WebSocket server listening on ", serverAddr)
  117. }
  118. func newConfig(t *testing.T, path string) *Config {
  119. config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
  120. return config
  121. }
  122. func TestEcho(t *testing.T) {
  123. once.Do(startServer)
  124. // websocket.Dial()
  125. client, err := net.Dial("tcp", serverAddr)
  126. if err != nil {
  127. t.Fatal("dialing", err)
  128. }
  129. conn, err := NewClient(newConfig(t, "/echo"), client)
  130. if err != nil {
  131. t.Errorf("WebSocket handshake error: %v", err)
  132. return
  133. }
  134. msg := []byte("hello, world\n")
  135. if _, err := conn.Write(msg); err != nil {
  136. t.Errorf("Write: %v", err)
  137. }
  138. var actual_msg = make([]byte, 512)
  139. n, err := conn.Read(actual_msg)
  140. if err != nil {
  141. t.Errorf("Read: %v", err)
  142. }
  143. actual_msg = actual_msg[0:n]
  144. if !bytes.Equal(msg, actual_msg) {
  145. t.Errorf("Echo: expected %q got %q", msg, actual_msg)
  146. }
  147. conn.Close()
  148. }
  149. func TestAddr(t *testing.T) {
  150. once.Do(startServer)
  151. // websocket.Dial()
  152. client, err := net.Dial("tcp", serverAddr)
  153. if err != nil {
  154. t.Fatal("dialing", err)
  155. }
  156. conn, err := NewClient(newConfig(t, "/echo"), client)
  157. if err != nil {
  158. t.Errorf("WebSocket handshake error: %v", err)
  159. return
  160. }
  161. ra := conn.RemoteAddr().String()
  162. if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
  163. t.Errorf("Bad remote addr: %v", ra)
  164. }
  165. la := conn.LocalAddr().String()
  166. if !strings.HasPrefix(la, "http://") {
  167. t.Errorf("Bad local addr: %v", la)
  168. }
  169. conn.Close()
  170. }
  171. func TestCount(t *testing.T) {
  172. once.Do(startServer)
  173. // websocket.Dial()
  174. client, err := net.Dial("tcp", serverAddr)
  175. if err != nil {
  176. t.Fatal("dialing", err)
  177. }
  178. conn, err := NewClient(newConfig(t, "/count"), client)
  179. if err != nil {
  180. t.Errorf("WebSocket handshake error: %v", err)
  181. return
  182. }
  183. var count Count
  184. count.S = "hello"
  185. if err := JSON.Send(conn, count); err != nil {
  186. t.Errorf("Write: %v", err)
  187. }
  188. if err := JSON.Receive(conn, &count); err != nil {
  189. t.Errorf("Read: %v", err)
  190. }
  191. if count.N != 1 {
  192. t.Errorf("count: expected %d got %d", 1, count.N)
  193. }
  194. if count.S != "hello" {
  195. t.Errorf("count: expected %q got %q", "hello", count.S)
  196. }
  197. if err := JSON.Send(conn, count); err != nil {
  198. t.Errorf("Write: %v", err)
  199. }
  200. if err := JSON.Receive(conn, &count); err != nil {
  201. t.Errorf("Read: %v", err)
  202. }
  203. if count.N != 2 {
  204. t.Errorf("count: expected %d got %d", 2, count.N)
  205. }
  206. if count.S != "hellohello" {
  207. t.Errorf("count: expected %q got %q", "hellohello", count.S)
  208. }
  209. conn.Close()
  210. }
  211. func TestWithQuery(t *testing.T) {
  212. once.Do(startServer)
  213. client, err := net.Dial("tcp", serverAddr)
  214. if err != nil {
  215. t.Fatal("dialing", err)
  216. }
  217. config := newConfig(t, "/echo")
  218. config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
  219. if err != nil {
  220. t.Fatal("location url", err)
  221. }
  222. ws, err := NewClient(config, client)
  223. if err != nil {
  224. t.Errorf("WebSocket handshake: %v", err)
  225. return
  226. }
  227. ws.Close()
  228. }
  229. func testWithProtocol(t *testing.T, subproto []string) (string, error) {
  230. once.Do(startServer)
  231. client, err := net.Dial("tcp", serverAddr)
  232. if err != nil {
  233. t.Fatal("dialing", err)
  234. }
  235. config := newConfig(t, "/subproto")
  236. config.Protocol = subproto
  237. ws, err := NewClient(config, client)
  238. if err != nil {
  239. return "", err
  240. }
  241. msg := make([]byte, 16)
  242. n, err := ws.Read(msg)
  243. if err != nil {
  244. return "", err
  245. }
  246. ws.Close()
  247. return string(msg[:n]), nil
  248. }
  249. func TestWithProtocol(t *testing.T) {
  250. proto, err := testWithProtocol(t, []string{"chat"})
  251. if err != nil {
  252. t.Errorf("SubProto: unexpected error: %v", err)
  253. }
  254. if proto != "chat" {
  255. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  256. }
  257. }
  258. func TestWithTwoProtocol(t *testing.T) {
  259. proto, err := testWithProtocol(t, []string{"test", "chat"})
  260. if err != nil {
  261. t.Errorf("SubProto: unexpected error: %v", err)
  262. }
  263. if proto != "chat" {
  264. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  265. }
  266. }
  267. func TestWithBadProtocol(t *testing.T) {
  268. _, err := testWithProtocol(t, []string{"test"})
  269. if err != ErrBadStatus {
  270. t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
  271. }
  272. }
  273. func TestHTTP(t *testing.T) {
  274. once.Do(startServer)
  275. // If the client did not send a handshake that matches the protocol
  276. // specification, the server MUST return an HTTP response with an
  277. // appropriate error code (such as 400 Bad Request)
  278. resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
  279. if err != nil {
  280. t.Errorf("Get: error %#v", err)
  281. return
  282. }
  283. if resp == nil {
  284. t.Error("Get: resp is null")
  285. return
  286. }
  287. if resp.StatusCode != http.StatusBadRequest {
  288. t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
  289. }
  290. }
  291. func TestTrailingSpaces(t *testing.T) {
  292. // http://code.google.com/p/go/issues/detail?id=955
  293. // The last runs of this create keys with trailing spaces that should not be
  294. // generated by the client.
  295. once.Do(startServer)
  296. config := newConfig(t, "/echo")
  297. for i := 0; i < 30; i++ {
  298. // body
  299. ws, err := DialConfig(config)
  300. if err != nil {
  301. t.Errorf("Dial #%d failed: %v", i, err)
  302. break
  303. }
  304. ws.Close()
  305. }
  306. }
  307. func TestDialConfigBadVersion(t *testing.T) {
  308. once.Do(startServer)
  309. config := newConfig(t, "/echo")
  310. config.Version = 1234
  311. _, err := DialConfig(config)
  312. if dialerr, ok := err.(*DialError); ok {
  313. if dialerr.Err != ErrBadProtocolVersion {
  314. t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
  315. }
  316. }
  317. }
  318. func TestDialConfigWithDialer(t *testing.T) {
  319. once.Do(startServer)
  320. config := newConfig(t, "/echo")
  321. config.Dialer = &net.Dialer{
  322. Deadline: time.Now().Add(-time.Minute),
  323. }
  324. _, err := DialConfig(config)
  325. dialerr, ok := err.(*DialError)
  326. if !ok {
  327. t.Fatalf("DialError expected, got %#v", err)
  328. }
  329. neterr, ok := dialerr.Err.(*net.OpError)
  330. if !ok {
  331. t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
  332. }
  333. if !neterr.Timeout() {
  334. t.Fatalf("expected timeout error, got %#v", neterr)
  335. }
  336. }
  337. func TestSmallBuffer(t *testing.T) {
  338. // http://code.google.com/p/go/issues/detail?id=1145
  339. // Read should be able to handle reading a fragment of a frame.
  340. once.Do(startServer)
  341. // websocket.Dial()
  342. client, err := net.Dial("tcp", serverAddr)
  343. if err != nil {
  344. t.Fatal("dialing", err)
  345. }
  346. conn, err := NewClient(newConfig(t, "/echo"), client)
  347. if err != nil {
  348. t.Errorf("WebSocket handshake error: %v", err)
  349. return
  350. }
  351. msg := []byte("hello, world\n")
  352. if _, err := conn.Write(msg); err != nil {
  353. t.Errorf("Write: %v", err)
  354. }
  355. var small_msg = make([]byte, 8)
  356. n, err := conn.Read(small_msg)
  357. if err != nil {
  358. t.Errorf("Read: %v", err)
  359. }
  360. if !bytes.Equal(msg[:len(small_msg)], small_msg) {
  361. t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
  362. }
  363. var second_msg = make([]byte, len(msg))
  364. n, err = conn.Read(second_msg)
  365. if err != nil {
  366. t.Errorf("Read: %v", err)
  367. }
  368. second_msg = second_msg[0:n]
  369. if !bytes.Equal(msg[len(small_msg):], second_msg) {
  370. t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
  371. }
  372. conn.Close()
  373. }
  374. var parseAuthorityTests = []struct {
  375. in *url.URL
  376. out string
  377. }{
  378. {
  379. &url.URL{
  380. Scheme: "ws",
  381. Host: "www.google.com",
  382. },
  383. "www.google.com:80",
  384. },
  385. {
  386. &url.URL{
  387. Scheme: "wss",
  388. Host: "www.google.com",
  389. },
  390. "www.google.com:443",
  391. },
  392. {
  393. &url.URL{
  394. Scheme: "ws",
  395. Host: "www.google.com:80",
  396. },
  397. "www.google.com:80",
  398. },
  399. {
  400. &url.URL{
  401. Scheme: "wss",
  402. Host: "www.google.com:443",
  403. },
  404. "www.google.com:443",
  405. },
  406. // some invalid ones for parseAuthority. parseAuthority doesn't
  407. // concern itself with the scheme unless it actually knows about it
  408. {
  409. &url.URL{
  410. Scheme: "http",
  411. Host: "www.google.com",
  412. },
  413. "www.google.com",
  414. },
  415. {
  416. &url.URL{
  417. Scheme: "http",
  418. Host: "www.google.com:80",
  419. },
  420. "www.google.com:80",
  421. },
  422. {
  423. &url.URL{
  424. Scheme: "asdf",
  425. Host: "127.0.0.1",
  426. },
  427. "127.0.0.1",
  428. },
  429. {
  430. &url.URL{
  431. Scheme: "asdf",
  432. Host: "www.google.com",
  433. },
  434. "www.google.com",
  435. },
  436. }
  437. func TestParseAuthority(t *testing.T) {
  438. for _, tt := range parseAuthorityTests {
  439. out := parseAuthority(tt.in)
  440. if out != tt.out {
  441. t.Errorf("got %v; want %v", out, tt.out)
  442. }
  443. }
  444. }
  445. type closerConn struct {
  446. net.Conn
  447. closed int // count of the number of times Close was called
  448. }
  449. func (c *closerConn) Close() error {
  450. c.closed++
  451. return c.Conn.Close()
  452. }
  453. func TestClose(t *testing.T) {
  454. if runtime.GOOS == "plan9" {
  455. t.Skip("see golang.org/issue/11454")
  456. }
  457. once.Do(startServer)
  458. conn, err := net.Dial("tcp", serverAddr)
  459. if err != nil {
  460. t.Fatal("dialing", err)
  461. }
  462. cc := closerConn{Conn: conn}
  463. client, err := NewClient(newConfig(t, "/echo"), &cc)
  464. if err != nil {
  465. t.Fatalf("WebSocket handshake: %v", err)
  466. }
  467. // set the deadline to ten minutes ago, which will have expired by the time
  468. // client.Close sends the close status frame.
  469. conn.SetDeadline(time.Now().Add(-10 * time.Minute))
  470. if err := client.Close(); err == nil {
  471. t.Errorf("ws.Close(): expected error, got %v", err)
  472. }
  473. if cc.closed < 1 {
  474. t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
  475. }
  476. }
  477. var originTests = []struct {
  478. req *http.Request
  479. origin *url.URL
  480. }{
  481. {
  482. req: &http.Request{
  483. Header: http.Header{
  484. "Origin": []string{"http://www.example.com"},
  485. },
  486. },
  487. origin: &url.URL{
  488. Scheme: "http",
  489. Host: "www.example.com",
  490. },
  491. },
  492. {
  493. req: &http.Request{},
  494. },
  495. }
  496. func TestOrigin(t *testing.T) {
  497. conf := newConfig(t, "/echo")
  498. conf.Version = ProtocolVersionHybi13
  499. for i, tt := range originTests {
  500. origin, err := Origin(conf, tt.req)
  501. if err != nil {
  502. t.Error(err)
  503. continue
  504. }
  505. if !reflect.DeepEqual(origin, tt.origin) {
  506. t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
  507. continue
  508. }
  509. }
  510. }
  511. func TestCtrlAndData(t *testing.T) {
  512. once.Do(startServer)
  513. c, err := net.Dial("tcp", serverAddr)
  514. if err != nil {
  515. t.Fatal(err)
  516. }
  517. ws, err := NewClient(newConfig(t, "/ctrldata"), c)
  518. if err != nil {
  519. t.Fatal(err)
  520. }
  521. defer ws.Close()
  522. h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
  523. ws.frameHandler = h
  524. b := make([]byte, 128)
  525. for i := 0; i < 2; i++ {
  526. data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
  527. if _, err := ws.Write(data); err != nil {
  528. t.Fatalf("#%d: %v", i, err)
  529. }
  530. var ctrl []byte
  531. if i%2 != 0 { // with or without payload
  532. ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
  533. }
  534. if _, err := h.WritePing(ctrl); err != nil {
  535. t.Fatalf("#%d: %v", i, err)
  536. }
  537. n, err := ws.Read(b)
  538. if err != nil {
  539. t.Fatalf("#%d: %v", i, err)
  540. }
  541. if !bytes.Equal(b[:n], data) {
  542. t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
  543. }
  544. }
  545. }