handler.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. package runtime
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/textproto"
  7. "context"
  8. "github.com/golang/protobuf/proto"
  9. "github.com/golang/protobuf/ptypes/any"
  10. "github.com/grpc-ecosystem/grpc-gateway/runtime/internal"
  11. "google.golang.org/grpc/codes"
  12. "google.golang.org/grpc/grpclog"
  13. "google.golang.org/grpc/status"
  14. )
  15. // ForwardResponseStream forwards the stream from gRPC server to REST client.
  16. func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  17. f, ok := w.(http.Flusher)
  18. if !ok {
  19. grpclog.Printf("Flush not supported in %T", w)
  20. http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
  21. return
  22. }
  23. md, ok := ServerMetadataFromContext(ctx)
  24. if !ok {
  25. grpclog.Printf("Failed to extract ServerMetadata from context")
  26. http.Error(w, "unexpected error", http.StatusInternalServerError)
  27. return
  28. }
  29. handleForwardResponseServerMetadata(w, mux, md)
  30. w.Header().Set("Transfer-Encoding", "chunked")
  31. w.Header().Set("Content-Type", marshaler.ContentType())
  32. if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
  33. HTTPError(ctx, mux, marshaler, w, req, err)
  34. return
  35. }
  36. var delimiter []byte
  37. if d, ok := marshaler.(Delimited); ok {
  38. delimiter = d.Delimiter()
  39. } else {
  40. delimiter = []byte("\n")
  41. }
  42. var wroteHeader bool
  43. for {
  44. resp, err := recv()
  45. if err == io.EOF {
  46. return
  47. }
  48. if err != nil {
  49. handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
  50. return
  51. }
  52. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  53. handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
  54. return
  55. }
  56. buf, err := marshaler.Marshal(streamChunk(resp, nil))
  57. if err != nil {
  58. grpclog.Printf("Failed to marshal response chunk: %v", err)
  59. handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
  60. return
  61. }
  62. if _, err = w.Write(buf); err != nil {
  63. grpclog.Printf("Failed to send response chunk: %v", err)
  64. return
  65. }
  66. wroteHeader = true
  67. if _, err = w.Write(delimiter); err != nil {
  68. grpclog.Printf("Failed to send delimiter chunk: %v", err)
  69. return
  70. }
  71. f.Flush()
  72. }
  73. }
  74. func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
  75. for k, vs := range md.HeaderMD {
  76. if h, ok := mux.outgoingHeaderMatcher(k); ok {
  77. for _, v := range vs {
  78. w.Header().Add(h, v)
  79. }
  80. }
  81. }
  82. }
  83. func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
  84. for k := range md.TrailerMD {
  85. tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
  86. w.Header().Add("Trailer", tKey)
  87. }
  88. }
  89. func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
  90. for k, vs := range md.TrailerMD {
  91. tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
  92. for _, v := range vs {
  93. w.Header().Add(tKey, v)
  94. }
  95. }
  96. }
  97. // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
  98. func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  99. md, ok := ServerMetadataFromContext(ctx)
  100. if !ok {
  101. grpclog.Printf("Failed to extract ServerMetadata from context")
  102. }
  103. handleForwardResponseServerMetadata(w, mux, md)
  104. handleForwardResponseTrailerHeader(w, md)
  105. w.Header().Set("Content-Type", marshaler.ContentType())
  106. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  107. HTTPError(ctx, mux, marshaler, w, req, err)
  108. return
  109. }
  110. buf, err := marshaler.Marshal(resp)
  111. if err != nil {
  112. grpclog.Printf("Marshal error: %v", err)
  113. HTTPError(ctx, mux, marshaler, w, req, err)
  114. return
  115. }
  116. if _, err = w.Write(buf); err != nil {
  117. grpclog.Printf("Failed to write response: %v", err)
  118. }
  119. handleForwardResponseTrailer(w, md)
  120. }
  121. func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
  122. if len(opts) == 0 {
  123. return nil
  124. }
  125. for _, opt := range opts {
  126. if err := opt(ctx, w, resp); err != nil {
  127. grpclog.Printf("Error handling ForwardResponseOptions: %v", err)
  128. return err
  129. }
  130. }
  131. return nil
  132. }
  133. func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, err error) {
  134. buf, merr := marshaler.Marshal(streamChunk(nil, err))
  135. if merr != nil {
  136. grpclog.Printf("Failed to marshal an error: %v", merr)
  137. return
  138. }
  139. if !wroteHeader {
  140. s, ok := status.FromError(err)
  141. if !ok {
  142. s = status.New(codes.Unknown, err.Error())
  143. }
  144. w.WriteHeader(HTTPStatusFromCode(s.Code()))
  145. }
  146. if _, werr := w.Write(buf); werr != nil {
  147. grpclog.Printf("Failed to notify error to client: %v", werr)
  148. return
  149. }
  150. }
  151. func streamChunk(result proto.Message, err error) map[string]proto.Message {
  152. if err != nil {
  153. grpcCode := codes.Unknown
  154. grpcMessage := err.Error()
  155. var grpcDetails []*any.Any
  156. if s, ok := status.FromError(err); ok {
  157. grpcCode = s.Code()
  158. grpcMessage = s.Message()
  159. grpcDetails = s.Proto().GetDetails()
  160. }
  161. httpCode := HTTPStatusFromCode(grpcCode)
  162. return map[string]proto.Message{
  163. "error": &internal.StreamError{
  164. GrpcCode: int32(grpcCode),
  165. HttpCode: int32(httpCode),
  166. Message: grpcMessage,
  167. HttpStatus: http.StatusText(httpCode),
  168. Details: grpcDetails,
  169. },
  170. }
  171. }
  172. if result == nil {
  173. return streamChunk(nil, fmt.Errorf("empty response"))
  174. }
  175. return map[string]proto.Message{"result": result}
  176. }