websocket_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bytes"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/http/httptest"
  13. "net/url"
  14. "reflect"
  15. "strings"
  16. "sync"
  17. "testing"
  18. "time"
  19. )
  20. var serverAddr string
  21. var once sync.Once
  22. func echoServer(ws *Conn) { io.Copy(ws, ws) }
  23. type Count struct {
  24. S string
  25. N int
  26. }
  27. func countServer(ws *Conn) {
  28. for {
  29. var count Count
  30. err := JSON.Receive(ws, &count)
  31. if err != nil {
  32. return
  33. }
  34. count.N++
  35. count.S = strings.Repeat(count.S, count.N)
  36. err = JSON.Send(ws, count)
  37. if err != nil {
  38. return
  39. }
  40. }
  41. }
  42. func subProtocolHandshake(config *Config, req *http.Request) error {
  43. for _, proto := range config.Protocol {
  44. if proto == "chat" {
  45. config.Protocol = []string{proto}
  46. return nil
  47. }
  48. }
  49. return ErrBadWebSocketProtocol
  50. }
  51. func subProtoServer(ws *Conn) {
  52. for _, proto := range ws.Config().Protocol {
  53. io.WriteString(ws, proto)
  54. }
  55. }
  56. func startServer() {
  57. http.Handle("/echo", Handler(echoServer))
  58. http.Handle("/count", Handler(countServer))
  59. subproto := Server{
  60. Handshake: subProtocolHandshake,
  61. Handler: Handler(subProtoServer),
  62. }
  63. http.Handle("/subproto", subproto)
  64. server := httptest.NewServer(nil)
  65. serverAddr = server.Listener.Addr().String()
  66. log.Print("Test WebSocket server listening on ", serverAddr)
  67. }
  68. func newConfig(t *testing.T, path string) *Config {
  69. config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
  70. return config
  71. }
  72. func TestEcho(t *testing.T) {
  73. once.Do(startServer)
  74. // websocket.Dial()
  75. client, err := net.Dial("tcp", serverAddr)
  76. if err != nil {
  77. t.Fatal("dialing", err)
  78. }
  79. conn, err := NewClient(newConfig(t, "/echo"), client)
  80. if err != nil {
  81. t.Errorf("WebSocket handshake error: %v", err)
  82. return
  83. }
  84. msg := []byte("hello, world\n")
  85. if _, err := conn.Write(msg); err != nil {
  86. t.Errorf("Write: %v", err)
  87. }
  88. var actual_msg = make([]byte, 512)
  89. n, err := conn.Read(actual_msg)
  90. if err != nil {
  91. t.Errorf("Read: %v", err)
  92. }
  93. actual_msg = actual_msg[0:n]
  94. if !bytes.Equal(msg, actual_msg) {
  95. t.Errorf("Echo: expected %q got %q", msg, actual_msg)
  96. }
  97. conn.Close()
  98. }
  99. func TestAddr(t *testing.T) {
  100. once.Do(startServer)
  101. // websocket.Dial()
  102. client, err := net.Dial("tcp", serverAddr)
  103. if err != nil {
  104. t.Fatal("dialing", err)
  105. }
  106. conn, err := NewClient(newConfig(t, "/echo"), client)
  107. if err != nil {
  108. t.Errorf("WebSocket handshake error: %v", err)
  109. return
  110. }
  111. ra := conn.RemoteAddr().String()
  112. if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
  113. t.Errorf("Bad remote addr: %v", ra)
  114. }
  115. la := conn.LocalAddr().String()
  116. if !strings.HasPrefix(la, "http://") {
  117. t.Errorf("Bad local addr: %v", la)
  118. }
  119. conn.Close()
  120. }
  121. func TestCount(t *testing.T) {
  122. once.Do(startServer)
  123. // websocket.Dial()
  124. client, err := net.Dial("tcp", serverAddr)
  125. if err != nil {
  126. t.Fatal("dialing", err)
  127. }
  128. conn, err := NewClient(newConfig(t, "/count"), client)
  129. if err != nil {
  130. t.Errorf("WebSocket handshake error: %v", err)
  131. return
  132. }
  133. var count Count
  134. count.S = "hello"
  135. if err := JSON.Send(conn, count); err != nil {
  136. t.Errorf("Write: %v", err)
  137. }
  138. if err := JSON.Receive(conn, &count); err != nil {
  139. t.Errorf("Read: %v", err)
  140. }
  141. if count.N != 1 {
  142. t.Errorf("count: expected %d got %d", 1, count.N)
  143. }
  144. if count.S != "hello" {
  145. t.Errorf("count: expected %q got %q", "hello", count.S)
  146. }
  147. if err := JSON.Send(conn, count); err != nil {
  148. t.Errorf("Write: %v", err)
  149. }
  150. if err := JSON.Receive(conn, &count); err != nil {
  151. t.Errorf("Read: %v", err)
  152. }
  153. if count.N != 2 {
  154. t.Errorf("count: expected %d got %d", 2, count.N)
  155. }
  156. if count.S != "hellohello" {
  157. t.Errorf("count: expected %q got %q", "hellohello", count.S)
  158. }
  159. conn.Close()
  160. }
  161. func TestWithQuery(t *testing.T) {
  162. once.Do(startServer)
  163. client, err := net.Dial("tcp", serverAddr)
  164. if err != nil {
  165. t.Fatal("dialing", err)
  166. }
  167. config := newConfig(t, "/echo")
  168. config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
  169. if err != nil {
  170. t.Fatal("location url", err)
  171. }
  172. ws, err := NewClient(config, client)
  173. if err != nil {
  174. t.Errorf("WebSocket handshake: %v", err)
  175. return
  176. }
  177. ws.Close()
  178. }
  179. func testWithProtocol(t *testing.T, subproto []string) (string, error) {
  180. once.Do(startServer)
  181. client, err := net.Dial("tcp", serverAddr)
  182. if err != nil {
  183. t.Fatal("dialing", err)
  184. }
  185. config := newConfig(t, "/subproto")
  186. config.Protocol = subproto
  187. ws, err := NewClient(config, client)
  188. if err != nil {
  189. return "", err
  190. }
  191. msg := make([]byte, 16)
  192. n, err := ws.Read(msg)
  193. if err != nil {
  194. return "", err
  195. }
  196. ws.Close()
  197. return string(msg[:n]), nil
  198. }
  199. func TestWithProtocol(t *testing.T) {
  200. proto, err := testWithProtocol(t, []string{"chat"})
  201. if err != nil {
  202. t.Errorf("SubProto: unexpected error: %v", err)
  203. }
  204. if proto != "chat" {
  205. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  206. }
  207. }
  208. func TestWithTwoProtocol(t *testing.T) {
  209. proto, err := testWithProtocol(t, []string{"test", "chat"})
  210. if err != nil {
  211. t.Errorf("SubProto: unexpected error: %v", err)
  212. }
  213. if proto != "chat" {
  214. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  215. }
  216. }
  217. func TestWithBadProtocol(t *testing.T) {
  218. _, err := testWithProtocol(t, []string{"test"})
  219. if err != ErrBadStatus {
  220. t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
  221. }
  222. }
  223. func TestHTTP(t *testing.T) {
  224. once.Do(startServer)
  225. // If the client did not send a handshake that matches the protocol
  226. // specification, the server MUST return an HTTP response with an
  227. // appropriate error code (such as 400 Bad Request)
  228. resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
  229. if err != nil {
  230. t.Errorf("Get: error %#v", err)
  231. return
  232. }
  233. if resp == nil {
  234. t.Error("Get: resp is null")
  235. return
  236. }
  237. if resp.StatusCode != http.StatusBadRequest {
  238. t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
  239. }
  240. }
  241. func TestTrailingSpaces(t *testing.T) {
  242. // http://code.google.com/p/go/issues/detail?id=955
  243. // The last runs of this create keys with trailing spaces that should not be
  244. // generated by the client.
  245. once.Do(startServer)
  246. config := newConfig(t, "/echo")
  247. for i := 0; i < 30; i++ {
  248. // body
  249. ws, err := DialConfig(config)
  250. if err != nil {
  251. t.Errorf("Dial #%d failed: %v", i, err)
  252. break
  253. }
  254. ws.Close()
  255. }
  256. }
  257. func TestDialConfigBadVersion(t *testing.T) {
  258. once.Do(startServer)
  259. config := newConfig(t, "/echo")
  260. config.Version = 1234
  261. _, err := DialConfig(config)
  262. if dialerr, ok := err.(*DialError); ok {
  263. if dialerr.Err != ErrBadProtocolVersion {
  264. t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
  265. }
  266. }
  267. }
  268. func TestSmallBuffer(t *testing.T) {
  269. // http://code.google.com/p/go/issues/detail?id=1145
  270. // Read should be able to handle reading a fragment of a frame.
  271. once.Do(startServer)
  272. // websocket.Dial()
  273. client, err := net.Dial("tcp", serverAddr)
  274. if err != nil {
  275. t.Fatal("dialing", err)
  276. }
  277. conn, err := NewClient(newConfig(t, "/echo"), client)
  278. if err != nil {
  279. t.Errorf("WebSocket handshake error: %v", err)
  280. return
  281. }
  282. msg := []byte("hello, world\n")
  283. if _, err := conn.Write(msg); err != nil {
  284. t.Errorf("Write: %v", err)
  285. }
  286. var small_msg = make([]byte, 8)
  287. n, err := conn.Read(small_msg)
  288. if err != nil {
  289. t.Errorf("Read: %v", err)
  290. }
  291. if !bytes.Equal(msg[:len(small_msg)], small_msg) {
  292. t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
  293. }
  294. var second_msg = make([]byte, len(msg))
  295. n, err = conn.Read(second_msg)
  296. if err != nil {
  297. t.Errorf("Read: %v", err)
  298. }
  299. second_msg = second_msg[0:n]
  300. if !bytes.Equal(msg[len(small_msg):], second_msg) {
  301. t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
  302. }
  303. conn.Close()
  304. }
  305. var parseAuthorityTests = []struct {
  306. in *url.URL
  307. out string
  308. }{
  309. {
  310. &url.URL{
  311. Scheme: "ws",
  312. Host: "www.google.com",
  313. },
  314. "www.google.com:80",
  315. },
  316. {
  317. &url.URL{
  318. Scheme: "wss",
  319. Host: "www.google.com",
  320. },
  321. "www.google.com:443",
  322. },
  323. {
  324. &url.URL{
  325. Scheme: "ws",
  326. Host: "www.google.com:80",
  327. },
  328. "www.google.com:80",
  329. },
  330. {
  331. &url.URL{
  332. Scheme: "wss",
  333. Host: "www.google.com:443",
  334. },
  335. "www.google.com:443",
  336. },
  337. // some invalid ones for parseAuthority. parseAuthority doesn't
  338. // concern itself with the scheme unless it actually knows about it
  339. {
  340. &url.URL{
  341. Scheme: "http",
  342. Host: "www.google.com",
  343. },
  344. "www.google.com",
  345. },
  346. {
  347. &url.URL{
  348. Scheme: "http",
  349. Host: "www.google.com:80",
  350. },
  351. "www.google.com:80",
  352. },
  353. {
  354. &url.URL{
  355. Scheme: "asdf",
  356. Host: "127.0.0.1",
  357. },
  358. "127.0.0.1",
  359. },
  360. {
  361. &url.URL{
  362. Scheme: "asdf",
  363. Host: "www.google.com",
  364. },
  365. "www.google.com",
  366. },
  367. }
  368. func TestParseAuthority(t *testing.T) {
  369. for _, tt := range parseAuthorityTests {
  370. out := parseAuthority(tt.in)
  371. if out != tt.out {
  372. t.Errorf("got %v; want %v", out, tt.out)
  373. }
  374. }
  375. }
  376. type closerConn struct {
  377. net.Conn
  378. closed int // count of the number of times Close was called
  379. }
  380. func (c *closerConn) Close() error {
  381. c.closed++
  382. return c.Conn.Close()
  383. }
  384. func TestClose(t *testing.T) {
  385. if runtime.GOOS == "plan9" {
  386. t.Skip("see golang.org/issue/11454")
  387. }
  388. once.Do(startServer)
  389. conn, err := net.Dial("tcp", serverAddr)
  390. if err != nil {
  391. t.Fatal("dialing", err)
  392. }
  393. cc := closerConn{Conn: conn}
  394. client, err := NewClient(newConfig(t, "/echo"), &cc)
  395. if err != nil {
  396. t.Fatalf("WebSocket handshake: %v", err)
  397. }
  398. // set the deadline to ten minutes ago, which will have expired by the time
  399. // client.Close sends the close status frame.
  400. conn.SetDeadline(time.Now().Add(-10 * time.Minute))
  401. if err := client.Close(); err == nil {
  402. t.Errorf("ws.Close(): expected error, got %v", err)
  403. }
  404. if cc.closed < 1 {
  405. t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
  406. }
  407. }
  408. var originTests = []struct {
  409. req *http.Request
  410. origin *url.URL
  411. }{
  412. {
  413. req: &http.Request{
  414. Header: http.Header{
  415. "Origin": []string{"http://www.example.com"},
  416. },
  417. },
  418. origin: &url.URL{
  419. Scheme: "http",
  420. Host: "www.example.com",
  421. },
  422. },
  423. {
  424. req: &http.Request{},
  425. },
  426. }
  427. func TestOrigin(t *testing.T) {
  428. conf := newConfig(t, "/echo")
  429. conf.Version = ProtocolVersionHybi13
  430. for i, tt := range originTests {
  431. origin, err := Origin(conf, tt.req)
  432. if err != nil {
  433. t.Error(err)
  434. continue
  435. }
  436. if !reflect.DeepEqual(origin, tt.origin) {
  437. t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
  438. continue
  439. }
  440. }
  441. }