hybi.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. // This file implements a protocol of hybi draft.
  6. // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
  7. import (
  8. "bufio"
  9. "bytes"
  10. "crypto/rand"
  11. "crypto/sha1"
  12. "encoding/base64"
  13. "encoding/binary"
  14. "fmt"
  15. "io"
  16. "io/ioutil"
  17. "net/http"
  18. "net/url"
  19. "strings"
  20. )
  21. const (
  22. websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  23. closeStatusNormal = 1000
  24. closeStatusGoingAway = 1001
  25. closeStatusProtocolError = 1002
  26. closeStatusUnsupportedData = 1003
  27. closeStatusFrameTooLarge = 1004
  28. closeStatusNoStatusRcvd = 1005
  29. closeStatusAbnormalClosure = 1006
  30. closeStatusBadMessageData = 1007
  31. closeStatusPolicyViolation = 1008
  32. closeStatusTooBigData = 1009
  33. closeStatusExtensionMismatch = 1010
  34. maxControlFramePayloadLength = 125
  35. )
  36. var (
  37. ErrBadMaskingKey = &ProtocolError{"bad masking key"}
  38. ErrBadPongMessage = &ProtocolError{"bad pong message"}
  39. ErrBadClosingStatus = &ProtocolError{"bad closing status"}
  40. ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
  41. ErrNotImplemented = &ProtocolError{"not implemented"}
  42. handshakeHeader = map[string]bool{
  43. "Host": true,
  44. "Upgrade": true,
  45. "Connection": true,
  46. "Sec-Websocket-Key": true,
  47. "Sec-Websocket-Origin": true,
  48. "Sec-Websocket-Version": true,
  49. "Sec-Websocket-Protocol": true,
  50. "Sec-Websocket-Accept": true,
  51. }
  52. )
  53. // A hybiFrameHeader is a frame header as defined in hybi draft.
  54. type hybiFrameHeader struct {
  55. Fin bool
  56. Rsv [3]bool
  57. OpCode byte
  58. Length int64
  59. MaskingKey []byte
  60. data *bytes.Buffer
  61. }
  62. // A hybiFrameReader is a reader for hybi frame.
  63. type hybiFrameReader struct {
  64. reader io.Reader
  65. header hybiFrameHeader
  66. pos int64
  67. length int
  68. }
  69. func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
  70. n, err = frame.reader.Read(msg)
  71. if err != nil {
  72. return 0, err
  73. }
  74. if frame.header.MaskingKey != nil {
  75. for i := 0; i < n; i++ {
  76. msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
  77. frame.pos++
  78. }
  79. }
  80. return n, err
  81. }
  82. func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
  83. func (frame *hybiFrameReader) HeaderReader() io.Reader {
  84. if frame.header.data == nil {
  85. return nil
  86. }
  87. if frame.header.data.Len() == 0 {
  88. return nil
  89. }
  90. return frame.header.data
  91. }
  92. func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
  93. func (frame *hybiFrameReader) Len() (n int) { return frame.length }
  94. // A hybiFrameReaderFactory creates new frame reader based on its frame type.
  95. type hybiFrameReaderFactory struct {
  96. *bufio.Reader
  97. }
  98. // NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
  99. // See Section 5.2 Base Framing protocol for detail.
  100. // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
  101. func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
  102. hybiFrame := new(hybiFrameReader)
  103. frame = hybiFrame
  104. var header []byte
  105. var b byte
  106. // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
  107. b, err = buf.ReadByte()
  108. if err != nil {
  109. return
  110. }
  111. header = append(header, b)
  112. hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
  113. for i := 0; i < 3; i++ {
  114. j := uint(6 - i)
  115. hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
  116. }
  117. hybiFrame.header.OpCode = header[0] & 0x0f
  118. // Second byte. Mask/Payload len(7bits)
  119. b, err = buf.ReadByte()
  120. if err != nil {
  121. return
  122. }
  123. header = append(header, b)
  124. mask := (b & 0x80) != 0
  125. b &= 0x7f
  126. lengthFields := 0
  127. switch {
  128. case b <= 125: // Payload length 7bits.
  129. hybiFrame.header.Length = int64(b)
  130. case b == 126: // Payload length 7+16bits
  131. lengthFields = 2
  132. case b == 127: // Payload length 7+64bits
  133. lengthFields = 8
  134. }
  135. for i := 0; i < lengthFields; i++ {
  136. b, err = buf.ReadByte()
  137. if err != nil {
  138. return
  139. }
  140. if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits
  141. b &= 0x7f
  142. }
  143. header = append(header, b)
  144. hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
  145. }
  146. if mask {
  147. // Masking key. 4 bytes.
  148. for i := 0; i < 4; i++ {
  149. b, err = buf.ReadByte()
  150. if err != nil {
  151. return
  152. }
  153. header = append(header, b)
  154. hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
  155. }
  156. }
  157. hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
  158. hybiFrame.header.data = bytes.NewBuffer(header)
  159. hybiFrame.length = len(header) + int(hybiFrame.header.Length)
  160. return
  161. }
  162. // A HybiFrameWriter is a writer for hybi frame.
  163. type hybiFrameWriter struct {
  164. writer *bufio.Writer
  165. header *hybiFrameHeader
  166. }
  167. func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
  168. var header []byte
  169. var b byte
  170. if frame.header.Fin {
  171. b |= 0x80
  172. }
  173. for i := 0; i < 3; i++ {
  174. if frame.header.Rsv[i] {
  175. j := uint(6 - i)
  176. b |= 1 << j
  177. }
  178. }
  179. b |= frame.header.OpCode
  180. header = append(header, b)
  181. if frame.header.MaskingKey != nil {
  182. b = 0x80
  183. } else {
  184. b = 0
  185. }
  186. lengthFields := 0
  187. length := len(msg)
  188. switch {
  189. case length <= 125:
  190. b |= byte(length)
  191. case length < 65536:
  192. b |= 126
  193. lengthFields = 2
  194. default:
  195. b |= 127
  196. lengthFields = 8
  197. }
  198. header = append(header, b)
  199. for i := 0; i < lengthFields; i++ {
  200. j := uint((lengthFields - i - 1) * 8)
  201. b = byte((length >> j) & 0xff)
  202. header = append(header, b)
  203. }
  204. if frame.header.MaskingKey != nil {
  205. if len(frame.header.MaskingKey) != 4 {
  206. return 0, ErrBadMaskingKey
  207. }
  208. header = append(header, frame.header.MaskingKey...)
  209. frame.writer.Write(header)
  210. data := make([]byte, length)
  211. for i := range data {
  212. data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
  213. }
  214. frame.writer.Write(data)
  215. err = frame.writer.Flush()
  216. return length, err
  217. }
  218. frame.writer.Write(header)
  219. frame.writer.Write(msg)
  220. err = frame.writer.Flush()
  221. return length, err
  222. }
  223. func (frame *hybiFrameWriter) Close() error { return nil }
  224. type hybiFrameWriterFactory struct {
  225. *bufio.Writer
  226. needMaskingKey bool
  227. }
  228. func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
  229. frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
  230. if buf.needMaskingKey {
  231. frameHeader.MaskingKey, err = generateMaskingKey()
  232. if err != nil {
  233. return nil, err
  234. }
  235. }
  236. return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
  237. }
  238. type hybiFrameHandler struct {
  239. conn *Conn
  240. payloadType byte
  241. }
  242. func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
  243. if handler.conn.IsServerConn() {
  244. // The client MUST mask all frames sent to the server.
  245. if frame.(*hybiFrameReader).header.MaskingKey == nil {
  246. handler.WriteClose(closeStatusProtocolError)
  247. return nil, io.EOF
  248. }
  249. } else {
  250. // The server MUST NOT mask all frames.
  251. if frame.(*hybiFrameReader).header.MaskingKey != nil {
  252. handler.WriteClose(closeStatusProtocolError)
  253. return nil, io.EOF
  254. }
  255. }
  256. if header := frame.HeaderReader(); header != nil {
  257. io.Copy(ioutil.Discard, header)
  258. }
  259. switch frame.PayloadType() {
  260. case ContinuationFrame:
  261. frame.(*hybiFrameReader).header.OpCode = handler.payloadType
  262. case TextFrame, BinaryFrame:
  263. handler.payloadType = frame.PayloadType()
  264. case CloseFrame:
  265. return nil, io.EOF
  266. case PingFrame, PongFrame:
  267. b := make([]byte, maxControlFramePayloadLength)
  268. n, err := io.ReadFull(frame, b)
  269. if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
  270. return nil, err
  271. }
  272. io.Copy(ioutil.Discard, frame)
  273. if frame.PayloadType() == PingFrame {
  274. if _, err := handler.WritePong(b[:n]); err != nil {
  275. return nil, err
  276. }
  277. }
  278. return nil, nil
  279. }
  280. return frame, nil
  281. }
  282. func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
  283. handler.conn.wio.Lock()
  284. defer handler.conn.wio.Unlock()
  285. w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
  286. if err != nil {
  287. return err
  288. }
  289. msg := make([]byte, 2)
  290. binary.BigEndian.PutUint16(msg, uint16(status))
  291. _, err = w.Write(msg)
  292. w.Close()
  293. return err
  294. }
  295. func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
  296. handler.conn.wio.Lock()
  297. defer handler.conn.wio.Unlock()
  298. w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
  299. if err != nil {
  300. return 0, err
  301. }
  302. n, err = w.Write(msg)
  303. w.Close()
  304. return n, err
  305. }
  306. // newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
  307. func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
  308. if buf == nil {
  309. br := bufio.NewReader(rwc)
  310. bw := bufio.NewWriter(rwc)
  311. buf = bufio.NewReadWriter(br, bw)
  312. }
  313. ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
  314. frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
  315. frameWriterFactory: hybiFrameWriterFactory{
  316. buf.Writer, request == nil},
  317. PayloadType: TextFrame,
  318. defaultCloseStatus: closeStatusNormal}
  319. ws.frameHandler = &hybiFrameHandler{conn: ws}
  320. return ws
  321. }
  322. // generateMaskingKey generates a masking key for a frame.
  323. func generateMaskingKey() (maskingKey []byte, err error) {
  324. maskingKey = make([]byte, 4)
  325. if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
  326. return
  327. }
  328. return
  329. }
  330. // generateNonce generates a nonce consisting of a randomly selected 16-byte
  331. // value that has been base64-encoded.
  332. func generateNonce() (nonce []byte) {
  333. key := make([]byte, 16)
  334. if _, err := io.ReadFull(rand.Reader, key); err != nil {
  335. panic(err)
  336. }
  337. nonce = make([]byte, 24)
  338. base64.StdEncoding.Encode(nonce, key)
  339. return
  340. }
  341. // getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
  342. // the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
  343. func getNonceAccept(nonce []byte) (expected []byte, err error) {
  344. h := sha1.New()
  345. if _, err = h.Write(nonce); err != nil {
  346. return
  347. }
  348. if _, err = h.Write([]byte(websocketGUID)); err != nil {
  349. return
  350. }
  351. expected = make([]byte, 28)
  352. base64.StdEncoding.Encode(expected, h.Sum(nil))
  353. return
  354. }
  355. // Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
  356. func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
  357. bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
  358. bw.WriteString("Host: " + config.Location.Host + "\r\n")
  359. bw.WriteString("Upgrade: websocket\r\n")
  360. bw.WriteString("Connection: Upgrade\r\n")
  361. nonce := generateNonce()
  362. if config.handshakeData != nil {
  363. nonce = []byte(config.handshakeData["key"])
  364. }
  365. bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
  366. bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
  367. if config.Version != ProtocolVersionHybi13 {
  368. return ErrBadProtocolVersion
  369. }
  370. bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
  371. if len(config.Protocol) > 0 {
  372. bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
  373. }
  374. // TODO(ukai): send Sec-WebSocket-Extensions.
  375. err = config.Header.WriteSubset(bw, handshakeHeader)
  376. if err != nil {
  377. return err
  378. }
  379. bw.WriteString("\r\n")
  380. if err = bw.Flush(); err != nil {
  381. return err
  382. }
  383. resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
  384. if err != nil {
  385. return err
  386. }
  387. if resp.StatusCode != 101 {
  388. return ErrBadStatus
  389. }
  390. if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
  391. strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
  392. return ErrBadUpgrade
  393. }
  394. expectedAccept, err := getNonceAccept(nonce)
  395. if err != nil {
  396. return err
  397. }
  398. if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
  399. return ErrChallengeResponse
  400. }
  401. if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
  402. return ErrUnsupportedExtensions
  403. }
  404. offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
  405. if offeredProtocol != "" {
  406. protocolMatched := false
  407. for i := 0; i < len(config.Protocol); i++ {
  408. if config.Protocol[i] == offeredProtocol {
  409. protocolMatched = true
  410. break
  411. }
  412. }
  413. if !protocolMatched {
  414. return ErrBadWebSocketProtocol
  415. }
  416. config.Protocol = []string{offeredProtocol}
  417. }
  418. return nil
  419. }
  420. // newHybiClientConn creates a client WebSocket connection after handshake.
  421. func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
  422. return newHybiConn(config, buf, rwc, nil)
  423. }
  424. // A HybiServerHandshaker performs a server handshake using hybi draft protocol.
  425. type hybiServerHandshaker struct {
  426. *Config
  427. accept []byte
  428. }
  429. func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
  430. c.Version = ProtocolVersionHybi13
  431. if req.Method != "GET" {
  432. return http.StatusMethodNotAllowed, ErrBadRequestMethod
  433. }
  434. // HTTP version can be safely ignored.
  435. if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
  436. !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
  437. return http.StatusBadRequest, ErrNotWebSocket
  438. }
  439. key := req.Header.Get("Sec-Websocket-Key")
  440. if key == "" {
  441. return http.StatusBadRequest, ErrChallengeResponse
  442. }
  443. version := req.Header.Get("Sec-Websocket-Version")
  444. switch version {
  445. case "13":
  446. c.Version = ProtocolVersionHybi13
  447. default:
  448. return http.StatusBadRequest, ErrBadWebSocketVersion
  449. }
  450. var scheme string
  451. if req.TLS != nil {
  452. scheme = "wss"
  453. } else {
  454. scheme = "ws"
  455. }
  456. c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
  457. if err != nil {
  458. return http.StatusBadRequest, err
  459. }
  460. protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
  461. if protocol != "" {
  462. protocols := strings.Split(protocol, ",")
  463. for i := 0; i < len(protocols); i++ {
  464. c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
  465. }
  466. }
  467. c.accept, err = getNonceAccept([]byte(key))
  468. if err != nil {
  469. return http.StatusInternalServerError, err
  470. }
  471. return http.StatusSwitchingProtocols, nil
  472. }
  473. // Origin parses the Origin header in req.
  474. // If the Origin header is not set, it returns nil and nil.
  475. func Origin(config *Config, req *http.Request) (*url.URL, error) {
  476. var origin string
  477. switch config.Version {
  478. case ProtocolVersionHybi13:
  479. origin = req.Header.Get("Origin")
  480. }
  481. if origin == "" {
  482. return nil, nil
  483. }
  484. return url.ParseRequestURI(origin)
  485. }
  486. func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
  487. if len(c.Protocol) > 0 {
  488. if len(c.Protocol) != 1 {
  489. // You need choose a Protocol in Handshake func in Server.
  490. return ErrBadWebSocketProtocol
  491. }
  492. }
  493. buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
  494. buf.WriteString("Upgrade: websocket\r\n")
  495. buf.WriteString("Connection: Upgrade\r\n")
  496. buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
  497. if len(c.Protocol) > 0 {
  498. buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
  499. }
  500. // TODO(ukai): send Sec-WebSocket-Extensions.
  501. if c.Header != nil {
  502. err := c.Header.WriteSubset(buf, handshakeHeader)
  503. if err != nil {
  504. return err
  505. }
  506. }
  507. buf.WriteString("\r\n")
  508. return buf.Flush()
  509. }
  510. func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
  511. return newHybiServerConn(c.Config, buf, rwc, request)
  512. }
  513. // newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
  514. func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
  515. return newHybiConn(config, buf, rwc, request)
  516. }