123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 |
- package runtime
- import (
- "context"
- "fmt"
- "net/http"
- "net/textproto"
- "strings"
- "github.com/golang/protobuf/proto"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/metadata"
- "google.golang.org/grpc/status"
- )
- // A HandlerFunc handles a specific pair of path pattern and HTTP method.
- type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
- // ErrUnknownURI is the error supplied to a custom ProtoErrorHandlerFunc when
- // a request is received with a URI path that does not match any registered
- // service method.
- //
- // Since gRPC servers return an "Unimplemented" code for requests with an
- // unrecognized URI path, this error also has a gRPC "Unimplemented" code.
- var ErrUnknownURI = status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
- // ServeMux is a request multiplexer for grpc-gateway.
- // It matches http requests to patterns and invokes the corresponding handler.
- type ServeMux struct {
- // handlers maps HTTP method to a list of handlers.
- handlers map[string][]handler
- forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
- marshalers marshalerRegistry
- incomingHeaderMatcher HeaderMatcherFunc
- outgoingHeaderMatcher HeaderMatcherFunc
- metadataAnnotators []func(context.Context, *http.Request) metadata.MD
- streamErrorHandler StreamErrorHandlerFunc
- protoErrorHandler ProtoErrorHandlerFunc
- disablePathLengthFallback bool
- lastMatchWins bool
- }
- // ServeMuxOption is an option that can be given to a ServeMux on construction.
- type ServeMuxOption func(*ServeMux)
- // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
- //
- // forwardResponseOption is an option that will be called on the relevant context.Context,
- // http.ResponseWriter, and proto.Message before every forwarded response.
- //
- // The message may be nil in the case where just a header is being sent.
- func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
- return func(serveMux *ServeMux) {
- serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
- }
- }
- // HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
- type HeaderMatcherFunc func(string) (string, bool)
- // DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
- // keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
- // 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
- func DefaultHeaderMatcher(key string) (string, bool) {
- key = textproto.CanonicalMIMEHeaderKey(key)
- if isPermanentHTTPHeader(key) {
- return MetadataPrefix + key, true
- } else if strings.HasPrefix(key, MetadataHeaderPrefix) {
- return key[len(MetadataHeaderPrefix):], true
- }
- return "", false
- }
- // WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
- //
- // This matcher will be called with each header in http.Request. If matcher returns true, that header will be
- // passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
- func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
- return func(mux *ServeMux) {
- mux.incomingHeaderMatcher = fn
- }
- }
- // WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
- //
- // This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
- // passed to http response returned from gateway. To transform the header before passing to response,
- // matcher should return modified header.
- func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
- return func(mux *ServeMux) {
- mux.outgoingHeaderMatcher = fn
- }
- }
- // WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
- //
- // This can be used by services that need to read from http.Request and modify gRPC context. A common use case
- // is reading token from cookie and adding it in gRPC context.
- func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
- return func(serveMux *ServeMux) {
- serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
- }
- }
- // WithProtoErrorHandler returns a ServeMuxOption for passing metadata to a gRPC context.
- //
- // This can be used to handle an error as general proto message defined by gRPC.
- // The response including body and status is not backward compatible with the default error handler.
- // When this option is used, HTTPError and OtherErrorHandler are overwritten on initialization.
- func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
- return func(serveMux *ServeMux) {
- serveMux.protoErrorHandler = fn
- }
- }
- // WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
- func WithDisablePathLengthFallback() ServeMuxOption {
- return func(serveMux *ServeMux) {
- serveMux.disablePathLengthFallback = true
- }
- }
- // WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
- // error handler, which allows for customizing the error trailer for server-streaming
- // calls.
- //
- // For stream errors that occur before any response has been written, the mux's
- // ProtoErrorHandler will be invoked. However, once data has been written, the errors must
- // be handled differently: they must be included in the response body. The response body's
- // final message will include the error details returned by the stream error handler.
- func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
- return func(serveMux *ServeMux) {
- serveMux.streamErrorHandler = fn
- }
- }
- // WithLastMatchWins returns a ServeMuxOption that will enable "last
- // match wins" behavior, where if multiple path patterns match a
- // request path, the last one defined in the .proto file will be used.
- func WithLastMatchWins() ServeMuxOption {
- return func(serveMux *ServeMux) {
- serveMux.lastMatchWins = true
- }
- }
- // NewServeMux returns a new ServeMux whose internal mapping is empty.
- func NewServeMux(opts ...ServeMuxOption) *ServeMux {
- serveMux := &ServeMux{
- handlers: make(map[string][]handler),
- forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
- marshalers: makeMarshalerMIMERegistry(),
- streamErrorHandler: DefaultHTTPStreamErrorHandler,
- }
- for _, opt := range opts {
- opt(serveMux)
- }
- if serveMux.protoErrorHandler != nil {
- HTTPError = serveMux.protoErrorHandler
- // OtherErrorHandler is no longer used when protoErrorHandler is set.
- // Overwritten by a special error handler to return Unknown.
- OtherErrorHandler = func(w http.ResponseWriter, r *http.Request, _ string, _ int) {
- ctx := context.Background()
- _, outboundMarshaler := MarshalerForRequest(serveMux, r)
- sterr := status.Error(codes.Unknown, "unexpected use of OtherErrorHandler")
- serveMux.protoErrorHandler(ctx, serveMux, outboundMarshaler, w, r, sterr)
- }
- }
- if serveMux.incomingHeaderMatcher == nil {
- serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
- }
- if serveMux.outgoingHeaderMatcher == nil {
- serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
- return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
- }
- }
- return serveMux
- }
- // Handle associates "h" to the pair of HTTP method and path pattern.
- func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
- if s.lastMatchWins {
- s.handlers[meth] = append([]handler{handler{pat: pat, h: h}}, s.handlers[meth]...)
- } else {
- s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h})
- }
- }
- // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
- func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- ctx := r.Context()
- path := r.URL.Path
- if !strings.HasPrefix(path, "/") {
- if s.protoErrorHandler != nil {
- _, outboundMarshaler := MarshalerForRequest(s, r)
- sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
- s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
- } else {
- OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
- }
- return
- }
- components := strings.Split(path[1:], "/")
- l := len(components)
- var verb string
- if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
- if s.protoErrorHandler != nil {
- _, outboundMarshaler := MarshalerForRequest(s, r)
- s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
- } else {
- OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
- }
- return
- } else if idx > 0 {
- c := components[l-1]
- components[l-1], verb = c[:idx], c[idx+1:]
- }
- if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
- r.Method = strings.ToUpper(override)
- if err := r.ParseForm(); err != nil {
- if s.protoErrorHandler != nil {
- _, outboundMarshaler := MarshalerForRequest(s, r)
- sterr := status.Error(codes.InvalidArgument, err.Error())
- s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
- } else {
- OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
- }
- return
- }
- }
- for _, h := range s.handlers[r.Method] {
- pathParams, err := h.pat.Match(components, verb)
- if err != nil {
- continue
- }
- h.h(w, r, pathParams)
- return
- }
- // lookup other methods to handle fallback from GET to POST and
- // to determine if it is MethodNotAllowed or NotFound.
- for m, handlers := range s.handlers {
- if m == r.Method {
- continue
- }
- for _, h := range handlers {
- pathParams, err := h.pat.Match(components, verb)
- if err != nil {
- continue
- }
- // X-HTTP-Method-Override is optional. Always allow fallback to POST.
- if s.isPathLengthFallback(r) {
- if err := r.ParseForm(); err != nil {
- if s.protoErrorHandler != nil {
- _, outboundMarshaler := MarshalerForRequest(s, r)
- sterr := status.Error(codes.InvalidArgument, err.Error())
- s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
- } else {
- OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
- }
- return
- }
- h.h(w, r, pathParams)
- return
- }
- if s.protoErrorHandler != nil {
- _, outboundMarshaler := MarshalerForRequest(s, r)
- s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
- } else {
- OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
- }
- return
- }
- }
- if s.protoErrorHandler != nil {
- _, outboundMarshaler := MarshalerForRequest(s, r)
- s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
- } else {
- OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
- }
- }
- // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
- func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
- return s.forwardResponseOptions
- }
- func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
- return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
- }
- type handler struct {
- pat Pattern
- h HandlerFunc
- }
|