cmux.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package cmux
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "net"
  7. "sync"
  8. )
  9. // Matcher matches a connection based on its content.
  10. type Matcher func(io.Reader) bool
  11. // ErrorHandler handles an error and returns whether
  12. // the mux should continue serving the listener.
  13. type ErrorHandler func(error) bool
  14. var _ net.Error = ErrNotMatched{}
  15. // ErrNotMatched is returned whenever a connection is not matched by any of
  16. // the matchers registered in the multiplexer.
  17. type ErrNotMatched struct {
  18. c net.Conn
  19. }
  20. func (e ErrNotMatched) Error() string {
  21. return fmt.Sprintf("mux: connection %v not matched by an matcher",
  22. e.c.RemoteAddr())
  23. }
  24. // Temporary implements the net.Error interface.
  25. func (e ErrNotMatched) Temporary() bool { return true }
  26. // Timeout implements the net.Error interface.
  27. func (e ErrNotMatched) Timeout() bool { return false }
  28. type errListenerClosed string
  29. func (e errListenerClosed) Error() string { return string(e) }
  30. func (e errListenerClosed) Temporary() bool { return false }
  31. func (e errListenerClosed) Timeout() bool { return false }
  32. // ErrListenerClosed is returned from muxListener.Accept when the underlying
  33. // listener is closed.
  34. var ErrListenerClosed = errListenerClosed("mux: listener closed")
  35. // New instantiates a new connection multiplexer.
  36. func New(l net.Listener) CMux {
  37. return &cMux{
  38. root: l,
  39. bufLen: 1024,
  40. errh: func(_ error) bool { return true },
  41. donec: make(chan struct{}),
  42. }
  43. }
  44. // CMux is a multiplexer for network connections.
  45. type CMux interface {
  46. // Match returns a net.Listener that sees (i.e., accepts) only
  47. // the connections matched by at least one of the matcher.
  48. //
  49. // The order used to call Match determines the priority of matchers.
  50. Match(...Matcher) net.Listener
  51. // Serve starts multiplexing the listener. Serve blocks and perhaps
  52. // should be invoked concurrently within a go routine.
  53. Serve() error
  54. // HandleError registers an error handler that handles listener errors.
  55. HandleError(ErrorHandler)
  56. }
  57. type matchersListener struct {
  58. ss []Matcher
  59. l muxListener
  60. }
  61. type cMux struct {
  62. root net.Listener
  63. bufLen int
  64. errh ErrorHandler
  65. donec chan struct{}
  66. sls []matchersListener
  67. }
  68. func (m *cMux) Match(matchers ...Matcher) net.Listener {
  69. ml := muxListener{
  70. Listener: m.root,
  71. connc: make(chan net.Conn, m.bufLen),
  72. }
  73. m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
  74. return ml
  75. }
  76. func (m *cMux) Serve() error {
  77. var wg sync.WaitGroup
  78. defer func() {
  79. close(m.donec)
  80. wg.Wait()
  81. for _, sl := range m.sls {
  82. close(sl.l.connc)
  83. // Drain the connections enqueued for the listener.
  84. for c := range sl.l.connc {
  85. _ = c.Close()
  86. }
  87. }
  88. }()
  89. for {
  90. c, err := m.root.Accept()
  91. if err != nil {
  92. if !m.handleErr(err) {
  93. return err
  94. }
  95. continue
  96. }
  97. wg.Add(1)
  98. go m.serve(c, m.donec, &wg)
  99. }
  100. }
  101. func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
  102. defer wg.Done()
  103. muc := newMuxConn(c)
  104. for _, sl := range m.sls {
  105. for _, s := range sl.ss {
  106. matched := s(muc.getSniffer())
  107. if matched {
  108. select {
  109. case sl.l.connc <- muc:
  110. case <-donec:
  111. _ = c.Close()
  112. }
  113. return
  114. }
  115. }
  116. }
  117. _ = c.Close()
  118. err := ErrNotMatched{c: c}
  119. if !m.handleErr(err) {
  120. _ = m.root.Close()
  121. }
  122. }
  123. func (m *cMux) HandleError(h ErrorHandler) {
  124. m.errh = h
  125. }
  126. func (m *cMux) handleErr(err error) bool {
  127. if !m.errh(err) {
  128. return false
  129. }
  130. if ne, ok := err.(net.Error); ok {
  131. return ne.Temporary()
  132. }
  133. return false
  134. }
  135. type muxListener struct {
  136. net.Listener
  137. connc chan net.Conn
  138. }
  139. func (l muxListener) Accept() (net.Conn, error) {
  140. c, ok := <-l.connc
  141. if !ok {
  142. return nil, ErrListenerClosed
  143. }
  144. return c, nil
  145. }
  146. // MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
  147. type MuxConn struct {
  148. net.Conn
  149. buf bytes.Buffer
  150. sniffer bufferedReader
  151. }
  152. func newMuxConn(c net.Conn) *MuxConn {
  153. return &MuxConn{
  154. Conn: c,
  155. }
  156. }
  157. // From the io.Reader documentation:
  158. //
  159. // When Read encounters an error or end-of-file condition after
  160. // successfully reading n > 0 bytes, it returns the number of
  161. // bytes read. It may return the (non-nil) error from the same call
  162. // or return the error (and n == 0) from a subsequent call.
  163. // An instance of this general case is that a Reader returning
  164. // a non-zero number of bytes at the end of the input stream may
  165. // return either err == EOF or err == nil. The next Read should
  166. // return 0, EOF.
  167. func (m *MuxConn) Read(p []byte) (int, error) {
  168. if n, err := m.buf.Read(p); err != io.EOF {
  169. return n, err
  170. }
  171. return m.Conn.Read(p)
  172. }
  173. func (m *MuxConn) getSniffer() io.Reader {
  174. m.sniffer = bufferedReader{source: m.Conn, buffer: &m.buf, bufferSize: m.buf.Len()}
  175. return &m.sniffer
  176. }