hybi.go 15 KB


  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. // This file implements a protocol of hybi draft.
  6. // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
  7. import (
  8. "bufio"
  9. "bytes"
  10. "crypto/rand"
  11. "crypto/sha1"
  12. "encoding/base64"
  13. "encoding/binary"
  14. "fmt"
  15. "io"
  16. "io/ioutil"
  17. "net/http"
  18. "net/url"
  19. "strings"
  20. )
  21. const (
  22. websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  23. closeStatusNormal = 1000
  24. closeStatusGoingAway = 1001
  25. closeStatusProtocolError = 1002
  26. closeStatusUnsupportedData = 1003
  27. closeStatusFrameTooLarge = 1004
  28. closeStatusNoStatusRcvd = 1005
  29. closeStatusAbnormalClosure = 1006
  30. closeStatusBadMessageData = 1007
  31. closeStatusPolicyViolation = 1008
  32. closeStatusTooBigData = 1009
  33. closeStatusExtensionMismatch = 1010
  34. maxControlFramePayloadLength = 125
  35. )
  36. var (
  37. ErrBadMaskingKey = &ProtocolError{"bad masking key"}
  38. ErrBadPongMessage = &ProtocolError{"bad pong message"}
  39. ErrBadClosingStatus = &ProtocolError{"bad closing status"}
  40. ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
  41. ErrNotImplemented = &ProtocolError{"not implemented"}
  42. )
  43. // A hybiFrameHeader is a frame header as defined in hybi draft.
  44. type hybiFrameHeader struct {
  45. Fin bool
  46. Rsv [3]bool
  47. OpCode byte
  48. Length int64
  49. MaskingKey []byte
  50. data *bytes.Buffer
  51. }
  52. // A hybiFrameReader is a reader for hybi frame.
  53. type hybiFrameReader struct {
  54. reader io.Reader
  55. header hybiFrameHeader
  56. pos int64
  57. length int
  58. }
  59. func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
  60. n, err = frame.reader.Read(msg)
  61. if err != nil {
  62. return 0, err
  63. }
  64. if frame.header.MaskingKey != nil {
  65. for i := 0; i < n; i++ {
  66. msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
  67. frame.pos++
  68. }
  69. }
  70. return n, err
  71. }
  72. func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
  73. func (frame *hybiFrameReader) HeaderReader() io.Reader {
  74. if frame.header.data == nil {
  75. return nil
  76. }
  77. if frame.header.data.Len() == 0 {
  78. return nil
  79. }
  80. return frame.header.data
  81. }
  82. func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
  83. func (frame *hybiFrameReader) Len() (n int) { return frame.length }
  84. // A hybiFrameReaderFactory creates new frame reader based on its frame type.
  85. type hybiFrameReaderFactory struct {
  86. *bufio.Reader
  87. }
  88. // NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
  89. // See Section 5.2 Base Framing protocol for detail.
  90. // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
  91. func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
  92. hybiFrame := new(hybiFrameReader)
  93. frame = hybiFrame
  94. var header []byte
  95. var b byte
  96. // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
  97. b, err = buf.ReadByte()
  98. if err != nil {
  99. return
  100. }
  101. header = append(header, b)
  102. hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
  103. for i := 0; i < 3; i++ {
  104. j := uint(6 - i)
  105. hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
  106. }
  107. hybiFrame.header.OpCode = header[0] & 0x0f
  108. // Second byte. Mask/Payload len(7bits)
  109. b, err = buf.ReadByte()
  110. if err != nil {
  111. return
  112. }
  113. header = append(header, b)
  114. mask := (b & 0x80) != 0
  115. b &= 0x7f
  116. lengthFields := 0
  117. switch {
  118. case b <= 125: // Payload length 7bits.
  119. hybiFrame.header.Length = int64(b)
  120. case b == 126: // Payload length 7+16bits
  121. lengthFields = 2
  122. case b == 127: // Payload length 7+64bits
  123. lengthFields = 8
  124. }
  125. for i := 0; i < lengthFields; i++ {
  126. b, err = buf.ReadByte()
  127. if err != nil {
  128. return
  129. }
  130. header = append(header, b)
  131. hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
  132. }
  133. if mask {
  134. // Masking key. 4 bytes.
  135. for i := 0; i < 4; i++ {
  136. b, err = buf.ReadByte()
  137. if err != nil {
  138. return
  139. }
  140. header = append(header, b)
  141. hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
  142. }
  143. }
  144. hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
  145. hybiFrame.header.data = bytes.NewBuffer(header)
  146. hybiFrame.length = len(header) + int(hybiFrame.header.Length)
  147. return
  148. }
  149. // A HybiFrameWriter is a writer for hybi frame.
  150. type hybiFrameWriter struct {
  151. writer *bufio.Writer
  152. header *hybiFrameHeader
  153. }
  154. func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
  155. var header []byte
  156. var b byte
  157. if frame.header.Fin {
  158. b |= 0x80
  159. }
  160. for i := 0; i < 3; i++ {
  161. if frame.header.Rsv[i] {
  162. j := uint(6 - i)
  163. b |= 1 << j
  164. }
  165. }
  166. b |= frame.header.OpCode
  167. header = append(header, b)
  168. if frame.header.MaskingKey != nil {
  169. b = 0x80
  170. } else {
  171. b = 0
  172. }
  173. lengthFields := 0
  174. length := len(msg)
  175. switch {
  176. case length <= 125:
  177. b |= byte(length)
  178. case length < 65536:
  179. b |= 126
  180. lengthFields = 2
  181. default:
  182. b |= 127
  183. lengthFields = 8
  184. }
  185. header = append(header, b)
  186. for i := 0; i < lengthFields; i++ {
  187. j := uint((lengthFields - i - 1) * 8)
  188. b = byte((length >> j) & 0xff)
  189. header = append(header, b)
  190. }
  191. if frame.header.MaskingKey != nil {
  192. if len(frame.header.MaskingKey) != 4 {
  193. return 0, ErrBadMaskingKey
  194. }
  195. header = append(header, frame.header.MaskingKey...)
  196. frame.writer.Write(header)
  197. var data []byte
  198. for i := 0; i < length; i++ {
  199. data = append(data, msg[i]^frame.header.MaskingKey[i%4])
  200. }
  201. frame.writer.Write(data)
  202. err = frame.writer.Flush()
  203. return length, err
  204. }
  205. frame.writer.Write(header)
  206. frame.writer.Write(msg)
  207. err = frame.writer.Flush()
  208. return length, err
  209. }
  210. func (frame *hybiFrameWriter) Close() error { return nil }
  211. type hybiFrameWriterFactory struct {
  212. *bufio.Writer
  213. needMaskingKey bool
  214. }
  215. func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
  216. frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
  217. if buf.needMaskingKey {
  218. frameHeader.MaskingKey, err = generateMaskingKey()
  219. if err != nil {
  220. return nil, err
  221. }
  222. }
  223. return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
  224. }
  225. type hybiFrameHandler struct {
  226. conn *Conn
  227. payloadType byte
  228. }
  229. func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, err error) {
  230. if handler.conn.IsServerConn() {
  231. // The client MUST mask all frames sent to the server.
  232. if frame.(*hybiFrameReader).header.MaskingKey == nil {
  233. handler.WriteClose(closeStatusProtocolError)
  234. return nil, io.EOF
  235. }
  236. } else {
  237. // The server MUST NOT mask all frames.
  238. if frame.(*hybiFrameReader).header.MaskingKey != nil {
  239. handler.WriteClose(closeStatusProtocolError)
  240. return nil, io.EOF
  241. }
  242. }
  243. if header := frame.HeaderReader(); header != nil {
  244. io.Copy(ioutil.Discard, header)
  245. }
  246. switch frame.PayloadType() {
  247. case ContinuationFrame:
  248. frame.(*hybiFrameReader).header.OpCode = handler.payloadType
  249. case TextFrame, BinaryFrame:
  250. handler.payloadType = frame.PayloadType()
  251. case CloseFrame:
  252. return nil, io.EOF
  253. case PingFrame:
  254. pingMsg := make([]byte, maxControlFramePayloadLength)
  255. n, err := io.ReadFull(frame, pingMsg)
  256. if err != nil && err != io.ErrUnexpectedEOF {
  257. return nil, err
  258. }
  259. io.Copy(ioutil.Discard, frame)
  260. n, err = handler.WritePong(pingMsg[:n])
  261. if err != nil {
  262. return nil, err
  263. }
  264. return nil, nil
  265. case PongFrame:
  266. return nil, ErrNotImplemented
  267. }
  268. return frame, nil
  269. }
  270. func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
  271. handler.conn.wio.Lock()
  272. defer handler.conn.wio.Unlock()
  273. w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
  274. if err != nil {
  275. return err
  276. }
  277. msg := make([]byte, 2)
  278. binary.BigEndian.PutUint16(msg, uint16(status))
  279. _, err = w.Write(msg)
  280. w.Close()
  281. return err
  282. }
  283. func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
  284. handler.conn.wio.Lock()
  285. defer handler.conn.wio.Unlock()
  286. w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
  287. if err != nil {
  288. return 0, err
  289. }
  290. n, err = w.Write(msg)
  291. w.Close()
  292. return n, err
  293. }
  294. // newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
  295. func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
  296. if buf == nil {
  297. br := bufio.NewReader(rwc)
  298. bw := bufio.NewWriter(rwc)
  299. buf = bufio.NewReadWriter(br, bw)
  300. }
  301. ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
  302. frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
  303. frameWriterFactory: hybiFrameWriterFactory{
  304. buf.Writer, request == nil},
  305. PayloadType: TextFrame,
  306. defaultCloseStatus: closeStatusNormal}
  307. ws.frameHandler = &hybiFrameHandler{conn: ws}
  308. return ws
  309. }
  310. // generateMaskingKey generates a masking key for a frame.
  311. func generateMaskingKey() (maskingKey []byte, err error) {
  312. maskingKey = make([]byte, 4)
  313. if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
  314. return
  315. }
  316. return
  317. }
  318. // generateNonce generates a nonce consisting of a randomly selected 16-byte
  319. // value that has been base64-encoded.
  320. func generateNonce() (nonce []byte) {
  321. key := make([]byte, 16)
  322. if _, err := io.ReadFull(rand.Reader, key); err != nil {
  323. panic(err)
  324. }
  325. nonce = make([]byte, 24)
  326. base64.StdEncoding.Encode(nonce, key)
  327. return
  328. }
  329. // getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
  330. // the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
  331. func getNonceAccept(nonce []byte) (expected []byte, err error) {
  332. h := sha1.New()
  333. if _, err = h.Write(nonce); err != nil {
  334. return
  335. }
  336. if _, err = h.Write([]byte(websocketGUID)); err != nil {
  337. return
  338. }
  339. expected = make([]byte, 28)
  340. base64.StdEncoding.Encode(expected, h.Sum(nil))
  341. return
  342. }
  343. func isHybiVersion(version int) bool {
  344. switch version {
  345. case ProtocolVersionHybi08, ProtocolVersionHybi13:
  346. return true
  347. default:
  348. }
  349. return false
  350. }
  351. // Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
  352. func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
  353. if !isHybiVersion(config.Version) {
  354. panic("wrong protocol version.")
  355. }
  356. bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
  357. bw.WriteString("Host: " + config.Location.Host + "\r\n")
  358. bw.WriteString("Upgrade: websocket\r\n")
  359. bw.WriteString("Connection: Upgrade\r\n")
  360. nonce := generateNonce()
  361. if config.handshakeData != nil {
  362. nonce = []byte(config.handshakeData["key"])
  363. }
  364. bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
  365. if config.Version == ProtocolVersionHybi13 {
  366. bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
  367. } else if config.Version == ProtocolVersionHybi08 {
  368. bw.WriteString("Sec-WebSocket-Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
  369. }
  370. bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
  371. if len(config.Protocol) > 0 {
  372. bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
  373. }
  374. // TODO(ukai): send extensions.
  375. // TODO(ukai): send cookie if any.
  376. bw.WriteString("\r\n")
  377. if err = bw.Flush(); err != nil {
  378. return err
  379. }
  380. resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
  381. if err != nil {
  382. return err
  383. }
  384. if resp.StatusCode != 101 {
  385. return ErrBadStatus
  386. }
  387. if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
  388. strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
  389. return ErrBadUpgrade
  390. }
  391. expectedAccept, err := getNonceAccept(nonce)
  392. if err != nil {
  393. return err
  394. }
  395. if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
  396. return ErrChallengeResponse
  397. }
  398. if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
  399. return ErrUnsupportedExtensions
  400. }
  401. offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
  402. if offeredProtocol != "" {
  403. protocolMatched := false
  404. for i := 0; i < len(config.Protocol); i++ {
  405. if config.Protocol[i] == offeredProtocol {
  406. protocolMatched = true
  407. break
  408. }
  409. }
  410. if !protocolMatched {
  411. return ErrBadWebSocketProtocol
  412. }
  413. config.Protocol = []string{offeredProtocol}
  414. }
  415. return nil
  416. }
  417. // newHybiClientConn creates a client WebSocket connection after handshake.
  418. func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
  419. return newHybiConn(config, buf, rwc, nil)
  420. }
  421. // A HybiServerHandshaker performs a server handshake using hybi draft protocol.
  422. type hybiServerHandshaker struct {
  423. *Config
  424. accept []byte
  425. }
  426. func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
  427. c.Version = ProtocolVersionHybi13
  428. if req.Method != "GET" {
  429. return http.StatusMethodNotAllowed, ErrBadRequestMethod
  430. }
  431. // HTTP version can be safely ignored.
  432. if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
  433. !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
  434. return http.StatusBadRequest, ErrNotWebSocket
  435. }
  436. key := req.Header.Get("Sec-Websocket-Key")
  437. if key == "" {
  438. return http.StatusBadRequest, ErrChallengeResponse
  439. }
  440. version := req.Header.Get("Sec-Websocket-Version")
  441. var origin string
  442. switch version {
  443. case "13":
  444. c.Version = ProtocolVersionHybi13
  445. origin = req.Header.Get("Origin")
  446. case "8":
  447. c.Version = ProtocolVersionHybi08
  448. origin = req.Header.Get("Sec-Websocket-Origin")
  449. default:
  450. return http.StatusBadRequest, ErrBadWebSocketVersion
  451. }
  452. c.Origin, err = url.ParseRequestURI(origin)
  453. if err != nil {
  454. return http.StatusForbidden, err
  455. }
  456. var scheme string
  457. if req.TLS != nil {
  458. scheme = "wss"
  459. } else {
  460. scheme = "ws"
  461. }
  462. c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
  463. if err != nil {
  464. return http.StatusBadRequest, err
  465. }
  466. protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
  467. protocols := strings.Split(protocol, ",")
  468. for i := 0; i < len(protocols); i++ {
  469. c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
  470. }
  471. c.accept, err = getNonceAccept([]byte(key))
  472. if err != nil {
  473. return http.StatusInternalServerError, err
  474. }
  475. return http.StatusSwitchingProtocols, nil
  476. }
  477. func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
  478. if len(c.Protocol) > 0 {
  479. if len(c.Protocol) != 1 {
  480. return ErrBadWebSocketProtocol
  481. }
  482. }
  483. buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
  484. buf.WriteString("Upgrade: websocket\r\n")
  485. buf.WriteString("Connection: Upgrade\r\n")
  486. buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
  487. if len(c.Protocol) > 0 {
  488. buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
  489. }
  490. // TODO(ukai): support extensions
  491. buf.WriteString("\r\n")
  492. return buf.Flush()
  493. }
  494. func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
  495. return newHybiServerConn(c.Config, buf, rwc, request)
  496. }
  497. // newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
  498. func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
  499. return newHybiConn(config, buf, rwc, request)
  500. }