handler.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. package runtime
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/textproto"
  8. "context"
  9. "github.com/golang/protobuf/proto"
  10. "github.com/grpc-ecosystem/grpc-gateway/internal"
  11. "google.golang.org/grpc/grpclog"
  12. )
  13. var errEmptyResponse = errors.New("empty response")
  14. // ForwardResponseStream forwards the stream from gRPC server to REST client.
  15. func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  16. f, ok := w.(http.Flusher)
  17. if !ok {
  18. grpclog.Infof("Flush not supported in %T", w)
  19. http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
  20. return
  21. }
  22. md, ok := ServerMetadataFromContext(ctx)
  23. if !ok {
  24. grpclog.Infof("Failed to extract ServerMetadata from context")
  25. http.Error(w, "unexpected error", http.StatusInternalServerError)
  26. return
  27. }
  28. handleForwardResponseServerMetadata(w, mux, md)
  29. w.Header().Set("Transfer-Encoding", "chunked")
  30. w.Header().Set("Content-Type", marshaler.ContentType())
  31. if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
  32. HTTPError(ctx, mux, marshaler, w, req, err)
  33. return
  34. }
  35. var delimiter []byte
  36. if d, ok := marshaler.(Delimited); ok {
  37. delimiter = d.Delimiter()
  38. } else {
  39. delimiter = []byte("\n")
  40. }
  41. var wroteHeader bool
  42. for {
  43. resp, err := recv()
  44. if err == io.EOF {
  45. return
  46. }
  47. if err != nil {
  48. handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
  49. return
  50. }
  51. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  52. handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
  53. return
  54. }
  55. buf, err := marshaler.Marshal(streamChunk(ctx, resp, mux.streamErrorHandler))
  56. if err != nil {
  57. grpclog.Infof("Failed to marshal response chunk: %v", err)
  58. handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
  59. return
  60. }
  61. if _, err = w.Write(buf); err != nil {
  62. grpclog.Infof("Failed to send response chunk: %v", err)
  63. return
  64. }
  65. wroteHeader = true
  66. if _, err = w.Write(delimiter); err != nil {
  67. grpclog.Infof("Failed to send delimiter chunk: %v", err)
  68. return
  69. }
  70. f.Flush()
  71. }
  72. }
  73. func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
  74. for k, vs := range md.HeaderMD {
  75. if h, ok := mux.outgoingHeaderMatcher(k); ok {
  76. for _, v := range vs {
  77. w.Header().Add(h, v)
  78. }
  79. }
  80. }
  81. }
  82. func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
  83. for k := range md.TrailerMD {
  84. tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
  85. w.Header().Add("Trailer", tKey)
  86. }
  87. }
  88. func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
  89. for k, vs := range md.TrailerMD {
  90. tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
  91. for _, v := range vs {
  92. w.Header().Add(tKey, v)
  93. }
  94. }
  95. }
  96. // responseBody interface contains method for getting field for marshaling to the response body
  97. // this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
  98. type responseBody interface {
  99. XXX_ResponseBody() interface{}
  100. }
  101. // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
  102. func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  103. md, ok := ServerMetadataFromContext(ctx)
  104. if !ok {
  105. grpclog.Infof("Failed to extract ServerMetadata from context")
  106. }
  107. handleForwardResponseServerMetadata(w, mux, md)
  108. handleForwardResponseTrailerHeader(w, md)
  109. contentType := marshaler.ContentType()
  110. // Check marshaler on run time in order to keep backwards compatability
  111. // An interface param needs to be added to the ContentType() function on
  112. // the Marshal interface to be able to remove this check
  113. if httpBodyMarshaler, ok := marshaler.(*HTTPBodyMarshaler); ok {
  114. contentType = httpBodyMarshaler.ContentTypeFromMessage(resp)
  115. }
  116. w.Header().Set("Content-Type", contentType)
  117. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  118. HTTPError(ctx, mux, marshaler, w, req, err)
  119. return
  120. }
  121. var buf []byte
  122. var err error
  123. if rb, ok := resp.(responseBody); ok {
  124. buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
  125. } else {
  126. buf, err = marshaler.Marshal(resp)
  127. }
  128. if err != nil {
  129. grpclog.Infof("Marshal error: %v", err)
  130. HTTPError(ctx, mux, marshaler, w, req, err)
  131. return
  132. }
  133. if _, err = w.Write(buf); err != nil {
  134. grpclog.Infof("Failed to write response: %v", err)
  135. }
  136. handleForwardResponseTrailer(w, md)
  137. }
  138. func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
  139. if len(opts) == 0 {
  140. return nil
  141. }
  142. for _, opt := range opts {
  143. if err := opt(ctx, w, resp); err != nil {
  144. grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
  145. return err
  146. }
  147. }
  148. return nil
  149. }
  150. func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
  151. serr := streamError(ctx, mux.streamErrorHandler, err)
  152. if !wroteHeader {
  153. w.WriteHeader(int(serr.HttpCode))
  154. }
  155. buf, merr := marshaler.Marshal(errorChunk(serr))
  156. if merr != nil {
  157. grpclog.Infof("Failed to marshal an error: %v", merr)
  158. return
  159. }
  160. if _, werr := w.Write(buf); werr != nil {
  161. grpclog.Infof("Failed to notify error to client: %v", werr)
  162. return
  163. }
  164. }
  165. // streamChunk returns a chunk in a response stream for the given result. The
  166. // given errHandler is used to render an error chunk if result is nil.
  167. func streamChunk(ctx context.Context, result proto.Message, errHandler StreamErrorHandlerFunc) map[string]proto.Message {
  168. if result == nil {
  169. return errorChunk(streamError(ctx, errHandler, errEmptyResponse))
  170. }
  171. return map[string]proto.Message{"result": result}
  172. }
  173. // streamError returns the payload for the final message in a response stream
  174. // that represents the given err.
  175. func streamError(ctx context.Context, errHandler StreamErrorHandlerFunc, err error) *StreamError {
  176. serr := errHandler(ctx, err)
  177. if serr != nil {
  178. return serr
  179. }
  180. // TODO: log about misbehaving stream error handler?
  181. return DefaultHTTPStreamErrorHandler(ctx, err)
  182. }
  183. func errorChunk(err *StreamError) map[string]proto.Message {
  184. return map[string]proto.Message{"error": (*internal.StreamError)(err)}
  185. }