matchers.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. package cmux
  2. import (
  3. "bufio"
  4. "io"
  5. "io/ioutil"
  6. "net/http"
  7. "strings"
  8. "golang.org/x/net/http2"
  9. "golang.org/x/net/http2/hpack"
  10. )
  11. // Any is a Matcher that matches any connection.
  12. func Any() Matcher {
  13. return func(r io.Reader) bool { return true }
  14. }
  15. // PrefixMatcher returns a matcher that matches a connection if it
  16. // starts with any of the strings in strs.
  17. func PrefixMatcher(strs ...string) Matcher {
  18. pt := newPatriciaTreeString(strs...)
  19. return pt.matchPrefix
  20. }
  21. var defaultHTTPMethods = []string{
  22. "OPTIONS",
  23. "GET",
  24. "HEAD",
  25. "POST",
  26. "PUT",
  27. "DELETE",
  28. "TRACE",
  29. "CONNECT",
  30. }
  31. // HTTP1Fast only matches the methods in the HTTP request.
  32. //
  33. // This matcher is very optimistic: if it returns true, it does not mean that
  34. // the request is a valid HTTP response. If you want a correct but slower HTTP1
  35. // matcher, use HTTP1 instead.
  36. func HTTP1Fast(extMethods ...string) Matcher {
  37. return PrefixMatcher(append(defaultHTTPMethods, extMethods...)...)
  38. }
  39. const maxHTTPRead = 4096
  40. // HTTP1 parses the first line or upto 4096 bytes of the request to see if
  41. // the conection contains an HTTP request.
  42. func HTTP1() Matcher {
  43. return func(r io.Reader) bool {
  44. br := bufio.NewReader(&io.LimitedReader{R: r, N: maxHTTPRead})
  45. l, part, err := br.ReadLine()
  46. if err != nil || part {
  47. return false
  48. }
  49. _, _, proto, ok := parseRequestLine(string(l))
  50. if !ok {
  51. return false
  52. }
  53. v, _, ok := http.ParseHTTPVersion(proto)
  54. return ok && v == 1
  55. }
  56. }
  57. // grabbed from net/http.
  58. func parseRequestLine(line string) (method, uri, proto string, ok bool) {
  59. s1 := strings.Index(line, " ")
  60. s2 := strings.Index(line[s1+1:], " ")
  61. if s1 < 0 || s2 < 0 {
  62. return
  63. }
  64. s2 += s1 + 1
  65. return line[:s1], line[s1+1 : s2], line[s2+1:], true
  66. }
  67. // HTTP2 parses the frame header of the first frame to detect whether the
  68. // connection is an HTTP2 connection.
  69. func HTTP2() Matcher {
  70. return hasHTTP2Preface
  71. }
  72. // HTTP1HeaderField returns a matcher matching the header fields of the first
  73. // request of an HTTP 1 connection.
  74. func HTTP1HeaderField(name, value string) Matcher {
  75. return func(r io.Reader) bool {
  76. return matchHTTP1Field(r, name, value)
  77. }
  78. }
  79. // HTTP2HeaderField resturns a matcher matching the header fields of the first
  80. // headers frame.
  81. func HTTP2HeaderField(name, value string) Matcher {
  82. return func(r io.Reader) bool {
  83. return matchHTTP2Field(r, name, value)
  84. }
  85. }
  86. func hasHTTP2Preface(r io.Reader) bool {
  87. var b [len(http2.ClientPreface)]byte
  88. if _, err := io.ReadFull(r, b[:]); err != nil {
  89. return false
  90. }
  91. return string(b[:]) == http2.ClientPreface
  92. }
  93. func matchHTTP1Field(r io.Reader, name, value string) (matched bool) {
  94. req, err := http.ReadRequest(bufio.NewReader(r))
  95. if err != nil {
  96. return false
  97. }
  98. return req.Header.Get(name) == value
  99. }
  100. func matchHTTP2Field(r io.Reader, name, value string) (matched bool) {
  101. if !hasHTTP2Preface(r) {
  102. return false
  103. }
  104. framer := http2.NewFramer(ioutil.Discard, r)
  105. hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) {
  106. if hf.Name == name && hf.Value == value {
  107. matched = true
  108. }
  109. })
  110. for {
  111. f, err := framer.ReadFrame()
  112. if err != nil {
  113. return false
  114. }
  115. switch f := f.(type) {
  116. case *http2.HeadersFrame:
  117. if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil {
  118. return false
  119. }
  120. if matched {
  121. return true
  122. }
  123. if f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 {
  124. return false
  125. }
  126. }
  127. }
  128. }