websocket_test.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bytes"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/http/httptest"
  13. "net/url"
  14. "strings"
  15. "sync"
  16. "testing"
  17. "time"
  18. )
  19. var serverAddr string
  20. var once sync.Once
  21. func echoServer(ws *Conn) { io.Copy(ws, ws) }
  22. type Count struct {
  23. S string
  24. N int
  25. }
  26. func countServer(ws *Conn) {
  27. for {
  28. var count Count
  29. err := JSON.Receive(ws, &count)
  30. if err != nil {
  31. return
  32. }
  33. count.N++
  34. count.S = strings.Repeat(count.S, count.N)
  35. err = JSON.Send(ws, count)
  36. if err != nil {
  37. return
  38. }
  39. }
  40. }
  41. func subProtocolHandshake(config *Config, req *http.Request) error {
  42. for _, proto := range config.Protocol {
  43. if proto == "chat" {
  44. config.Protocol = []string{proto}
  45. return nil
  46. }
  47. }
  48. return ErrBadWebSocketProtocol
  49. }
  50. func subProtoServer(ws *Conn) {
  51. for _, proto := range ws.Config().Protocol {
  52. io.WriteString(ws, proto)
  53. }
  54. }
  55. func startServer() {
  56. http.Handle("/echo", Handler(echoServer))
  57. http.Handle("/count", Handler(countServer))
  58. subproto := Server{
  59. Handshake: subProtocolHandshake,
  60. Handler: Handler(subProtoServer),
  61. }
  62. http.Handle("/subproto", subproto)
  63. server := httptest.NewServer(nil)
  64. serverAddr = server.Listener.Addr().String()
  65. log.Print("Test WebSocket server listening on ", serverAddr)
  66. }
  67. func newConfig(t *testing.T, path string) *Config {
  68. config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
  69. return config
  70. }
  71. func TestEcho(t *testing.T) {
  72. once.Do(startServer)
  73. // websocket.Dial()
  74. client, err := net.Dial("tcp", serverAddr)
  75. if err != nil {
  76. t.Fatal("dialing", err)
  77. }
  78. conn, err := NewClient(newConfig(t, "/echo"), client)
  79. if err != nil {
  80. t.Errorf("WebSocket handshake error: %v", err)
  81. return
  82. }
  83. msg := []byte("hello, world\n")
  84. if _, err := conn.Write(msg); err != nil {
  85. t.Errorf("Write: %v", err)
  86. }
  87. var actual_msg = make([]byte, 512)
  88. n, err := conn.Read(actual_msg)
  89. if err != nil {
  90. t.Errorf("Read: %v", err)
  91. }
  92. actual_msg = actual_msg[0:n]
  93. if !bytes.Equal(msg, actual_msg) {
  94. t.Errorf("Echo: expected %q got %q", msg, actual_msg)
  95. }
  96. conn.Close()
  97. }
  98. func TestAddr(t *testing.T) {
  99. once.Do(startServer)
  100. // websocket.Dial()
  101. client, err := net.Dial("tcp", serverAddr)
  102. if err != nil {
  103. t.Fatal("dialing", err)
  104. }
  105. conn, err := NewClient(newConfig(t, "/echo"), client)
  106. if err != nil {
  107. t.Errorf("WebSocket handshake error: %v", err)
  108. return
  109. }
  110. ra := conn.RemoteAddr().String()
  111. if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
  112. t.Errorf("Bad remote addr: %v", ra)
  113. }
  114. la := conn.LocalAddr().String()
  115. if !strings.HasPrefix(la, "http://") {
  116. t.Errorf("Bad local addr: %v", la)
  117. }
  118. conn.Close()
  119. }
  120. func TestCount(t *testing.T) {
  121. once.Do(startServer)
  122. // websocket.Dial()
  123. client, err := net.Dial("tcp", serverAddr)
  124. if err != nil {
  125. t.Fatal("dialing", err)
  126. }
  127. conn, err := NewClient(newConfig(t, "/count"), client)
  128. if err != nil {
  129. t.Errorf("WebSocket handshake error: %v", err)
  130. return
  131. }
  132. var count Count
  133. count.S = "hello"
  134. if err := JSON.Send(conn, count); err != nil {
  135. t.Errorf("Write: %v", err)
  136. }
  137. if err := JSON.Receive(conn, &count); err != nil {
  138. t.Errorf("Read: %v", err)
  139. }
  140. if count.N != 1 {
  141. t.Errorf("count: expected %d got %d", 1, count.N)
  142. }
  143. if count.S != "hello" {
  144. t.Errorf("count: expected %q got %q", "hello", count.S)
  145. }
  146. if err := JSON.Send(conn, count); err != nil {
  147. t.Errorf("Write: %v", err)
  148. }
  149. if err := JSON.Receive(conn, &count); err != nil {
  150. t.Errorf("Read: %v", err)
  151. }
  152. if count.N != 2 {
  153. t.Errorf("count: expected %d got %d", 2, count.N)
  154. }
  155. if count.S != "hellohello" {
  156. t.Errorf("count: expected %q got %q", "hellohello", count.S)
  157. }
  158. conn.Close()
  159. }
  160. func TestWithQuery(t *testing.T) {
  161. once.Do(startServer)
  162. client, err := net.Dial("tcp", serverAddr)
  163. if err != nil {
  164. t.Fatal("dialing", err)
  165. }
  166. config := newConfig(t, "/echo")
  167. config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
  168. if err != nil {
  169. t.Fatal("location url", err)
  170. }
  171. ws, err := NewClient(config, client)
  172. if err != nil {
  173. t.Errorf("WebSocket handshake: %v", err)
  174. return
  175. }
  176. ws.Close()
  177. }
  178. func testWithProtocol(t *testing.T, subproto []string) (string, error) {
  179. once.Do(startServer)
  180. client, err := net.Dial("tcp", serverAddr)
  181. if err != nil {
  182. t.Fatal("dialing", err)
  183. }
  184. config := newConfig(t, "/subproto")
  185. config.Protocol = subproto
  186. ws, err := NewClient(config, client)
  187. if err != nil {
  188. return "", err
  189. }
  190. msg := make([]byte, 16)
  191. n, err := ws.Read(msg)
  192. if err != nil {
  193. return "", err
  194. }
  195. ws.Close()
  196. return string(msg[:n]), nil
  197. }
  198. func TestWithProtocol(t *testing.T) {
  199. proto, err := testWithProtocol(t, []string{"chat"})
  200. if err != nil {
  201. t.Errorf("SubProto: unexpected error: %v", err)
  202. }
  203. if proto != "chat" {
  204. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  205. }
  206. }
  207. func TestWithTwoProtocol(t *testing.T) {
  208. proto, err := testWithProtocol(t, []string{"test", "chat"})
  209. if err != nil {
  210. t.Errorf("SubProto: unexpected error: %v", err)
  211. }
  212. if proto != "chat" {
  213. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  214. }
  215. }
  216. func TestWithBadProtocol(t *testing.T) {
  217. _, err := testWithProtocol(t, []string{"test"})
  218. if err != ErrBadStatus {
  219. t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
  220. }
  221. }
  222. func TestHTTP(t *testing.T) {
  223. once.Do(startServer)
  224. // If the client did not send a handshake that matches the protocol
  225. // specification, the server MUST return an HTTP response with an
  226. // appropriate error code (such as 400 Bad Request)
  227. resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
  228. if err != nil {
  229. t.Errorf("Get: error %#v", err)
  230. return
  231. }
  232. if resp == nil {
  233. t.Error("Get: resp is null")
  234. return
  235. }
  236. if resp.StatusCode != http.StatusBadRequest {
  237. t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
  238. }
  239. }
  240. func TestTrailingSpaces(t *testing.T) {
  241. // http://code.google.com/p/go/issues/detail?id=955
  242. // The last runs of this create keys with trailing spaces that should not be
  243. // generated by the client.
  244. once.Do(startServer)
  245. config := newConfig(t, "/echo")
  246. for i := 0; i < 30; i++ {
  247. // body
  248. ws, err := DialConfig(config)
  249. if err != nil {
  250. t.Errorf("Dial #%d failed: %v", i, err)
  251. break
  252. }
  253. ws.Close()
  254. }
  255. }
  256. func TestDialConfigBadVersion(t *testing.T) {
  257. once.Do(startServer)
  258. config := newConfig(t, "/echo")
  259. config.Version = 1234
  260. _, err := DialConfig(config)
  261. if dialerr, ok := err.(*DialError); ok {
  262. if dialerr.Err != ErrBadProtocolVersion {
  263. t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
  264. }
  265. }
  266. }
  267. func TestSmallBuffer(t *testing.T) {
  268. // http://code.google.com/p/go/issues/detail?id=1145
  269. // Read should be able to handle reading a fragment of a frame.
  270. once.Do(startServer)
  271. // websocket.Dial()
  272. client, err := net.Dial("tcp", serverAddr)
  273. if err != nil {
  274. t.Fatal("dialing", err)
  275. }
  276. conn, err := NewClient(newConfig(t, "/echo"), client)
  277. if err != nil {
  278. t.Errorf("WebSocket handshake error: %v", err)
  279. return
  280. }
  281. msg := []byte("hello, world\n")
  282. if _, err := conn.Write(msg); err != nil {
  283. t.Errorf("Write: %v", err)
  284. }
  285. var small_msg = make([]byte, 8)
  286. n, err := conn.Read(small_msg)
  287. if err != nil {
  288. t.Errorf("Read: %v", err)
  289. }
  290. if !bytes.Equal(msg[:len(small_msg)], small_msg) {
  291. t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
  292. }
  293. var second_msg = make([]byte, len(msg))
  294. n, err = conn.Read(second_msg)
  295. if err != nil {
  296. t.Errorf("Read: %v", err)
  297. }
  298. second_msg = second_msg[0:n]
  299. if !bytes.Equal(msg[len(small_msg):], second_msg) {
  300. t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
  301. }
  302. conn.Close()
  303. }
  304. var parseAuthorityTests = []struct {
  305. in *url.URL
  306. out string
  307. }{
  308. {
  309. &url.URL{
  310. Scheme: "ws",
  311. Host: "www.google.com",
  312. },
  313. "www.google.com:80",
  314. },
  315. {
  316. &url.URL{
  317. Scheme: "wss",
  318. Host: "www.google.com",
  319. },
  320. "www.google.com:443",
  321. },
  322. {
  323. &url.URL{
  324. Scheme: "ws",
  325. Host: "www.google.com:80",
  326. },
  327. "www.google.com:80",
  328. },
  329. {
  330. &url.URL{
  331. Scheme: "wss",
  332. Host: "www.google.com:443",
  333. },
  334. "www.google.com:443",
  335. },
  336. // some invalid ones for parseAuthority. parseAuthority doesn't
  337. // concern itself with the scheme unless it actually knows about it
  338. {
  339. &url.URL{
  340. Scheme: "http",
  341. Host: "www.google.com",
  342. },
  343. "www.google.com",
  344. },
  345. {
  346. &url.URL{
  347. Scheme: "http",
  348. Host: "www.google.com:80",
  349. },
  350. "www.google.com:80",
  351. },
  352. {
  353. &url.URL{
  354. Scheme: "asdf",
  355. Host: "127.0.0.1",
  356. },
  357. "127.0.0.1",
  358. },
  359. {
  360. &url.URL{
  361. Scheme: "asdf",
  362. Host: "www.google.com",
  363. },
  364. "www.google.com",
  365. },
  366. }
  367. func TestParseAuthority(t *testing.T) {
  368. for _, tt := range parseAuthorityTests {
  369. out := parseAuthority(tt.in)
  370. if out != tt.out {
  371. t.Errorf("got %v; want %v", out, tt.out)
  372. }
  373. }
  374. }
  375. type closerConn struct {
  376. net.Conn
  377. closed int // count of the number of times Close was called
  378. }
  379. func (c *closerConn) Close() error {
  380. c.closed++
  381. return c.Conn.Close()
  382. }
  383. func TestClose(t *testing.T) {
  384. once.Do(startServer)
  385. conn, err := net.Dial("tcp", serverAddr)
  386. if err != nil {
  387. t.Fatal("dialing", err)
  388. }
  389. cc := closerConn{Conn: conn}
  390. client, err := NewClient(newConfig(t, "/echo"), &cc)
  391. if err != nil {
  392. t.Fatalf("WebSocket handshake: %v", err)
  393. }
  394. // set the deadline to ten minutes ago, which will have expired by the time
  395. // client.Close sends the close status frame.
  396. conn.SetDeadline(time.Now().Add(-10 * time.Minute))
  397. if err := client.Close(); err == nil {
  398. t.Errorf("ws.Close(): expected error, got %v", err)
  399. }
  400. if cc.closed < 1 {
  401. t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
  402. }
  403. }