|
|
@@ -1,12 +1,16 @@
|
|
|
package runtime
|
|
|
|
|
|
import (
|
|
|
+ "fmt"
|
|
|
"net/http"
|
|
|
+ "net/textproto"
|
|
|
"strings"
|
|
|
|
|
|
- "golang.org/x/net/context"
|
|
|
-
|
|
|
"github.com/golang/protobuf/proto"
|
|
|
+ "golang.org/x/net/context"
|
|
|
+ "google.golang.org/grpc/codes"
|
|
|
+ "google.golang.org/grpc/metadata"
|
|
|
+ "google.golang.org/grpc/status"
|
|
|
)
|
|
|
|
|
|
// A HandlerFunc handles a specific pair of path pattern and HTTP method.
|
|
|
@@ -19,6 +23,10 @@ type ServeMux struct {
|
|
|
handlers map[string][]handler
|
|
|
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
|
|
|
marshalers marshalerRegistry
|
|
|
+ incomingHeaderMatcher HeaderMatcherFunc
|
|
|
+ outgoingHeaderMatcher HeaderMatcherFunc
|
|
|
+ metadataAnnotator func(context.Context, *http.Request) metadata.MD
|
|
|
+ protoErrorHandler ProtoErrorHandlerFunc
|
|
|
}
|
|
|
|
|
|
// ServeMuxOption is an option that can be given to a ServeMux on construction.
|
|
|
@@ -36,6 +44,64 @@ func WithForwardResponseOption(forwardResponseOption func(context.Context, http.
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
|
|
|
+type HeaderMatcherFunc func(string) (string, bool)
|
|
|
+
|
|
|
+// DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
|
|
|
+// keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
|
|
|
+// 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
|
|
|
+func DefaultHeaderMatcher(key string) (string, bool) {
|
|
|
+ key = textproto.CanonicalMIMEHeaderKey(key)
|
|
|
+ if isPermanentHTTPHeader(key) {
|
|
|
+ return MetadataPrefix + key, true
|
|
|
+ } else if strings.HasPrefix(key, MetadataHeaderPrefix) {
|
|
|
+ return key[len(MetadataHeaderPrefix):], true
|
|
|
+ }
|
|
|
+ return "", false
|
|
|
+}
|
|
|
+
|
|
|
+// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
|
|
|
+//
|
|
|
+// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
|
|
|
+// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
|
|
|
+func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
|
|
|
+ return func(mux *ServeMux) {
|
|
|
+ mux.incomingHeaderMatcher = fn
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
|
|
|
+//
|
|
|
+// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
|
|
|
+// passed to http response returned from gateway. To transform the header before passing to response,
|
|
|
+// matcher should return modified header.
|
|
|
+func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
|
|
|
+ return func(mux *ServeMux) {
|
|
|
+ mux.outgoingHeaderMatcher = fn
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
|
|
|
+//
|
|
|
+// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
|
|
|
+// is reading token from cookie and adding it in gRPC context.
|
|
|
+func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
|
|
|
+ return func(serveMux *ServeMux) {
|
|
|
+ serveMux.metadataAnnotator = annotator
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// WithProtoErrorHandler returns a ServeMuxOption for passing metadata to a gRPC context.
|
|
|
+//
|
|
|
+// This can be used to handle an error as general proto message defined by gRPC.
|
|
|
+// The response including body and status is not backward compatible with the default error handler.
|
|
|
+// When this option is used, HTTPError and OtherErrorHandler are overwritten on initialization.
|
|
|
+func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
|
|
|
+ return func(serveMux *ServeMux) {
|
|
|
+ serveMux.protoErrorHandler = fn
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// NewServeMux returns a new ServeMux whose internal mapping is empty.
|
|
|
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
|
|
|
serveMux := &ServeMux{
|
|
|
@@ -47,6 +113,29 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux {
|
|
|
for _, opt := range opts {
|
|
|
opt(serveMux)
|
|
|
}
|
|
|
+
|
|
|
+ if serveMux.protoErrorHandler != nil {
|
|
|
+ HTTPError = serveMux.protoErrorHandler
|
|
|
+ // OtherErrorHandler is no longer used when protoErrorHandler is set.
|
|
|
+ // Overwritten by a special error handler to return Unknown.
|
|
|
+ OtherErrorHandler = func(w http.ResponseWriter, r *http.Request, _ string, _ int) {
|
|
|
+ ctx := context.Background()
|
|
|
+ _, outboundMarshaler := MarshalerForRequest(serveMux, r)
|
|
|
+ sterr := status.Error(codes.Unknown, "unexpected use of OtherErrorHandler")
|
|
|
+ serveMux.protoErrorHandler(ctx, serveMux, outboundMarshaler, w, r, sterr)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if serveMux.incomingHeaderMatcher == nil {
|
|
|
+ serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
|
|
|
+ }
|
|
|
+
|
|
|
+ if serveMux.outgoingHeaderMatcher == nil {
|
|
|
+ serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
|
|
|
+ return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
return serveMux
|
|
|
}
|
|
|
|
|
|
@@ -57,9 +146,17 @@ func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
|
|
|
|
|
|
// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
|
|
|
func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
+ ctx := r.Context()
|
|
|
+
|
|
|
path := r.URL.Path
|
|
|
if !strings.HasPrefix(path, "/") {
|
|
|
- OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
|
|
+ if s.protoErrorHandler != nil {
|
|
|
+ _, outboundMarshaler := MarshalerForRequest(s, r)
|
|
|
+ sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
|
|
|
+ s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
|
|
|
+ } else {
|
|
|
+ OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
|
|
+ }
|
|
|
return
|
|
|
}
|
|
|
|
|
|
@@ -67,7 +164,13 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
l := len(components)
|
|
|
var verb string
|
|
|
if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
|
|
|
- OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
|
|
+ if s.protoErrorHandler != nil {
|
|
|
+ _, outboundMarshaler := MarshalerForRequest(s, r)
|
|
|
+ sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
|
|
|
+ s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
|
|
|
+ } else {
|
|
|
+ OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
|
|
+ }
|
|
|
return
|
|
|
} else if idx > 0 {
|
|
|
c := components[l-1]
|
|
|
@@ -77,7 +180,13 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && isPathLengthFallback(r) {
|
|
|
r.Method = strings.ToUpper(override)
|
|
|
if err := r.ParseForm(); err != nil {
|
|
|
- OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
|
|
|
+ if s.protoErrorHandler != nil {
|
|
|
+ _, outboundMarshaler := MarshalerForRequest(s, r)
|
|
|
+ sterr := status.Error(codes.InvalidArgument, err.Error())
|
|
|
+ s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
|
|
|
+ } else {
|
|
|
+ OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
|
|
|
+ }
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
@@ -104,17 +213,36 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
// X-HTTP-Method-Override is optional. Always allow fallback to POST.
|
|
|
if isPathLengthFallback(r) {
|
|
|
if err := r.ParseForm(); err != nil {
|
|
|
- OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
|
|
|
+ if s.protoErrorHandler != nil {
|
|
|
+ _, outboundMarshaler := MarshalerForRequest(s, r)
|
|
|
+ sterr := status.Error(codes.InvalidArgument, err.Error())
|
|
|
+ s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
|
|
|
+ } else {
|
|
|
+ OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
|
|
|
+ }
|
|
|
return
|
|
|
}
|
|
|
h.h(w, r, pathParams)
|
|
|
return
|
|
|
}
|
|
|
- OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
|
|
+ if s.protoErrorHandler != nil {
|
|
|
+ _, outboundMarshaler := MarshalerForRequest(s, r)
|
|
|
+ sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusMethodNotAllowed))
|
|
|
+ s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
|
|
|
+ } else {
|
|
|
+ OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
|
|
+ }
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
- OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
|
|
+
|
|
|
+ if s.protoErrorHandler != nil {
|
|
|
+ _, outboundMarshaler := MarshalerForRequest(s, r)
|
|
|
+ sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
|
|
|
+ s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
|
|
|
+ } else {
|
|
|
+ OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
|