client_server_test.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "crypto/tls"
  7. "crypto/x509"
  8. "io"
  9. "net"
  10. "net/http"
  11. "net/http/httptest"
  12. "net/url"
  13. "reflect"
  14. "testing"
  15. "time"
  16. )
  17. func sendRecv(t *testing.T, ws *Conn) {
  18. const message = "Hello World!"
  19. if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
  20. t.Fatalf("SetWriteDeadline: %v", err)
  21. }
  22. if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
  23. t.Fatalf("WriteMessage: %v", err)
  24. }
  25. if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
  26. t.Fatalf("SetReadDeadline: %v", err)
  27. }
  28. _, p, err := ws.ReadMessage()
  29. if err != nil {
  30. t.Fatalf("ReadMessage: %v", err)
  31. }
  32. if string(p) != message {
  33. t.Fatalf("message=%s, want %s", p, message)
  34. }
  35. }
  36. func httpToWs(u string) string {
  37. return "ws" + u[len("http"):]
  38. }
  39. var handshakeUpgrader = &Upgrader{
  40. Subprotocols: []string{"p0", "p1"},
  41. ReadBufferSize: 1024,
  42. WriteBufferSize: 1024,
  43. }
  44. var handshakeDialer = &Dialer{
  45. Subprotocols: []string{"p1", "p2"},
  46. ReadBufferSize: 1024,
  47. WriteBufferSize: 1024,
  48. }
  49. type handshakeHandler struct {
  50. *testing.T
  51. }
  52. func (t handshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  53. if r.Method != "GET" {
  54. http.Error(w, "Method not allowed", 405)
  55. t.Logf("method = %s, want GET", r.Method)
  56. return
  57. }
  58. subprotos := Subprotocols(r)
  59. if !reflect.DeepEqual(subprotos, handshakeDialer.Subprotocols) {
  60. http.Error(w, "bad protocol", 400)
  61. t.Logf("Subprotocols = %v, want %v", subprotos, handshakeDialer.Subprotocols)
  62. return
  63. }
  64. ws, err := handshakeUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
  65. if err != nil {
  66. t.Logf("upgrade error: %v", err)
  67. return
  68. }
  69. defer ws.Close()
  70. if ws.Subprotocol() != "p1" {
  71. t.Logf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
  72. return
  73. }
  74. for {
  75. op, r, err := ws.NextReader()
  76. if err != nil {
  77. if err != io.EOF {
  78. t.Logf("NextReader: %v", err)
  79. }
  80. return
  81. }
  82. w, err := ws.NextWriter(op)
  83. if err != nil {
  84. t.Logf("NextWriter: %v", err)
  85. return
  86. }
  87. if _, err = io.Copy(w, r); err != nil {
  88. t.Logf("Copy: %v", err)
  89. return
  90. }
  91. if err := w.Close(); err != nil {
  92. t.Logf("Close: %v", err)
  93. return
  94. }
  95. }
  96. }
  97. func TestHandshake(t *testing.T) {
  98. s := httptest.NewServer(handshakeHandler{t})
  99. defer s.Close()
  100. ws, resp, err := handshakeDialer.Dial(httpToWs(s.URL), http.Header{"Origin": {s.URL}})
  101. if err != nil {
  102. t.Fatalf("Dial: %v", err)
  103. }
  104. defer ws.Close()
  105. var sessionID string
  106. for _, c := range resp.Cookies() {
  107. if c.Name == "sessionID" {
  108. sessionID = c.Value
  109. }
  110. }
  111. if sessionID != "1234" {
  112. t.Error("Set-Cookie not received from the server.")
  113. }
  114. if ws.Subprotocol() != "p1" {
  115. t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
  116. }
  117. sendRecv(t, ws)
  118. }
  119. type dialHandler struct {
  120. *testing.T
  121. }
  122. var dialUpgrader = &Upgrader{
  123. ReadBufferSize: 1024,
  124. WriteBufferSize: 1024,
  125. CheckOrigin: func(r *http.Request) bool { return true },
  126. }
  127. func (t dialHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  128. ws, err := dialUpgrader.Upgrade(w, r, nil)
  129. if err != nil {
  130. t.Logf("upgrade error: %v", err)
  131. return
  132. }
  133. defer ws.Close()
  134. for {
  135. mt, p, err := ws.ReadMessage()
  136. if err != nil {
  137. if err != io.EOF {
  138. t.Logf("ReadMessage: %v", err)
  139. }
  140. return
  141. }
  142. if err := ws.WriteMessage(mt, p); err != nil {
  143. t.Logf("WriteMessage: %v", err)
  144. return
  145. }
  146. }
  147. }
  148. func TestDial(t *testing.T) {
  149. s := httptest.NewServer(dialHandler{t})
  150. defer s.Close()
  151. ws, _, err := DefaultDialer.Dial(httpToWs(s.URL), nil)
  152. if err != nil {
  153. t.Fatalf("Dial() returned error %v", err)
  154. }
  155. defer ws.Close()
  156. sendRecv(t, ws)
  157. }
  158. func TestDialTLS(t *testing.T) {
  159. s := httptest.NewTLSServer(dialHandler{t})
  160. defer s.Close()
  161. certs := x509.NewCertPool()
  162. for _, c := range s.TLS.Certificates {
  163. roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
  164. if err != nil {
  165. t.Fatalf("error parsing server's root cert: %v", err)
  166. }
  167. for _, root := range roots {
  168. certs.AddCert(root)
  169. }
  170. }
  171. u, _ := url.Parse(s.URL)
  172. d := &Dialer{
  173. NetDial: func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) },
  174. TLSClientConfig: &tls.Config{RootCAs: certs},
  175. }
  176. ws, _, err := d.Dial("wss://example.com/", nil)
  177. if err != nil {
  178. t.Fatalf("Dial() returned error %v", err)
  179. }
  180. defer ws.Close()
  181. sendRecv(t, ws)
  182. }
  183. func TestDialTLSBadCert(t *testing.T) {
  184. s := httptest.NewTLSServer(dialHandler{t})
  185. defer s.Close()
  186. _, _, err := DefaultDialer.Dial(httpToWs(s.URL), nil)
  187. if err == nil {
  188. t.Fatalf("Dial() did not return error")
  189. }
  190. }
  191. func TestDialTLSNoVerify(t *testing.T) {
  192. s := httptest.NewTLSServer(dialHandler{t})
  193. defer s.Close()
  194. d := &Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
  195. ws, _, err := d.Dial(httpToWs(s.URL), nil)
  196. if err != nil {
  197. t.Fatalf("Dial() returned error %v", err)
  198. }
  199. defer ws.Close()
  200. sendRecv(t, ws)
  201. }
  202. func TestDialTimeout(t *testing.T) {
  203. s := httptest.NewServer(dialHandler{t})
  204. defer s.Close()
  205. d := &Dialer{
  206. HandshakeTimeout: -1,
  207. }
  208. _, _, err := d.Dial(httpToWs(s.URL), nil)
  209. if err == nil {
  210. t.Fatalf("Dial() did not return error")
  211. }
  212. }
  213. func TestDialBadScheme(t *testing.T) {
  214. s := httptest.NewServer(dialHandler{t})
  215. defer s.Close()
  216. _, _, err := DefaultDialer.Dial(s.URL, nil)
  217. if err == nil {
  218. t.Fatalf("Dial() did not return error")
  219. }
  220. }