websocket_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bytes"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/http/httptest"
  13. "net/url"
  14. "reflect"
  15. "runtime"
  16. "strings"
  17. "sync"
  18. "testing"
  19. "time"
  20. )
  21. var serverAddr string
  22. var once sync.Once
  23. func echoServer(ws *Conn) { io.Copy(ws, ws) }
  24. type Count struct {
  25. S string
  26. N int
  27. }
  28. func countServer(ws *Conn) {
  29. for {
  30. var count Count
  31. err := JSON.Receive(ws, &count)
  32. if err != nil {
  33. return
  34. }
  35. count.N++
  36. count.S = strings.Repeat(count.S, count.N)
  37. err = JSON.Send(ws, count)
  38. if err != nil {
  39. return
  40. }
  41. }
  42. }
  43. func subProtocolHandshake(config *Config, req *http.Request) error {
  44. for _, proto := range config.Protocol {
  45. if proto == "chat" {
  46. config.Protocol = []string{proto}
  47. return nil
  48. }
  49. }
  50. return ErrBadWebSocketProtocol
  51. }
  52. func subProtoServer(ws *Conn) {
  53. for _, proto := range ws.Config().Protocol {
  54. io.WriteString(ws, proto)
  55. }
  56. }
  57. func startServer() {
  58. http.Handle("/echo", Handler(echoServer))
  59. http.Handle("/count", Handler(countServer))
  60. subproto := Server{
  61. Handshake: subProtocolHandshake,
  62. Handler: Handler(subProtoServer),
  63. }
  64. http.Handle("/subproto", subproto)
  65. server := httptest.NewServer(nil)
  66. serverAddr = server.Listener.Addr().String()
  67. log.Print("Test WebSocket server listening on ", serverAddr)
  68. }
  69. func newConfig(t *testing.T, path string) *Config {
  70. config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
  71. return config
  72. }
  73. func TestEcho(t *testing.T) {
  74. once.Do(startServer)
  75. // websocket.Dial()
  76. client, err := net.Dial("tcp", serverAddr)
  77. if err != nil {
  78. t.Fatal("dialing", err)
  79. }
  80. conn, err := NewClient(newConfig(t, "/echo"), client)
  81. if err != nil {
  82. t.Errorf("WebSocket handshake error: %v", err)
  83. return
  84. }
  85. msg := []byte("hello, world\n")
  86. if _, err := conn.Write(msg); err != nil {
  87. t.Errorf("Write: %v", err)
  88. }
  89. var actual_msg = make([]byte, 512)
  90. n, err := conn.Read(actual_msg)
  91. if err != nil {
  92. t.Errorf("Read: %v", err)
  93. }
  94. actual_msg = actual_msg[0:n]
  95. if !bytes.Equal(msg, actual_msg) {
  96. t.Errorf("Echo: expected %q got %q", msg, actual_msg)
  97. }
  98. conn.Close()
  99. }
  100. func TestAddr(t *testing.T) {
  101. once.Do(startServer)
  102. // websocket.Dial()
  103. client, err := net.Dial("tcp", serverAddr)
  104. if err != nil {
  105. t.Fatal("dialing", err)
  106. }
  107. conn, err := NewClient(newConfig(t, "/echo"), client)
  108. if err != nil {
  109. t.Errorf("WebSocket handshake error: %v", err)
  110. return
  111. }
  112. ra := conn.RemoteAddr().String()
  113. if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
  114. t.Errorf("Bad remote addr: %v", ra)
  115. }
  116. la := conn.LocalAddr().String()
  117. if !strings.HasPrefix(la, "http://") {
  118. t.Errorf("Bad local addr: %v", la)
  119. }
  120. conn.Close()
  121. }
  122. func TestCount(t *testing.T) {
  123. once.Do(startServer)
  124. // websocket.Dial()
  125. client, err := net.Dial("tcp", serverAddr)
  126. if err != nil {
  127. t.Fatal("dialing", err)
  128. }
  129. conn, err := NewClient(newConfig(t, "/count"), client)
  130. if err != nil {
  131. t.Errorf("WebSocket handshake error: %v", err)
  132. return
  133. }
  134. var count Count
  135. count.S = "hello"
  136. if err := JSON.Send(conn, count); err != nil {
  137. t.Errorf("Write: %v", err)
  138. }
  139. if err := JSON.Receive(conn, &count); err != nil {
  140. t.Errorf("Read: %v", err)
  141. }
  142. if count.N != 1 {
  143. t.Errorf("count: expected %d got %d", 1, count.N)
  144. }
  145. if count.S != "hello" {
  146. t.Errorf("count: expected %q got %q", "hello", count.S)
  147. }
  148. if err := JSON.Send(conn, count); err != nil {
  149. t.Errorf("Write: %v", err)
  150. }
  151. if err := JSON.Receive(conn, &count); err != nil {
  152. t.Errorf("Read: %v", err)
  153. }
  154. if count.N != 2 {
  155. t.Errorf("count: expected %d got %d", 2, count.N)
  156. }
  157. if count.S != "hellohello" {
  158. t.Errorf("count: expected %q got %q", "hellohello", count.S)
  159. }
  160. conn.Close()
  161. }
  162. func TestWithQuery(t *testing.T) {
  163. once.Do(startServer)
  164. client, err := net.Dial("tcp", serverAddr)
  165. if err != nil {
  166. t.Fatal("dialing", err)
  167. }
  168. config := newConfig(t, "/echo")
  169. config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
  170. if err != nil {
  171. t.Fatal("location url", err)
  172. }
  173. ws, err := NewClient(config, client)
  174. if err != nil {
  175. t.Errorf("WebSocket handshake: %v", err)
  176. return
  177. }
  178. ws.Close()
  179. }
  180. func testWithProtocol(t *testing.T, subproto []string) (string, error) {
  181. once.Do(startServer)
  182. client, err := net.Dial("tcp", serverAddr)
  183. if err != nil {
  184. t.Fatal("dialing", err)
  185. }
  186. config := newConfig(t, "/subproto")
  187. config.Protocol = subproto
  188. ws, err := NewClient(config, client)
  189. if err != nil {
  190. return "", err
  191. }
  192. msg := make([]byte, 16)
  193. n, err := ws.Read(msg)
  194. if err != nil {
  195. return "", err
  196. }
  197. ws.Close()
  198. return string(msg[:n]), nil
  199. }
  200. func TestWithProtocol(t *testing.T) {
  201. proto, err := testWithProtocol(t, []string{"chat"})
  202. if err != nil {
  203. t.Errorf("SubProto: unexpected error: %v", err)
  204. }
  205. if proto != "chat" {
  206. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  207. }
  208. }
  209. func TestWithTwoProtocol(t *testing.T) {
  210. proto, err := testWithProtocol(t, []string{"test", "chat"})
  211. if err != nil {
  212. t.Errorf("SubProto: unexpected error: %v", err)
  213. }
  214. if proto != "chat" {
  215. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  216. }
  217. }
  218. func TestWithBadProtocol(t *testing.T) {
  219. _, err := testWithProtocol(t, []string{"test"})
  220. if err != ErrBadStatus {
  221. t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
  222. }
  223. }
  224. func TestHTTP(t *testing.T) {
  225. once.Do(startServer)
  226. // If the client did not send a handshake that matches the protocol
  227. // specification, the server MUST return an HTTP response with an
  228. // appropriate error code (such as 400 Bad Request)
  229. resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
  230. if err != nil {
  231. t.Errorf("Get: error %#v", err)
  232. return
  233. }
  234. if resp == nil {
  235. t.Error("Get: resp is null")
  236. return
  237. }
  238. if resp.StatusCode != http.StatusBadRequest {
  239. t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
  240. }
  241. }
  242. func TestTrailingSpaces(t *testing.T) {
  243. // http://code.google.com/p/go/issues/detail?id=955
  244. // The last runs of this create keys with trailing spaces that should not be
  245. // generated by the client.
  246. once.Do(startServer)
  247. config := newConfig(t, "/echo")
  248. for i := 0; i < 30; i++ {
  249. // body
  250. ws, err := DialConfig(config)
  251. if err != nil {
  252. t.Errorf("Dial #%d failed: %v", i, err)
  253. break
  254. }
  255. ws.Close()
  256. }
  257. }
  258. func TestDialConfigBadVersion(t *testing.T) {
  259. once.Do(startServer)
  260. config := newConfig(t, "/echo")
  261. config.Version = 1234
  262. _, err := DialConfig(config)
  263. if dialerr, ok := err.(*DialError); ok {
  264. if dialerr.Err != ErrBadProtocolVersion {
  265. t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
  266. }
  267. }
  268. }
  269. func TestSmallBuffer(t *testing.T) {
  270. // http://code.google.com/p/go/issues/detail?id=1145
  271. // Read should be able to handle reading a fragment of a frame.
  272. once.Do(startServer)
  273. // websocket.Dial()
  274. client, err := net.Dial("tcp", serverAddr)
  275. if err != nil {
  276. t.Fatal("dialing", err)
  277. }
  278. conn, err := NewClient(newConfig(t, "/echo"), client)
  279. if err != nil {
  280. t.Errorf("WebSocket handshake error: %v", err)
  281. return
  282. }
  283. msg := []byte("hello, world\n")
  284. if _, err := conn.Write(msg); err != nil {
  285. t.Errorf("Write: %v", err)
  286. }
  287. var small_msg = make([]byte, 8)
  288. n, err := conn.Read(small_msg)
  289. if err != nil {
  290. t.Errorf("Read: %v", err)
  291. }
  292. if !bytes.Equal(msg[:len(small_msg)], small_msg) {
  293. t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
  294. }
  295. var second_msg = make([]byte, len(msg))
  296. n, err = conn.Read(second_msg)
  297. if err != nil {
  298. t.Errorf("Read: %v", err)
  299. }
  300. second_msg = second_msg[0:n]
  301. if !bytes.Equal(msg[len(small_msg):], second_msg) {
  302. t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
  303. }
  304. conn.Close()
  305. }
  306. var parseAuthorityTests = []struct {
  307. in *url.URL
  308. out string
  309. }{
  310. {
  311. &url.URL{
  312. Scheme: "ws",
  313. Host: "www.google.com",
  314. },
  315. "www.google.com:80",
  316. },
  317. {
  318. &url.URL{
  319. Scheme: "wss",
  320. Host: "www.google.com",
  321. },
  322. "www.google.com:443",
  323. },
  324. {
  325. &url.URL{
  326. Scheme: "ws",
  327. Host: "www.google.com:80",
  328. },
  329. "www.google.com:80",
  330. },
  331. {
  332. &url.URL{
  333. Scheme: "wss",
  334. Host: "www.google.com:443",
  335. },
  336. "www.google.com:443",
  337. },
  338. // some invalid ones for parseAuthority. parseAuthority doesn't
  339. // concern itself with the scheme unless it actually knows about it
  340. {
  341. &url.URL{
  342. Scheme: "http",
  343. Host: "www.google.com",
  344. },
  345. "www.google.com",
  346. },
  347. {
  348. &url.URL{
  349. Scheme: "http",
  350. Host: "www.google.com:80",
  351. },
  352. "www.google.com:80",
  353. },
  354. {
  355. &url.URL{
  356. Scheme: "asdf",
  357. Host: "127.0.0.1",
  358. },
  359. "127.0.0.1",
  360. },
  361. {
  362. &url.URL{
  363. Scheme: "asdf",
  364. Host: "www.google.com",
  365. },
  366. "www.google.com",
  367. },
  368. }
  369. func TestParseAuthority(t *testing.T) {
  370. for _, tt := range parseAuthorityTests {
  371. out := parseAuthority(tt.in)
  372. if out != tt.out {
  373. t.Errorf("got %v; want %v", out, tt.out)
  374. }
  375. }
  376. }
  377. type closerConn struct {
  378. net.Conn
  379. closed int // count of the number of times Close was called
  380. }
  381. func (c *closerConn) Close() error {
  382. c.closed++
  383. return c.Conn.Close()
  384. }
  385. func TestClose(t *testing.T) {
  386. if runtime.GOOS == "plan9" {
  387. t.Skip("see golang.org/issue/11454")
  388. }
  389. once.Do(startServer)
  390. conn, err := net.Dial("tcp", serverAddr)
  391. if err != nil {
  392. t.Fatal("dialing", err)
  393. }
  394. cc := closerConn{Conn: conn}
  395. client, err := NewClient(newConfig(t, "/echo"), &cc)
  396. if err != nil {
  397. t.Fatalf("WebSocket handshake: %v", err)
  398. }
  399. // set the deadline to ten minutes ago, which will have expired by the time
  400. // client.Close sends the close status frame.
  401. conn.SetDeadline(time.Now().Add(-10 * time.Minute))
  402. if err := client.Close(); err == nil {
  403. t.Errorf("ws.Close(): expected error, got %v", err)
  404. }
  405. if cc.closed < 1 {
  406. t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
  407. }
  408. }
  409. var originTests = []struct {
  410. req *http.Request
  411. origin *url.URL
  412. }{
  413. {
  414. req: &http.Request{
  415. Header: http.Header{
  416. "Origin": []string{"http://www.example.com"},
  417. },
  418. },
  419. origin: &url.URL{
  420. Scheme: "http",
  421. Host: "www.example.com",
  422. },
  423. },
  424. {
  425. req: &http.Request{},
  426. },
  427. }
  428. func TestOrigin(t *testing.T) {
  429. conf := newConfig(t, "/echo")
  430. conf.Version = ProtocolVersionHybi13
  431. for i, tt := range originTests {
  432. origin, err := Origin(conf, tt.req)
  433. if err != nil {
  434. t.Error(err)
  435. continue
  436. }
  437. if !reflect.DeepEqual(origin, tt.origin) {
  438. t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
  439. continue
  440. }
  441. }
  442. }