handler.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package runtime
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/textproto"
  7. "github.com/golang/protobuf/proto"
  8. "github.com/grpc-ecosystem/grpc-gateway/runtime/internal"
  9. "golang.org/x/net/context"
  10. "google.golang.org/grpc"
  11. "google.golang.org/grpc/grpclog"
  12. )
  13. // ForwardResponseStream forwards the stream from gRPC server to REST client.
  14. func ForwardResponseStream(ctx context.Context, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  15. f, ok := w.(http.Flusher)
  16. if !ok {
  17. grpclog.Printf("Flush not supported in %T", w)
  18. http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
  19. return
  20. }
  21. md, ok := ServerMetadataFromContext(ctx)
  22. if !ok {
  23. grpclog.Printf("Failed to extract ServerMetadata from context")
  24. http.Error(w, "unexpected error", http.StatusInternalServerError)
  25. return
  26. }
  27. handleForwardResponseServerMetadata(w, md)
  28. w.Header().Set("Transfer-Encoding", "chunked")
  29. w.Header().Set("Content-Type", marshaler.ContentType())
  30. if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
  31. http.Error(w, err.Error(), http.StatusInternalServerError)
  32. return
  33. }
  34. w.WriteHeader(http.StatusOK)
  35. f.Flush()
  36. for {
  37. resp, err := recv()
  38. if err == io.EOF {
  39. return
  40. }
  41. if err != nil {
  42. handleForwardResponseStreamError(marshaler, w, err)
  43. return
  44. }
  45. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  46. handleForwardResponseStreamError(marshaler, w, err)
  47. return
  48. }
  49. buf, err := marshaler.Marshal(streamChunk(resp, nil))
  50. if err != nil {
  51. grpclog.Printf("Failed to marshal response chunk: %v", err)
  52. return
  53. }
  54. if _, err = fmt.Fprintf(w, "%s\n", buf); err != nil {
  55. grpclog.Printf("Failed to send response chunk: %v", err)
  56. return
  57. }
  58. f.Flush()
  59. }
  60. }
  61. func handleForwardResponseServerMetadata(w http.ResponseWriter, md ServerMetadata) {
  62. for k, vs := range md.HeaderMD {
  63. hKey := fmt.Sprintf("%s%s", MetadataHeaderPrefix, k)
  64. for i := range vs {
  65. w.Header().Add(hKey, vs[i])
  66. }
  67. }
  68. }
  69. func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
  70. for k := range md.TrailerMD {
  71. tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
  72. w.Header().Add("Trailer", tKey)
  73. }
  74. }
  75. func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
  76. for k, vs := range md.TrailerMD {
  77. tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
  78. for i := range vs {
  79. w.Header().Add(tKey, vs[i])
  80. }
  81. }
  82. }
  83. // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
  84. func ForwardResponseMessage(ctx context.Context, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  85. md, ok := ServerMetadataFromContext(ctx)
  86. if !ok {
  87. grpclog.Printf("Failed to extract ServerMetadata from context")
  88. }
  89. handleForwardResponseServerMetadata(w, md)
  90. handleForwardResponseTrailerHeader(w, md)
  91. w.Header().Set("Content-Type", marshaler.ContentType())
  92. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  93. HTTPError(ctx, marshaler, w, req, err)
  94. return
  95. }
  96. buf, err := marshaler.Marshal(resp)
  97. if err != nil {
  98. grpclog.Printf("Marshal error: %v", err)
  99. HTTPError(ctx, marshaler, w, req, err)
  100. return
  101. }
  102. if _, err = w.Write(buf); err != nil {
  103. grpclog.Printf("Failed to write response: %v", err)
  104. }
  105. handleForwardResponseTrailer(w, md)
  106. }
  107. func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
  108. if len(opts) == 0 {
  109. return nil
  110. }
  111. for _, opt := range opts {
  112. if err := opt(ctx, w, resp); err != nil {
  113. grpclog.Printf("Error handling ForwardResponseOptions: %v", err)
  114. return err
  115. }
  116. }
  117. return nil
  118. }
  119. func handleForwardResponseStreamError(marshaler Marshaler, w http.ResponseWriter, err error) {
  120. buf, merr := marshaler.Marshal(streamChunk(nil, err))
  121. if merr != nil {
  122. grpclog.Printf("Failed to marshal an error: %v", merr)
  123. return
  124. }
  125. if _, werr := fmt.Fprintf(w, "%s\n", buf); werr != nil {
  126. grpclog.Printf("Failed to notify error to client: %v", werr)
  127. return
  128. }
  129. }
  130. func streamChunk(result proto.Message, err error) map[string]proto.Message {
  131. if err != nil {
  132. grpcCode := grpc.Code(err)
  133. httpCode := HTTPStatusFromCode(grpcCode)
  134. return map[string]proto.Message{
  135. "error": &internal.StreamError{
  136. GrpcCode: int32(grpcCode),
  137. HttpCode: int32(httpCode),
  138. Message: err.Error(),
  139. HttpStatus: http.StatusText(httpCode),
  140. },
  141. }
  142. }
  143. if result == nil {
  144. return streamChunk(nil, fmt.Errorf("empty response"))
  145. }
  146. return map[string]proto.Message{"result": result}
  147. }