client_server_test.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "crypto/tls"
  7. "crypto/x509"
  8. "encoding/base64"
  9. "io"
  10. "io/ioutil"
  11. "net"
  12. "net/http"
  13. "net/http/httptest"
  14. "net/url"
  15. "reflect"
  16. "strings"
  17. "testing"
  18. "time"
  19. )
  20. var cstUpgrader = Upgrader{
  21. Subprotocols: []string{"p0", "p1"},
  22. ReadBufferSize: 1024,
  23. WriteBufferSize: 1024,
  24. Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
  25. http.Error(w, reason.Error(), status)
  26. },
  27. }
  28. var cstDialer = Dialer{
  29. Subprotocols: []string{"p1", "p2"},
  30. ReadBufferSize: 1024,
  31. WriteBufferSize: 1024,
  32. }
  33. type cstHandler struct{ *testing.T }
  34. type cstServer struct {
  35. *httptest.Server
  36. URL string
  37. }
  38. const (
  39. cstPath = "/a/b"
  40. cstRawQuery = "x=y"
  41. cstRequestURI = cstPath + "?" + cstRawQuery
  42. )
  43. func newServer(t *testing.T) *cstServer {
  44. var s cstServer
  45. s.Server = httptest.NewServer(cstHandler{t})
  46. s.Server.URL += cstRequestURI
  47. s.URL = makeWsProto(s.Server.URL)
  48. return &s
  49. }
  50. func newTLSServer(t *testing.T) *cstServer {
  51. var s cstServer
  52. s.Server = httptest.NewTLSServer(cstHandler{t})
  53. s.Server.URL += cstRequestURI
  54. s.URL = makeWsProto(s.Server.URL)
  55. return &s
  56. }
  57. func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  58. if r.URL.Path != cstPath {
  59. t.Logf("path=%v, want %v", r.URL.Path, cstPath)
  60. http.Error(w, "bad path", 400)
  61. return
  62. }
  63. if r.URL.RawQuery != cstRawQuery {
  64. t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
  65. http.Error(w, "bad path", 400)
  66. return
  67. }
  68. subprotos := Subprotocols(r)
  69. if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
  70. t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
  71. http.Error(w, "bad protocol", 400)
  72. return
  73. }
  74. ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
  75. if err != nil {
  76. t.Logf("Upgrade: %v", err)
  77. return
  78. }
  79. defer ws.Close()
  80. if ws.Subprotocol() != "p1" {
  81. t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
  82. ws.Close()
  83. return
  84. }
  85. op, rd, err := ws.NextReader()
  86. if err != nil {
  87. t.Logf("NextReader: %v", err)
  88. return
  89. }
  90. wr, err := ws.NextWriter(op)
  91. if err != nil {
  92. t.Logf("NextWriter: %v", err)
  93. return
  94. }
  95. if _, err = io.Copy(wr, rd); err != nil {
  96. t.Logf("NextWriter: %v", err)
  97. return
  98. }
  99. if err := wr.Close(); err != nil {
  100. t.Logf("Close: %v", err)
  101. return
  102. }
  103. }
  104. func makeWsProto(s string) string {
  105. return "ws" + strings.TrimPrefix(s, "http")
  106. }
  107. func sendRecv(t *testing.T, ws *Conn) {
  108. const message = "Hello World!"
  109. if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
  110. t.Fatalf("SetWriteDeadline: %v", err)
  111. }
  112. if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
  113. t.Fatalf("WriteMessage: %v", err)
  114. }
  115. if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
  116. t.Fatalf("SetReadDeadline: %v", err)
  117. }
  118. _, p, err := ws.ReadMessage()
  119. if err != nil {
  120. t.Fatalf("ReadMessage: %v", err)
  121. }
  122. if string(p) != message {
  123. t.Fatalf("message=%s, want %s", p, message)
  124. }
  125. }
  126. func TestProxyDial(t *testing.T) {
  127. s := newServer(t)
  128. defer s.Close()
  129. surl, _ := url.Parse(s.URL)
  130. cstDialer.Proxy = http.ProxyURL(surl)
  131. connect := false
  132. origHandler := s.Server.Config.Handler
  133. // Capture the request Host header.
  134. s.Server.Config.Handler = http.HandlerFunc(
  135. func(w http.ResponseWriter, r *http.Request) {
  136. if r.Method == "CONNECT" {
  137. connect = true
  138. w.WriteHeader(200)
  139. return
  140. }
  141. if !connect {
  142. t.Log("connect not recieved")
  143. http.Error(w, "connect not recieved", 405)
  144. return
  145. }
  146. origHandler.ServeHTTP(w, r)
  147. })
  148. ws, _, err := cstDialer.Dial(s.URL, nil)
  149. if err != nil {
  150. t.Fatalf("Dial: %v", err)
  151. }
  152. defer ws.Close()
  153. sendRecv(t, ws)
  154. cstDialer.Proxy = http.ProxyFromEnvironment
  155. }
  156. func TestProxyAuthorizationDial(t *testing.T) {
  157. s := newServer(t)
  158. defer s.Close()
  159. surl, _ := url.Parse(s.URL)
  160. surl.User = url.UserPassword("username", "password")
  161. cstDialer.Proxy = http.ProxyURL(surl)
  162. connect := false
  163. origHandler := s.Server.Config.Handler
  164. // Capture the request Host header.
  165. s.Server.Config.Handler = http.HandlerFunc(
  166. func(w http.ResponseWriter, r *http.Request) {
  167. proxyAuth := r.Header.Get("Proxy-Authorization")
  168. expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
  169. if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
  170. connect = true
  171. w.WriteHeader(200)
  172. return
  173. }
  174. if !connect {
  175. t.Log("connect with proxy authorization not recieved")
  176. http.Error(w, "connect with proxy authorization not recieved", 405)
  177. return
  178. }
  179. origHandler.ServeHTTP(w, r)
  180. })
  181. ws, _, err := cstDialer.Dial(s.URL, nil)
  182. if err != nil {
  183. t.Fatalf("Dial: %v", err)
  184. }
  185. defer ws.Close()
  186. sendRecv(t, ws)
  187. cstDialer.Proxy = http.ProxyFromEnvironment
  188. }
  189. func TestDial(t *testing.T) {
  190. s := newServer(t)
  191. defer s.Close()
  192. ws, _, err := cstDialer.Dial(s.URL, nil)
  193. if err != nil {
  194. t.Fatalf("Dial: %v", err)
  195. }
  196. defer ws.Close()
  197. sendRecv(t, ws)
  198. }
  199. func TestDialTLS(t *testing.T) {
  200. s := newTLSServer(t)
  201. defer s.Close()
  202. certs := x509.NewCertPool()
  203. for _, c := range s.TLS.Certificates {
  204. roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
  205. if err != nil {
  206. t.Fatalf("error parsing server's root cert: %v", err)
  207. }
  208. for _, root := range roots {
  209. certs.AddCert(root)
  210. }
  211. }
  212. u, _ := url.Parse(s.URL)
  213. d := cstDialer
  214. d.NetDial = func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) }
  215. d.TLSClientConfig = &tls.Config{RootCAs: certs}
  216. ws, _, err := d.Dial("wss://example.com"+cstRequestURI, nil)
  217. if err != nil {
  218. t.Fatalf("Dial: %v", err)
  219. }
  220. defer ws.Close()
  221. sendRecv(t, ws)
  222. }
  223. func xTestDialTLSBadCert(t *testing.T) {
  224. // This test is deactivated because of noisy logging from the net/http package.
  225. s := newTLSServer(t)
  226. defer s.Close()
  227. ws, _, err := cstDialer.Dial(s.URL, nil)
  228. if err == nil {
  229. ws.Close()
  230. t.Fatalf("Dial: nil")
  231. }
  232. }
  233. func xTestDialTLSNoVerify(t *testing.T) {
  234. s := newTLSServer(t)
  235. defer s.Close()
  236. d := cstDialer
  237. d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
  238. ws, _, err := d.Dial(s.URL, nil)
  239. if err != nil {
  240. t.Fatalf("Dial: %v", err)
  241. }
  242. defer ws.Close()
  243. sendRecv(t, ws)
  244. }
  245. func TestDialTimeout(t *testing.T) {
  246. s := newServer(t)
  247. defer s.Close()
  248. d := cstDialer
  249. d.HandshakeTimeout = -1
  250. ws, _, err := d.Dial(s.URL, nil)
  251. if err == nil {
  252. ws.Close()
  253. t.Fatalf("Dial: nil")
  254. }
  255. }
  256. func TestDialBadScheme(t *testing.T) {
  257. s := newServer(t)
  258. defer s.Close()
  259. ws, _, err := cstDialer.Dial(s.Server.URL, nil)
  260. if err == nil {
  261. ws.Close()
  262. t.Fatalf("Dial: nil")
  263. }
  264. }
  265. func TestDialBadOrigin(t *testing.T) {
  266. s := newServer(t)
  267. defer s.Close()
  268. ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
  269. if err == nil {
  270. ws.Close()
  271. t.Fatalf("Dial: nil")
  272. }
  273. if resp == nil {
  274. t.Fatalf("resp=nil, err=%v", err)
  275. }
  276. if resp.StatusCode != http.StatusForbidden {
  277. t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
  278. }
  279. }
  280. func TestDialBadHeader(t *testing.T) {
  281. s := newServer(t)
  282. defer s.Close()
  283. for _, k := range []string{"Upgrade",
  284. "Connection",
  285. "Sec-Websocket-Key",
  286. "Sec-Websocket-Version",
  287. "Sec-Websocket-Protocol"} {
  288. h := http.Header{}
  289. h.Set(k, "bad")
  290. ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
  291. if err == nil {
  292. ws.Close()
  293. t.Errorf("Dial with header %s returned nil", k)
  294. }
  295. }
  296. }
  297. func TestBadMethod(t *testing.T) {
  298. s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  299. ws, err := cstUpgrader.Upgrade(w, r, nil)
  300. if err == nil {
  301. t.Errorf("handshake succeeded, expect fail")
  302. ws.Close()
  303. }
  304. }))
  305. defer s.Close()
  306. resp, err := http.PostForm(s.URL, url.Values{})
  307. if err != nil {
  308. t.Fatalf("PostForm returned error %v", err)
  309. }
  310. resp.Body.Close()
  311. if resp.StatusCode != http.StatusMethodNotAllowed {
  312. t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
  313. }
  314. }
  315. func TestHandshake(t *testing.T) {
  316. s := newServer(t)
  317. defer s.Close()
  318. ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
  319. if err != nil {
  320. t.Fatalf("Dial: %v", err)
  321. }
  322. defer ws.Close()
  323. var sessionID string
  324. for _, c := range resp.Cookies() {
  325. if c.Name == "sessionID" {
  326. sessionID = c.Value
  327. }
  328. }
  329. if sessionID != "1234" {
  330. t.Error("Set-Cookie not received from the server.")
  331. }
  332. if ws.Subprotocol() != "p1" {
  333. t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
  334. }
  335. sendRecv(t, ws)
  336. }
  337. func TestRespOnBadHandshake(t *testing.T) {
  338. const expectedStatus = http.StatusGone
  339. const expectedBody = "This is the response body."
  340. s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  341. w.WriteHeader(expectedStatus)
  342. io.WriteString(w, expectedBody)
  343. }))
  344. defer s.Close()
  345. ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
  346. if err == nil {
  347. ws.Close()
  348. t.Fatalf("Dial: nil")
  349. }
  350. if resp == nil {
  351. t.Fatalf("resp=nil, err=%v", err)
  352. }
  353. if resp.StatusCode != expectedStatus {
  354. t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
  355. }
  356. p, err := ioutil.ReadAll(resp.Body)
  357. if err != nil {
  358. t.Fatalf("ReadFull(resp.Body) returned error %v", err)
  359. }
  360. if string(p) != expectedBody {
  361. t.Errorf("resp.Body=%s, want %s", p, expectedBody)
  362. }
  363. }
  364. // TestHostHeader confirms that the host header provided in the call to Dial is
  365. // sent to the server.
  366. func TestHostHeader(t *testing.T) {
  367. s := newServer(t)
  368. defer s.Close()
  369. specifiedHost := make(chan string, 1)
  370. origHandler := s.Server.Config.Handler
  371. // Capture the request Host header.
  372. s.Server.Config.Handler = http.HandlerFunc(
  373. func(w http.ResponseWriter, r *http.Request) {
  374. specifiedHost <- r.Host
  375. origHandler.ServeHTTP(w, r)
  376. })
  377. ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
  378. if err != nil {
  379. t.Fatalf("Dial: %v", err)
  380. }
  381. defer ws.Close()
  382. if gotHost := <-specifiedHost; gotHost != "testhost" {
  383. t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
  384. }
  385. sendRecv(t, ws)
  386. }