cmux.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. // Copyright 2016 The CMux Authors. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
  12. // implied. See the License for the specific language governing
  13. // permissions and limitations under the License.
  14. package cmux
  15. import (
  16. "fmt"
  17. "io"
  18. "net"
  19. "sync"
  20. "time"
  21. )
  22. // Matcher matches a connection based on its content.
  23. type Matcher func(io.Reader) bool
  24. // MatchWriter is a match that can also write response (say to do handshake).
  25. type MatchWriter func(io.Writer, io.Reader) bool
  26. // ErrorHandler handles an error and returns whether
  27. // the mux should continue serving the listener.
  28. type ErrorHandler func(error) bool
  29. var _ net.Error = ErrNotMatched{}
  30. // ErrNotMatched is returned whenever a connection is not matched by any of
  31. // the matchers registered in the multiplexer.
  32. type ErrNotMatched struct {
  33. c net.Conn
  34. }
  35. func (e ErrNotMatched) Error() string {
  36. return fmt.Sprintf("mux: connection %v not matched by an matcher",
  37. e.c.RemoteAddr())
  38. }
  39. // Temporary implements the net.Error interface.
  40. func (e ErrNotMatched) Temporary() bool { return true }
  41. // Timeout implements the net.Error interface.
  42. func (e ErrNotMatched) Timeout() bool { return false }
  43. type errListenerClosed string
  44. func (e errListenerClosed) Error() string { return string(e) }
  45. func (e errListenerClosed) Temporary() bool { return false }
  46. func (e errListenerClosed) Timeout() bool { return false }
  47. // ErrListenerClosed is returned from muxListener.Accept when the underlying
  48. // listener is closed.
  49. var ErrListenerClosed = errListenerClosed("mux: listener closed")
  50. // for readability of readTimeout
  51. var noTimeout time.Duration
  52. // New instantiates a new connection multiplexer.
  53. func New(l net.Listener) CMux {
  54. return &cMux{
  55. root: l,
  56. bufLen: 1024,
  57. errh: func(_ error) bool { return true },
  58. donec: make(chan struct{}),
  59. readTimeout: noTimeout,
  60. }
  61. }
  62. // CMux is a multiplexer for network connections.
  63. type CMux interface {
  64. // Match returns a net.Listener that sees (i.e., accepts) only
  65. // the connections matched by at least one of the matcher.
  66. //
  67. // The order used to call Match determines the priority of matchers.
  68. Match(...Matcher) net.Listener
  69. // MatchWithWriters returns a net.Listener that accepts only the
  70. // connections that matched by at least of the matcher writers.
  71. //
  72. // Prefer Matchers over MatchWriters, since the latter can write on the
  73. // connection before the actual handler.
  74. //
  75. // The order used to call Match determines the priority of matchers.
  76. MatchWithWriters(...MatchWriter) net.Listener
  77. // Serve starts multiplexing the listener. Serve blocks and perhaps
  78. // should be invoked concurrently within a go routine.
  79. Serve() error
  80. // HandleError registers an error handler that handles listener errors.
  81. HandleError(ErrorHandler)
  82. // sets a timeout for the read of matchers
  83. SetReadTimeout(time.Duration)
  84. }
  85. type matchersListener struct {
  86. ss []MatchWriter
  87. l muxListener
  88. }
  89. type cMux struct {
  90. root net.Listener
  91. bufLen int
  92. errh ErrorHandler
  93. donec chan struct{}
  94. sls []matchersListener
  95. readTimeout time.Duration
  96. }
  97. func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
  98. mws := make([]MatchWriter, 0, len(matchers))
  99. for _, m := range matchers {
  100. cm := m
  101. mws = append(mws, func(w io.Writer, r io.Reader) bool {
  102. return cm(r)
  103. })
  104. }
  105. return mws
  106. }
  107. func (m *cMux) Match(matchers ...Matcher) net.Listener {
  108. mws := matchersToMatchWriters(matchers)
  109. return m.MatchWithWriters(mws...)
  110. }
  111. func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
  112. ml := muxListener{
  113. Listener: m.root,
  114. connc: make(chan net.Conn, m.bufLen),
  115. }
  116. m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
  117. return ml
  118. }
  119. func (m *cMux) SetReadTimeout(t time.Duration) {
  120. m.readTimeout = t
  121. }
  122. func (m *cMux) Serve() error {
  123. var wg sync.WaitGroup
  124. defer func() {
  125. close(m.donec)
  126. wg.Wait()
  127. for _, sl := range m.sls {
  128. close(sl.l.connc)
  129. // Drain the connections enqueued for the listener.
  130. for c := range sl.l.connc {
  131. _ = c.Close()
  132. }
  133. }
  134. }()
  135. for {
  136. c, err := m.root.Accept()
  137. if err != nil {
  138. if !m.handleErr(err) {
  139. return err
  140. }
  141. continue
  142. }
  143. wg.Add(1)
  144. go m.serve(c, m.donec, &wg)
  145. }
  146. }
  147. func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
  148. defer wg.Done()
  149. muc := newMuxConn(c)
  150. if m.readTimeout > noTimeout {
  151. _ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
  152. }
  153. for _, sl := range m.sls {
  154. for _, s := range sl.ss {
  155. matched := s(muc.Conn, muc.startSniffing())
  156. if matched {
  157. muc.doneSniffing()
  158. if m.readTimeout > noTimeout {
  159. _ = c.SetReadDeadline(time.Time{})
  160. }
  161. select {
  162. case sl.l.connc <- muc:
  163. case <-donec:
  164. _ = c.Close()
  165. }
  166. return
  167. }
  168. }
  169. }
  170. _ = c.Close()
  171. err := ErrNotMatched{c: c}
  172. if !m.handleErr(err) {
  173. _ = m.root.Close()
  174. }
  175. }
  176. func (m *cMux) HandleError(h ErrorHandler) {
  177. m.errh = h
  178. }
  179. func (m *cMux) handleErr(err error) bool {
  180. if !m.errh(err) {
  181. return false
  182. }
  183. if ne, ok := err.(net.Error); ok {
  184. return ne.Temporary()
  185. }
  186. return false
  187. }
  188. type muxListener struct {
  189. net.Listener
  190. connc chan net.Conn
  191. }
  192. func (l muxListener) Accept() (net.Conn, error) {
  193. c, ok := <-l.connc
  194. if !ok {
  195. return nil, ErrListenerClosed
  196. }
  197. return c, nil
  198. }
  199. // MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
  200. type MuxConn struct {
  201. net.Conn
  202. buf bufferedReader
  203. }
  204. func newMuxConn(c net.Conn) *MuxConn {
  205. return &MuxConn{
  206. Conn: c,
  207. buf: bufferedReader{source: c},
  208. }
  209. }
  210. // From the io.Reader documentation:
  211. //
  212. // When Read encounters an error or end-of-file condition after
  213. // successfully reading n > 0 bytes, it returns the number of
  214. // bytes read. It may return the (non-nil) error from the same call
  215. // or return the error (and n == 0) from a subsequent call.
  216. // An instance of this general case is that a Reader returning
  217. // a non-zero number of bytes at the end of the input stream may
  218. // return either err == EOF or err == nil. The next Read should
  219. // return 0, EOF.
  220. func (m *MuxConn) Read(p []byte) (int, error) {
  221. return m.buf.Read(p)
  222. }
  223. func (m *MuxConn) startSniffing() io.Reader {
  224. m.buf.reset(true)
  225. return &m.buf
  226. }
  227. func (m *MuxConn) doneSniffing() {
  228. m.buf.reset(false)
  229. }