handler.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package runtime
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/textproto"
  7. "github.com/golang/protobuf/proto"
  8. "github.com/grpc-ecosystem/grpc-gateway/runtime/internal"
  9. "golang.org/x/net/context"
  10. "google.golang.org/grpc/codes"
  11. "google.golang.org/grpc/grpclog"
  12. "google.golang.org/grpc/status"
  13. )
  14. // ForwardResponseStream forwards the stream from gRPC server to REST client.
  15. func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  16. f, ok := w.(http.Flusher)
  17. if !ok {
  18. grpclog.Printf("Flush not supported in %T", w)
  19. http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
  20. return
  21. }
  22. md, ok := ServerMetadataFromContext(ctx)
  23. if !ok {
  24. grpclog.Printf("Failed to extract ServerMetadata from context")
  25. http.Error(w, "unexpected error", http.StatusInternalServerError)
  26. return
  27. }
  28. handleForwardResponseServerMetadata(w, mux, md)
  29. w.Header().Set("Transfer-Encoding", "chunked")
  30. w.Header().Set("Content-Type", marshaler.ContentType())
  31. if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
  32. HTTPError(ctx, mux, marshaler, w, req, err)
  33. return
  34. }
  35. var delimiter []byte
  36. if d, ok := marshaler.(Delimited); ok {
  37. delimiter = d.Delimiter()
  38. } else {
  39. delimiter = []byte("\n")
  40. }
  41. var wroteHeader bool
  42. for {
  43. resp, err := recv()
  44. if err == io.EOF {
  45. return
  46. }
  47. if err != nil {
  48. handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
  49. return
  50. }
  51. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  52. handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
  53. return
  54. }
  55. buf, err := marshaler.Marshal(streamChunk(resp, nil))
  56. if err != nil {
  57. grpclog.Printf("Failed to marshal response chunk: %v", err)
  58. handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
  59. return
  60. }
  61. if _, err = w.Write(buf); err != nil {
  62. grpclog.Printf("Failed to send response chunk: %v", err)
  63. return
  64. }
  65. wroteHeader = true
  66. if _, err = w.Write(delimiter); err != nil {
  67. grpclog.Printf("Failed to send delimiter chunk: %v", err)
  68. return
  69. }
  70. f.Flush()
  71. }
  72. }
  73. func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
  74. for k, vs := range md.HeaderMD {
  75. if h, ok := mux.outgoingHeaderMatcher(k); ok {
  76. for _, v := range vs {
  77. w.Header().Add(h, v)
  78. }
  79. }
  80. }
  81. }
  82. func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
  83. for k := range md.TrailerMD {
  84. tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
  85. w.Header().Add("Trailer", tKey)
  86. }
  87. }
  88. func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
  89. for k, vs := range md.TrailerMD {
  90. tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
  91. for _, v := range vs {
  92. w.Header().Add(tKey, v)
  93. }
  94. }
  95. }
  96. // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
  97. func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  98. md, ok := ServerMetadataFromContext(ctx)
  99. if !ok {
  100. grpclog.Printf("Failed to extract ServerMetadata from context")
  101. }
  102. handleForwardResponseServerMetadata(w, mux, md)
  103. handleForwardResponseTrailerHeader(w, md)
  104. w.Header().Set("Content-Type", marshaler.ContentType())
  105. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  106. HTTPError(ctx, mux, marshaler, w, req, err)
  107. return
  108. }
  109. buf, err := marshaler.Marshal(resp)
  110. if err != nil {
  111. grpclog.Printf("Marshal error: %v", err)
  112. HTTPError(ctx, mux, marshaler, w, req, err)
  113. return
  114. }
  115. if _, err = w.Write(buf); err != nil {
  116. grpclog.Printf("Failed to write response: %v", err)
  117. }
  118. handleForwardResponseTrailer(w, md)
  119. }
  120. func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
  121. if len(opts) == 0 {
  122. return nil
  123. }
  124. for _, opt := range opts {
  125. if err := opt(ctx, w, resp); err != nil {
  126. grpclog.Printf("Error handling ForwardResponseOptions: %v", err)
  127. return err
  128. }
  129. }
  130. return nil
  131. }
  132. func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, err error) {
  133. buf, merr := marshaler.Marshal(streamChunk(nil, err))
  134. if merr != nil {
  135. grpclog.Printf("Failed to marshal an error: %v", merr)
  136. return
  137. }
  138. if !wroteHeader {
  139. s, ok := status.FromError(err)
  140. if !ok {
  141. s = status.New(codes.Unknown, err.Error())
  142. }
  143. w.WriteHeader(HTTPStatusFromCode(s.Code()))
  144. }
  145. if _, werr := w.Write(buf); werr != nil {
  146. grpclog.Printf("Failed to notify error to client: %v", werr)
  147. return
  148. }
  149. }
  150. func streamChunk(result proto.Message, err error) map[string]proto.Message {
  151. if err != nil {
  152. grpcCode := codes.Unknown
  153. if s, ok := status.FromError(err); ok {
  154. grpcCode = s.Code()
  155. }
  156. httpCode := HTTPStatusFromCode(grpcCode)
  157. return map[string]proto.Message{
  158. "error": &internal.StreamError{
  159. GrpcCode: int32(grpcCode),
  160. HttpCode: int32(httpCode),
  161. Message: err.Error(),
  162. HttpStatus: http.StatusText(httpCode),
  163. },
  164. }
  165. }
  166. if result == nil {
  167. return streamChunk(nil, fmt.Errorf("empty response"))
  168. }
  169. return map[string]proto.Message{"result": result}
  170. }