websocket_proxy.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. package wsproxy
  2. import (
  3. "bufio"
  4. "io"
  5. "net/http"
  6. "strings"
  7. "github.com/gorilla/websocket"
  8. "github.com/sirupsen/logrus"
  9. "golang.org/x/net/context"
  10. )
  11. // MethodOverrideParam defines the special URL parameter that is translated into the subsequent proxied streaming http request's method.
  12. //
  13. // Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
  14. var MethodOverrideParam = "method"
  15. // TokenCookieName defines the cookie name that is translated to an 'Authorization: Bearer' header in the streaming http request's headers.
  16. //
  17. // Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
  18. var TokenCookieName = "token"
  19. // RequestMutatorFunc can supply an alternate outgoing request.
  20. type RequestMutatorFunc func(incoming *http.Request, outgoing *http.Request) *http.Request
  21. // Proxy provides websocket transport upgrade to compatible endpoints.
  22. type Proxy struct {
  23. h http.Handler
  24. logger Logger
  25. methodOverrideParam string
  26. tokenCookieName string
  27. requestMutator RequestMutatorFunc
  28. }
  29. // Logger collects log messages.
  30. type Logger interface {
  31. Warnln(...interface{})
  32. Debugln(...interface{})
  33. }
  34. func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  35. if !websocket.IsWebSocketUpgrade(r) {
  36. p.h.ServeHTTP(w, r)
  37. return
  38. }
  39. p.proxy(w, r)
  40. }
  41. // Option allows customization of the proxy.
  42. type Option func(*Proxy)
  43. // WithMethodParamOverride allows specification of the special http parameter that is used in the proxied streaming request.
  44. func WithMethodParamOverride(param string) Option {
  45. return func(p *Proxy) {
  46. p.methodOverrideParam = param
  47. }
  48. }
  49. // WithTokenCookieName allows specification of the cookie that is supplied as an upstream 'Authorization: Bearer' http header.
  50. func WithTokenCookieName(param string) Option {
  51. return func(p *Proxy) {
  52. p.tokenCookieName = param
  53. }
  54. }
  55. // WithRequestMutator allows a custom RequestMutatorFunc to be supplied.
  56. func WithRequestMutator(fn RequestMutatorFunc) Option {
  57. return func(p *Proxy) {
  58. p.requestMutator = fn
  59. }
  60. }
  61. // WithLogger allows a custom FieldLogger to be supplied
  62. func WithLogger(logger Logger) Option {
  63. return func(p *Proxy) {
  64. p.logger = logger
  65. }
  66. }
  67. // WebsocketProxy attempts to expose the underlying handler as a bidi websocket stream with newline-delimited
  68. // JSON as the content encoding.
  69. //
  70. // The HTTP Authorization header is either populated from the Sec-Websocket-Protocol field or by a cookie.
  71. // The cookie name is specified by the TokenCookieName value.
  72. //
  73. // example:
  74. // Sec-Websocket-Protocol: Bearer, foobar
  75. // is converted to:
  76. // Authorization: Bearer foobar
  77. //
  78. // Method can be overwritten with the MethodOverrideParam get parameter in the requested URL
  79. func WebsocketProxy(h http.Handler, opts ...Option) http.Handler {
  80. p := &Proxy{
  81. h: h,
  82. logger: logrus.New(),
  83. methodOverrideParam: MethodOverrideParam,
  84. tokenCookieName: TokenCookieName,
  85. }
  86. for _, o := range opts {
  87. o(p)
  88. }
  89. return p
  90. }
  91. // TODO(tmc): allow modification of upgrader settings?
  92. var upgrader = websocket.Upgrader{
  93. ReadBufferSize: 1024,
  94. WriteBufferSize: 1024,
  95. CheckOrigin: func(r *http.Request) bool { return true },
  96. }
  97. func isClosedConnError(err error) bool {
  98. str := err.Error()
  99. if strings.Contains(str, "use of closed network connection") {
  100. return true
  101. }
  102. return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway)
  103. }
  104. func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) {
  105. var responseHeader http.Header
  106. // If Sec-WebSocket-Protocol starts with "Bearer", respond in kind.
  107. // TODO(tmc): consider customizability/extension point here.
  108. if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") {
  109. responseHeader = http.Header{
  110. "Sec-WebSocket-Protocol": []string{"Bearer"},
  111. }
  112. }
  113. conn, err := upgrader.Upgrade(w, r, responseHeader)
  114. if err != nil {
  115. p.logger.Warnln("error upgrading websocket:", err)
  116. return
  117. }
  118. defer conn.Close()
  119. ctx, cancelFn := context.WithCancel(context.Background())
  120. defer cancelFn()
  121. requestBodyR, requestBodyW := io.Pipe()
  122. request, err := http.NewRequest(r.Method, r.URL.String(), requestBodyR)
  123. if err != nil {
  124. p.logger.Warnln("error preparing request:", err)
  125. return
  126. }
  127. if swsp := r.Header.Get("Sec-WebSocket-Protocol"); swsp != "" {
  128. request.Header.Set("Authorization", strings.Replace(swsp, "Bearer, ", "Bearer ", 1))
  129. }
  130. // If token cookie is present, populate Authorization header from the cookie instead.
  131. if cookie, err := r.Cookie(p.tokenCookieName); err == nil {
  132. request.Header.Set("Authorization", "Bearer "+cookie.Value)
  133. }
  134. if m := r.URL.Query().Get(p.methodOverrideParam); m != "" {
  135. request.Method = m
  136. }
  137. if p.requestMutator != nil {
  138. request = p.requestMutator(r, request)
  139. }
  140. responseBodyR, responseBodyW := io.Pipe()
  141. response := newInMemoryResponseWriter(responseBodyW)
  142. go func() {
  143. <-ctx.Done()
  144. p.logger.Debugln("closing pipes")
  145. requestBodyW.CloseWithError(io.EOF)
  146. responseBodyW.CloseWithError(io.EOF)
  147. response.closed <- true
  148. }()
  149. go func() {
  150. defer cancelFn()
  151. p.h.ServeHTTP(response, request)
  152. }()
  153. // read loop -- take messages from websocket and write to http request
  154. go func() {
  155. defer func() {
  156. cancelFn()
  157. }()
  158. for {
  159. select {
  160. case <-ctx.Done():
  161. p.logger.Debugln("read loop done")
  162. return
  163. default:
  164. }
  165. p.logger.Debugln("[read] reading from socket.")
  166. _, payload, err := conn.ReadMessage()
  167. if err != nil {
  168. if isClosedConnError(err) {
  169. p.logger.Debugln("[read] websocket closed:", err)
  170. return
  171. }
  172. p.logger.Warnln("error reading websocket message:", err)
  173. return
  174. }
  175. p.logger.Debugln("[read] read payload:", string(payload))
  176. p.logger.Debugln("[read] writing to requestBody:")
  177. n, err := requestBodyW.Write(payload)
  178. requestBodyW.Write([]byte("\n"))
  179. p.logger.Debugln("[read] wrote to requestBody", n)
  180. if err != nil {
  181. p.logger.Warnln("[read] error writing message to upstream http server:", err)
  182. return
  183. }
  184. }
  185. }()
  186. // write loop -- take messages from response and write to websocket
  187. scanner := bufio.NewScanner(responseBodyR)
  188. for scanner.Scan() {
  189. if len(scanner.Bytes()) == 0 {
  190. p.logger.Warnln("[write] empty scan", scanner.Err())
  191. continue
  192. }
  193. p.logger.Debugln("[write] scanned", scanner.Text())
  194. if err = conn.WriteMessage(websocket.TextMessage, scanner.Bytes()); err != nil {
  195. p.logger.Warnln("[write] error writing websocket message:", err)
  196. return
  197. }
  198. }
  199. if err := scanner.Err(); err != nil {
  200. p.logger.Warnln("scanner err:", err)
  201. }
  202. }
  203. type inMemoryResponseWriter struct {
  204. io.Writer
  205. header http.Header
  206. code int
  207. closed chan bool
  208. }
  209. func newInMemoryResponseWriter(w io.Writer) *inMemoryResponseWriter {
  210. return &inMemoryResponseWriter{
  211. Writer: w,
  212. header: http.Header{},
  213. closed: make(chan bool, 1),
  214. }
  215. }
  216. func (w *inMemoryResponseWriter) Write(b []byte) (int, error) {
  217. return w.Writer.Write(b)
  218. }
  219. func (w *inMemoryResponseWriter) Header() http.Header {
  220. return w.header
  221. }
  222. func (w *inMemoryResponseWriter) WriteHeader(code int) {
  223. w.code = code
  224. }
  225. func (w *inMemoryResponseWriter) CloseNotify() <-chan bool {
  226. return w.closed
  227. }
  228. func (w *inMemoryResponseWriter) Flush() {}