client_server_test.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "crypto/tls"
  7. "crypto/x509"
  8. "io"
  9. "net"
  10. "net/http"
  11. "net/http/httptest"
  12. "net/url"
  13. "reflect"
  14. "testing"
  15. "time"
  16. )
  17. func sendRecv(t *testing.T, ws *Conn) {
  18. const message = "Hello World!"
  19. if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
  20. t.Fatalf("SetWriteDeadline: %v", err)
  21. }
  22. if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
  23. t.Fatalf("WriteMessage: %v", err)
  24. }
  25. if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
  26. t.Fatalf("SetReadDeadline: %v", err)
  27. }
  28. _, p, err := ws.ReadMessage()
  29. if err != nil {
  30. t.Fatalf("ReadMessage: %v", err)
  31. }
  32. if string(p) != message {
  33. t.Fatalf("message=%s, want %s", p, message)
  34. }
  35. }
  36. func httpToWs(u string) string {
  37. return "ws" + u[len("http"):]
  38. }
  39. var handshakeUpgrader = &Upgrader{
  40. Subprotocols: []string{"p0", "p1"},
  41. ReadBufferSize: 1024,
  42. WriteBufferSize: 1024,
  43. }
  44. var handshakeDialer = &Dialer{
  45. Subprotocols: []string{"p1", "p2"},
  46. ReadBufferSize: 1024,
  47. WriteBufferSize: 1024,
  48. }
  49. type handshakeHandler struct {
  50. *testing.T
  51. }
  52. func (t handshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  53. if r.Method != "GET" {
  54. http.Error(w, "Method not allowed", 405)
  55. t.Logf("method = %s, want GET", r.Method)
  56. return
  57. }
  58. subprotos := Subprotocols(r)
  59. if !reflect.DeepEqual(subprotos, handshakeDialer.Subprotocols) {
  60. http.Error(w, "bad protocol", 400)
  61. t.Logf("Subprotocols = %v, want %v", subprotos, handshakeDialer.Subprotocols)
  62. return
  63. }
  64. ws, err := handshakeUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
  65. if err != nil {
  66. t.Logf("upgrade error: %v", err)
  67. return
  68. }
  69. defer ws.Close()
  70. if ws.Subprotocol() != "p1" {
  71. t.Logf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
  72. return
  73. }
  74. for {
  75. op, r, err := ws.NextReader()
  76. if err != nil {
  77. if err != io.EOF {
  78. t.Logf("NextReader: %v", err)
  79. }
  80. return
  81. }
  82. w, err := ws.NextWriter(op)
  83. if err != nil {
  84. t.Logf("NextWriter: %v", err)
  85. return
  86. }
  87. if _, err = io.Copy(w, r); err != nil {
  88. t.Logf("Copy: %v", err)
  89. return
  90. }
  91. if err := w.Close(); err != nil {
  92. t.Logf("Close: %v", err)
  93. return
  94. }
  95. }
  96. }
  97. func TestHandshake(t *testing.T) {
  98. s := httptest.NewServer(handshakeHandler{t})
  99. defer s.Close()
  100. ws, resp, err := handshakeDialer.Dial(httpToWs(s.URL), http.Header{"Origin": {s.URL}})
  101. if err != nil {
  102. t.Fatalf("Dial: %v", err)
  103. }
  104. defer ws.Close()
  105. var sessionID string
  106. for _, c := range resp.Cookies() {
  107. if c.Name == "sessionID" {
  108. sessionID = c.Value
  109. }
  110. }
  111. if sessionID != "1234" {
  112. t.Error("Set-Cookie not received from the server.")
  113. }
  114. if ws.Subprotocol() != "p1" {
  115. t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
  116. }
  117. sendRecv(t, ws)
  118. }
  119. type dialHandler struct {
  120. *testing.T
  121. }
  122. var dialUpgrader = &Upgrader{
  123. ReadBufferSize: 1024,
  124. WriteBufferSize: 1024,
  125. }
  126. func (t dialHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  127. ws, err := dialUpgrader.Upgrade(w, r, nil)
  128. if err != nil {
  129. t.Logf("upgrade error: %v", err)
  130. return
  131. }
  132. defer ws.Close()
  133. for {
  134. mt, p, err := ws.ReadMessage()
  135. if err != nil {
  136. if err != io.EOF {
  137. t.Logf("ReadMessage: %v", err)
  138. }
  139. return
  140. }
  141. if err := ws.WriteMessage(mt, p); err != nil {
  142. t.Logf("WriteMessage: %v", err)
  143. return
  144. }
  145. }
  146. }
  147. func TestDial(t *testing.T) {
  148. s := httptest.NewServer(dialHandler{t})
  149. defer s.Close()
  150. ws, _, err := DefaultDialer.Dial(httpToWs(s.URL), nil)
  151. if err != nil {
  152. t.Fatalf("Dial() returned error %v", err)
  153. }
  154. defer ws.Close()
  155. sendRecv(t, ws)
  156. }
  157. func TestDialTLS(t *testing.T) {
  158. s := httptest.NewTLSServer(dialHandler{t})
  159. defer s.Close()
  160. certs := x509.NewCertPool()
  161. for _, c := range s.TLS.Certificates {
  162. roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
  163. if err != nil {
  164. t.Fatalf("error parsing server's root cert: %v", err)
  165. }
  166. for _, root := range roots {
  167. certs.AddCert(root)
  168. }
  169. }
  170. u, _ := url.Parse(s.URL)
  171. d := &Dialer{
  172. NetDial: func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) },
  173. TLSClientConfig: &tls.Config{RootCAs: certs},
  174. }
  175. ws, _, err := d.Dial("wss://example.com/", nil)
  176. if err != nil {
  177. t.Fatalf("Dial() returned error %v", err)
  178. }
  179. defer ws.Close()
  180. sendRecv(t, ws)
  181. }
  182. func TestDialTLSBadCert(t *testing.T) {
  183. s := httptest.NewTLSServer(dialHandler{t})
  184. defer s.Close()
  185. _, _, err := DefaultDialer.Dial(httpToWs(s.URL), nil)
  186. if err == nil {
  187. t.Fatalf("Dial() did not return error")
  188. }
  189. }
  190. func TestDialTLSNoVerify(t *testing.T) {
  191. s := httptest.NewTLSServer(dialHandler{t})
  192. defer s.Close()
  193. d := &Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
  194. ws, _, err := d.Dial(httpToWs(s.URL), nil)
  195. if err != nil {
  196. t.Fatalf("Dial() returned error %v", err)
  197. }
  198. defer ws.Close()
  199. sendRecv(t, ws)
  200. }
  201. func TestDialTimeout(t *testing.T) {
  202. s := httptest.NewServer(dialHandler{t})
  203. defer s.Close()
  204. d := &Dialer{
  205. HandshakeTimeout: -1,
  206. }
  207. _, _, err := d.Dial(httpToWs(s.URL), nil)
  208. if err == nil {
  209. t.Fatalf("Dial() did not return error")
  210. }
  211. }
  212. func TestDialBadScheme(t *testing.T) {
  213. s := httptest.NewServer(dialHandler{t})
  214. defer s.Close()
  215. _, _, err := DefaultDialer.Dial(s.URL, nil)
  216. if err == nil {
  217. t.Fatalf("Dial() did not return error")
  218. }
  219. }
  220. func TestDialBadOrigin(t *testing.T) {
  221. s := httptest.NewServer(dialHandler{t})
  222. defer s.Close()
  223. _, _, err := DefaultDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
  224. if err == nil {
  225. t.Fatalf("Dial() did not return error")
  226. }
  227. }