123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- package cmux
- import (
- "bytes"
- "fmt"
- "io"
- "net"
- "sync"
- )
- // Matcher matches a connection based on its content.
- type Matcher func(io.Reader) bool
- // ErrorHandler handles an error and returns whether
- // the mux should continue serving the listener.
- type ErrorHandler func(error) bool
- var _ net.Error = ErrNotMatched{}
- // ErrNotMatched is returned whenever a connection is not matched by any of
- // the matchers registered in the multiplexer.
- type ErrNotMatched struct {
- c net.Conn
- }
- func (e ErrNotMatched) Error() string {
- return fmt.Sprintf("mux: connection %v not matched by an matcher",
- e.c.RemoteAddr())
- }
- // Temporary implements the net.Error interface.
- func (e ErrNotMatched) Temporary() bool { return true }
- // Timeout implements the net.Error interface.
- func (e ErrNotMatched) Timeout() bool { return false }
- type errListenerClosed string
- func (e errListenerClosed) Error() string { return string(e) }
- func (e errListenerClosed) Temporary() bool { return false }
- func (e errListenerClosed) Timeout() bool { return false }
- // ErrListenerClosed is returned from muxListener.Accept when the underlying
- // listener is closed.
- var ErrListenerClosed = errListenerClosed("mux: listener closed")
- // New instantiates a new connection multiplexer.
- func New(l net.Listener) CMux {
- return &cMux{
- root: l,
- bufLen: 1024,
- errh: func(_ error) bool { return true },
- donec: make(chan struct{}),
- }
- }
- // CMux is a multiplexer for network connections.
- type CMux interface {
- // Match returns a net.Listener that sees (i.e., accepts) only
- // the connections matched by at least one of the matcher.
- //
- // The order used to call Match determines the priority of matchers.
- Match(...Matcher) net.Listener
- // Serve starts multiplexing the listener. Serve blocks and perhaps
- // should be invoked concurrently within a go routine.
- Serve() error
- // HandleError registers an error handler that handles listener errors.
- HandleError(ErrorHandler)
- }
- type matchersListener struct {
- ss []Matcher
- l muxListener
- }
- type cMux struct {
- root net.Listener
- bufLen int
- errh ErrorHandler
- donec chan struct{}
- sls []matchersListener
- }
- func (m *cMux) Match(matchers ...Matcher) net.Listener {
- ml := muxListener{
- Listener: m.root,
- connc: make(chan net.Conn, m.bufLen),
- }
- m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
- return ml
- }
- func (m *cMux) Serve() error {
- var wg sync.WaitGroup
- defer func() {
- close(m.donec)
- wg.Wait()
- for _, sl := range m.sls {
- close(sl.l.connc)
- // Drain the connections enqueued for the listener.
- for c := range sl.l.connc {
- _ = c.Close()
- }
- }
- }()
- for {
- c, err := m.root.Accept()
- if err != nil {
- if !m.handleErr(err) {
- return err
- }
- continue
- }
- wg.Add(1)
- go m.serve(c, m.donec, &wg)
- }
- }
- func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
- defer wg.Done()
- muc := newMuxConn(c)
- for _, sl := range m.sls {
- for _, s := range sl.ss {
- matched := s(muc.getSniffer())
- if matched {
- select {
- case sl.l.connc <- muc:
- case <-donec:
- _ = c.Close()
- }
- return
- }
- }
- }
- _ = c.Close()
- err := ErrNotMatched{c: c}
- if !m.handleErr(err) {
- _ = m.root.Close()
- }
- }
- func (m *cMux) HandleError(h ErrorHandler) {
- m.errh = h
- }
- func (m *cMux) handleErr(err error) bool {
- if !m.errh(err) {
- return false
- }
- if ne, ok := err.(net.Error); ok {
- return ne.Temporary()
- }
- return false
- }
- type muxListener struct {
- net.Listener
- connc chan net.Conn
- }
- func (l muxListener) Accept() (net.Conn, error) {
- c, ok := <-l.connc
- if !ok {
- return nil, ErrListenerClosed
- }
- return c, nil
- }
- // MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
- type MuxConn struct {
- net.Conn
- buf bytes.Buffer
- sniffer bufferedReader
- }
- func newMuxConn(c net.Conn) *MuxConn {
- return &MuxConn{
- Conn: c,
- }
- }
- // From the io.Reader documentation:
- //
- // When Read encounters an error or end-of-file condition after
- // successfully reading n > 0 bytes, it returns the number of
- // bytes read. It may return the (non-nil) error from the same call
- // or return the error (and n == 0) from a subsequent call.
- // An instance of this general case is that a Reader returning
- // a non-zero number of bytes at the end of the input stream may
- // return either err == EOF or err == nil. The next Read should
- // return 0, EOF.
- func (m *MuxConn) Read(p []byte) (int, error) {
- if n, err := m.buf.Read(p); err != io.EOF {
- return n, err
- }
- return m.Conn.Read(p)
- }
- func (m *MuxConn) getSniffer() io.Reader {
- m.sniffer = bufferedReader{source: m.Conn, buffer: &m.buf, bufferSize: m.buf.Len()}
- return &m.sniffer
- }
|