mux.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package runtime
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "net/textproto"
  7. "strings"
  8. "github.com/golang/protobuf/proto"
  9. "google.golang.org/grpc/codes"
  10. "google.golang.org/grpc/metadata"
  11. "google.golang.org/grpc/status"
  12. )
  13. // A HandlerFunc handles a specific pair of path pattern and HTTP method.
  14. type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
  15. // ErrUnknownURI is the error supplied to a custom ProtoErrorHandlerFunc when
  16. // a request is received with a URI path that does not match any registered
  17. // service method.
  18. //
  19. // Since gRPC servers return an "Unimplemented" code for requests with an
  20. // unrecognized URI path, this error also has a gRPC "Unimplemented" code.
  21. var ErrUnknownURI = status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
  22. // ServeMux is a request multiplexer for grpc-gateway.
  23. // It matches http requests to patterns and invokes the corresponding handler.
  24. type ServeMux struct {
  25. // handlers maps HTTP method to a list of handlers.
  26. handlers map[string][]handler
  27. forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
  28. marshalers marshalerRegistry
  29. incomingHeaderMatcher HeaderMatcherFunc
  30. outgoingHeaderMatcher HeaderMatcherFunc
  31. metadataAnnotators []func(context.Context, *http.Request) metadata.MD
  32. streamErrorHandler StreamErrorHandlerFunc
  33. protoErrorHandler ProtoErrorHandlerFunc
  34. disablePathLengthFallback bool
  35. lastMatchWins bool
  36. }
  37. // ServeMuxOption is an option that can be given to a ServeMux on construction.
  38. type ServeMuxOption func(*ServeMux)
  39. // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
  40. //
  41. // forwardResponseOption is an option that will be called on the relevant context.Context,
  42. // http.ResponseWriter, and proto.Message before every forwarded response.
  43. //
  44. // The message may be nil in the case where just a header is being sent.
  45. func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
  46. return func(serveMux *ServeMux) {
  47. serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
  48. }
  49. }
  50. // HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
  51. type HeaderMatcherFunc func(string) (string, bool)
  52. // DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
  53. // keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
  54. // 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
  55. func DefaultHeaderMatcher(key string) (string, bool) {
  56. key = textproto.CanonicalMIMEHeaderKey(key)
  57. if isPermanentHTTPHeader(key) {
  58. return MetadataPrefix + key, true
  59. } else if strings.HasPrefix(key, MetadataHeaderPrefix) {
  60. return key[len(MetadataHeaderPrefix):], true
  61. }
  62. return "", false
  63. }
  64. // WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
  65. //
  66. // This matcher will be called with each header in http.Request. If matcher returns true, that header will be
  67. // passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
  68. func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
  69. return func(mux *ServeMux) {
  70. mux.incomingHeaderMatcher = fn
  71. }
  72. }
  73. // WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
  74. //
  75. // This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
  76. // passed to http response returned from gateway. To transform the header before passing to response,
  77. // matcher should return modified header.
  78. func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
  79. return func(mux *ServeMux) {
  80. mux.outgoingHeaderMatcher = fn
  81. }
  82. }
  83. // WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
  84. //
  85. // This can be used by services that need to read from http.Request and modify gRPC context. A common use case
  86. // is reading token from cookie and adding it in gRPC context.
  87. func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
  88. return func(serveMux *ServeMux) {
  89. serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
  90. }
  91. }
  92. // WithProtoErrorHandler returns a ServeMuxOption for passing metadata to a gRPC context.
  93. //
  94. // This can be used to handle an error as general proto message defined by gRPC.
  95. // The response including body and status is not backward compatible with the default error handler.
  96. // When this option is used, HTTPError and OtherErrorHandler are overwritten on initialization.
  97. func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
  98. return func(serveMux *ServeMux) {
  99. serveMux.protoErrorHandler = fn
  100. }
  101. }
  102. // WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
  103. func WithDisablePathLengthFallback() ServeMuxOption {
  104. return func(serveMux *ServeMux) {
  105. serveMux.disablePathLengthFallback = true
  106. }
  107. }
  108. // WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
  109. // error handler, which allows for customizing the error trailer for server-streaming
  110. // calls.
  111. //
  112. // For stream errors that occur before any response has been written, the mux's
  113. // ProtoErrorHandler will be invoked. However, once data has been written, the errors must
  114. // be handled differently: they must be included in the response body. The response body's
  115. // final message will include the error details returned by the stream error handler.
  116. func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
  117. return func(serveMux *ServeMux) {
  118. serveMux.streamErrorHandler = fn
  119. }
  120. }
  121. // WithLastMatchWins returns a ServeMuxOption that will enable "last
  122. // match wins" behavior, where if multiple path patterns match a
  123. // request path, the last one defined in the .proto file will be used.
  124. func WithLastMatchWins() ServeMuxOption {
  125. return func(serveMux *ServeMux) {
  126. serveMux.lastMatchWins = true
  127. }
  128. }
  129. // NewServeMux returns a new ServeMux whose internal mapping is empty.
  130. func NewServeMux(opts ...ServeMuxOption) *ServeMux {
  131. serveMux := &ServeMux{
  132. handlers: make(map[string][]handler),
  133. forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
  134. marshalers: makeMarshalerMIMERegistry(),
  135. streamErrorHandler: DefaultHTTPStreamErrorHandler,
  136. }
  137. for _, opt := range opts {
  138. opt(serveMux)
  139. }
  140. if serveMux.protoErrorHandler != nil {
  141. HTTPError = serveMux.protoErrorHandler
  142. // OtherErrorHandler is no longer used when protoErrorHandler is set.
  143. // Overwritten by a special error handler to return Unknown.
  144. OtherErrorHandler = func(w http.ResponseWriter, r *http.Request, _ string, _ int) {
  145. ctx := context.Background()
  146. _, outboundMarshaler := MarshalerForRequest(serveMux, r)
  147. sterr := status.Error(codes.Unknown, "unexpected use of OtherErrorHandler")
  148. serveMux.protoErrorHandler(ctx, serveMux, outboundMarshaler, w, r, sterr)
  149. }
  150. }
  151. if serveMux.incomingHeaderMatcher == nil {
  152. serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
  153. }
  154. if serveMux.outgoingHeaderMatcher == nil {
  155. serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
  156. return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
  157. }
  158. }
  159. return serveMux
  160. }
  161. // Handle associates "h" to the pair of HTTP method and path pattern.
  162. func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
  163. if s.lastMatchWins {
  164. s.handlers[meth] = append([]handler{handler{pat: pat, h: h}}, s.handlers[meth]...)
  165. } else {
  166. s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h})
  167. }
  168. }
  169. // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
  170. func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  171. ctx := r.Context()
  172. path := r.URL.Path
  173. if !strings.HasPrefix(path, "/") {
  174. if s.protoErrorHandler != nil {
  175. _, outboundMarshaler := MarshalerForRequest(s, r)
  176. sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
  177. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  178. } else {
  179. OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
  180. }
  181. return
  182. }
  183. components := strings.Split(path[1:], "/")
  184. l := len(components)
  185. var verb string
  186. if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
  187. if s.protoErrorHandler != nil {
  188. _, outboundMarshaler := MarshalerForRequest(s, r)
  189. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
  190. } else {
  191. OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
  192. }
  193. return
  194. } else if idx > 0 {
  195. c := components[l-1]
  196. components[l-1], verb = c[:idx], c[idx+1:]
  197. }
  198. if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
  199. r.Method = strings.ToUpper(override)
  200. if err := r.ParseForm(); err != nil {
  201. if s.protoErrorHandler != nil {
  202. _, outboundMarshaler := MarshalerForRequest(s, r)
  203. sterr := status.Error(codes.InvalidArgument, err.Error())
  204. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  205. } else {
  206. OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
  207. }
  208. return
  209. }
  210. }
  211. for _, h := range s.handlers[r.Method] {
  212. pathParams, err := h.pat.Match(components, verb)
  213. if err != nil {
  214. continue
  215. }
  216. h.h(w, r, pathParams)
  217. return
  218. }
  219. // lookup other methods to handle fallback from GET to POST and
  220. // to determine if it is MethodNotAllowed or NotFound.
  221. for m, handlers := range s.handlers {
  222. if m == r.Method {
  223. continue
  224. }
  225. for _, h := range handlers {
  226. pathParams, err := h.pat.Match(components, verb)
  227. if err != nil {
  228. continue
  229. }
  230. // X-HTTP-Method-Override is optional. Always allow fallback to POST.
  231. if s.isPathLengthFallback(r) {
  232. if err := r.ParseForm(); err != nil {
  233. if s.protoErrorHandler != nil {
  234. _, outboundMarshaler := MarshalerForRequest(s, r)
  235. sterr := status.Error(codes.InvalidArgument, err.Error())
  236. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  237. } else {
  238. OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
  239. }
  240. return
  241. }
  242. h.h(w, r, pathParams)
  243. return
  244. }
  245. if s.protoErrorHandler != nil {
  246. _, outboundMarshaler := MarshalerForRequest(s, r)
  247. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
  248. } else {
  249. OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
  250. }
  251. return
  252. }
  253. }
  254. if s.protoErrorHandler != nil {
  255. _, outboundMarshaler := MarshalerForRequest(s, r)
  256. s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
  257. } else {
  258. OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
  259. }
  260. }
  261. // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
  262. func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
  263. return s.forwardResponseOptions
  264. }
  265. func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
  266. return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
  267. }
  268. type handler struct {
  269. pat Pattern
  270. h HandlerFunc
  271. }