client_server_test.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. // Copyright 2013 Gary Burd. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "crypto/tls"
  7. "crypto/x509"
  8. "io"
  9. "net"
  10. "net/http"
  11. "net/http/httptest"
  12. "net/url"
  13. "reflect"
  14. "testing"
  15. "time"
  16. )
  17. type handshakeHandler struct {
  18. *testing.T
  19. }
  20. func (t handshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  21. if r.Method != "GET" {
  22. http.Error(w, "Method not allowed", 405)
  23. t.Logf("method = %s, want GET", r.Method)
  24. return
  25. }
  26. if origin := r.Header.Get("Origin"); origin != "http://"+r.Host {
  27. http.Error(w, "Origin not allowed", 403)
  28. t.Logf("Origin = %s, want %s", origin, r.Host)
  29. return
  30. }
  31. subprotos := Subprotocols(r)
  32. if !reflect.DeepEqual(subprotos, handshakeDialer.Subprotocols) {
  33. http.Error(w, "bad protocol", 400)
  34. t.Logf("Subprotocols = %v, want %v", subprotos, handshakeDialer.Subprotocols)
  35. return
  36. }
  37. ws, err := Upgrade(w, r, http.Header{
  38. "Set-Cookie": {"sessionID=1234"},
  39. "Sec-Websocket-Protocol": {subprotos[0]},
  40. }, 1024, 1024)
  41. if _, ok := err.(HandshakeError); ok {
  42. t.Logf("bad handshake: %v", err)
  43. http.Error(w, "Not a websocket handshake", 400)
  44. return
  45. } else if err != nil {
  46. t.Logf("upgrade error: %v", err)
  47. return
  48. }
  49. defer ws.Close()
  50. if ws.Subprotocol() != subprotos[0] {
  51. t.Logf("ws.Subprotocol() = %s, want %s", ws.Subprotocol(), subprotos[0])
  52. return
  53. }
  54. for {
  55. op, r, err := ws.NextReader()
  56. if err != nil {
  57. if err != io.EOF {
  58. t.Logf("NextReader: %v", err)
  59. }
  60. return
  61. }
  62. w, err := ws.NextWriter(op)
  63. if err != nil {
  64. t.Logf("NextWriter: %v", err)
  65. return
  66. }
  67. if _, err = io.Copy(w, r); err != nil {
  68. t.Logf("Copy: %v", err)
  69. return
  70. }
  71. if err := w.Close(); err != nil {
  72. t.Logf("Close: %v", err)
  73. return
  74. }
  75. }
  76. }
  77. var handshakeDialer = &Dialer{
  78. Subprotocols: []string{"p1", "p2"},
  79. ReadBufferSize: 1024,
  80. WriteBufferSize: 1024,
  81. }
  82. func TestHandshake(t *testing.T) {
  83. s := httptest.NewServer(handshakeHandler{t})
  84. defer s.Close()
  85. ws, resp, err := handshakeDialer.Dial(httpToWs(s.URL), http.Header{"Origin": {s.URL}})
  86. if err != nil {
  87. t.Fatalf("Dial: %v", err)
  88. }
  89. defer ws.Close()
  90. var sessionID string
  91. for _, c := range resp.Cookies() {
  92. if c.Name == "sessionID" {
  93. sessionID = c.Value
  94. }
  95. }
  96. if sessionID != "1234" {
  97. t.Error("Set-Cookie not received from the server.")
  98. }
  99. if ws.Subprotocol() != handshakeDialer.Subprotocols[0] {
  100. t.Errorf("ws.Subprotocol() = %s, want %s", ws.Subprotocol(), handshakeDialer.Subprotocols[0])
  101. }
  102. sendRecv(t, ws)
  103. }
  104. type dialHandler struct {
  105. *testing.T
  106. }
  107. func (t dialHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  108. ws, err := Upgrade(w, r, nil, 1024, 1024)
  109. if _, ok := err.(HandshakeError); ok {
  110. t.Logf("bad handshake: %v", err)
  111. http.Error(w, "Not a websocket handshake", 400)
  112. return
  113. } else if err != nil {
  114. t.Logf("upgrade error: %v", err)
  115. return
  116. }
  117. defer ws.Close()
  118. for {
  119. mt, p, err := ws.ReadMessage()
  120. if err != nil {
  121. if err != io.EOF {
  122. t.Logf("ReadMessage: %v", err)
  123. }
  124. return
  125. }
  126. if err := ws.WriteMessage(mt, p); err != nil {
  127. t.Logf("WriteMessage: %v", err)
  128. return
  129. }
  130. }
  131. }
  132. func sendRecv(t *testing.T, ws *Conn) {
  133. const message = "Hello World!"
  134. if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
  135. t.Fatalf("SetWriteDeadline: %v", err)
  136. }
  137. if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
  138. t.Fatalf("WriteMessage: %v", err)
  139. }
  140. if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
  141. t.Fatalf("SetReadDeadline: %v", err)
  142. }
  143. _, p, err := ws.ReadMessage()
  144. if err != nil {
  145. t.Fatalf("ReadMessage: %v", err)
  146. }
  147. if string(p) != message {
  148. t.Fatalf("message=%s, want %s", p, message)
  149. }
  150. }
  151. func httpToWs(u string) string {
  152. return "ws" + u[len("http"):]
  153. }
  154. func TestDial(t *testing.T) {
  155. s := httptest.NewServer(dialHandler{t})
  156. defer s.Close()
  157. ws, _, err := DefaultDialer.Dial(httpToWs(s.URL), nil)
  158. if err != nil {
  159. t.Fatalf("Dial() returned error %v", err)
  160. }
  161. defer ws.Close()
  162. sendRecv(t, ws)
  163. }
  164. func TestDialTLS(t *testing.T) {
  165. s := httptest.NewTLSServer(dialHandler{t})
  166. defer s.Close()
  167. certs := x509.NewCertPool()
  168. for _, c := range s.TLS.Certificates {
  169. roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
  170. if err != nil {
  171. t.Fatalf("error parsing server's root cert: %v", err)
  172. }
  173. for _, root := range roots {
  174. certs.AddCert(root)
  175. }
  176. }
  177. u, _ := url.Parse(s.URL)
  178. d := &Dialer{
  179. NetDial: func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) },
  180. TLSClientConfig: &tls.Config{RootCAs: certs},
  181. }
  182. ws, _, err := d.Dial("wss://example.com/", nil)
  183. if err != nil {
  184. t.Fatalf("Dial() returned error %v", err)
  185. }
  186. defer ws.Close()
  187. sendRecv(t, ws)
  188. }
  189. func TestDialTLSBadCert(t *testing.T) {
  190. s := httptest.NewTLSServer(dialHandler{t})
  191. defer s.Close()
  192. _, _, err := DefaultDialer.Dial(httpToWs(s.URL), nil)
  193. if err == nil {
  194. t.Fatalf("Dial() did not return error")
  195. }
  196. }
  197. func TestDialTLSNoVerify(t *testing.T) {
  198. s := httptest.NewTLSServer(dialHandler{t})
  199. defer s.Close()
  200. d := &Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
  201. ws, _, err := d.Dial(httpToWs(s.URL), nil)
  202. if err != nil {
  203. t.Fatalf("Dial() returned error %v", err)
  204. }
  205. defer ws.Close()
  206. sendRecv(t, ws)
  207. }
  208. func TestDialTimeout(t *testing.T) {
  209. s := httptest.NewServer(dialHandler{t})
  210. defer s.Close()
  211. d := &Dialer{
  212. HandshakeTimeout: -1,
  213. }
  214. _, _, err := d.Dial(httpToWs(s.URL), nil)
  215. if err == nil {
  216. t.Fatalf("Dial() did not return error")
  217. }
  218. }
  219. func TestDialBadScheme(t *testing.T) {
  220. s := httptest.NewServer(dialHandler{t})
  221. defer s.Close()
  222. _, _, err := DefaultDialer.Dial(s.URL, nil)
  223. if err == nil {
  224. t.Fatalf("Dial() did not return error")
  225. }
  226. }