context.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package runtime
  2. import (
  3. "fmt"
  4. "net"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "golang.org/x/net/context"
  10. "google.golang.org/grpc"
  11. "google.golang.org/grpc/codes"
  12. "google.golang.org/grpc/grpclog"
  13. "google.golang.org/grpc/metadata"
  14. )
  15. // MetadataHeaderPrefix is prepended to HTTP headers in order to convert them to
  16. // gRPC metadata for incoming requests processed by grpc-gateway
  17. const MetadataHeaderPrefix = "Grpc-Metadata-"
  18. // MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
  19. // HTTP headers in a response handled by grpc-gateway
  20. const MetadataTrailerPrefix = "Grpc-Trailer-"
  21. const metadataGrpcTimeout = "Grpc-Timeout"
  22. const xForwardedFor = "X-Forwarded-For"
  23. const xForwardedHost = "X-Forwarded-Host"
  24. var (
  25. // DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
  26. // header isn't present. If the value is 0 the sent `context` will not have a timeout.
  27. DefaultContextTimeout = 0 * time.Second
  28. )
  29. /*
  30. AnnotateContext adds context information such as metadata from the request.
  31. At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
  32. except that the forwarded destination is not another HTTP service but rather
  33. a gRPC service.
  34. */
  35. func AnnotateContext(ctx context.Context, req *http.Request) (context.Context, error) {
  36. var pairs []string
  37. timeout := DefaultContextTimeout
  38. if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
  39. var err error
  40. timeout, err = timeoutDecode(tm)
  41. if err != nil {
  42. return nil, grpc.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
  43. }
  44. }
  45. for key, vals := range req.Header {
  46. for _, val := range vals {
  47. if key == "Authorization" {
  48. pairs = append(pairs, "authorization", val)
  49. continue
  50. }
  51. if strings.HasPrefix(key, MetadataHeaderPrefix) {
  52. pairs = append(pairs, key[len(MetadataHeaderPrefix):], val)
  53. }
  54. }
  55. }
  56. if host := req.Header.Get(xForwardedHost); host != "" {
  57. pairs = append(pairs, strings.ToLower(xForwardedHost), host)
  58. } else if req.Host != "" {
  59. pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
  60. }
  61. if addr := req.RemoteAddr; addr != "" {
  62. if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
  63. if fwd := req.Header.Get(xForwardedFor); fwd == "" {
  64. pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
  65. } else {
  66. pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
  67. }
  68. } else {
  69. grpclog.Printf("invalid remote addr: %s", addr)
  70. }
  71. }
  72. if timeout != 0 {
  73. ctx, _ = context.WithTimeout(ctx, timeout)
  74. }
  75. if len(pairs) == 0 {
  76. return ctx, nil
  77. }
  78. return metadata.NewContext(ctx, metadata.Pairs(pairs...)), nil
  79. }
  80. // ServerMetadata consists of metadata sent from gRPC server.
  81. type ServerMetadata struct {
  82. HeaderMD metadata.MD
  83. TrailerMD metadata.MD
  84. }
  85. type serverMetadataKey struct{}
  86. // NewServerMetadataContext creates a new context with ServerMetadata
  87. func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
  88. return context.WithValue(ctx, serverMetadataKey{}, md)
  89. }
  90. // ServerMetadataFromContext returns the ServerMetadata in ctx
  91. func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
  92. md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
  93. return
  94. }
  95. func timeoutDecode(s string) (time.Duration, error) {
  96. size := len(s)
  97. if size < 2 {
  98. return 0, fmt.Errorf("timeout string is too short: %q", s)
  99. }
  100. d, ok := timeoutUnitToDuration(s[size-1])
  101. if !ok {
  102. return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
  103. }
  104. t, err := strconv.ParseInt(s[:size-1], 10, 64)
  105. if err != nil {
  106. return 0, err
  107. }
  108. return d * time.Duration(t), nil
  109. }
  110. func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
  111. switch u {
  112. case 'H':
  113. return time.Hour, true
  114. case 'M':
  115. return time.Minute, true
  116. case 'S':
  117. return time.Second, true
  118. case 'm':
  119. return time.Millisecond, true
  120. case 'u':
  121. return time.Microsecond, true
  122. case 'n':
  123. return time.Nanosecond, true
  124. default:
  125. }
  126. return
  127. }