context.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package runtime
  2. import (
  3. "fmt"
  4. "net"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "golang.org/x/net/context"
  10. "google.golang.org/grpc"
  11. "google.golang.org/grpc/codes"
  12. "google.golang.org/grpc/grpclog"
  13. "google.golang.org/grpc/metadata"
  14. )
  15. const metadataHeaderPrefix = "Grpc-Metadata-"
  16. const metadataTrailerPrefix = "Grpc-Trailer-"
  17. const metadataGrpcTimeout = "Grpc-Timeout"
  18. const xForwardedFor = "X-Forwarded-For"
  19. const xForwardedHost = "X-Forwarded-Host"
  20. var (
  21. // DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
  22. // header isn't present. If the value is 0 the sent `context` will not have a timeout.
  23. DefaultContextTimeout = 0 * time.Second
  24. )
  25. /*
  26. AnnotateContext adds context information such as metadata from the request.
  27. At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
  28. except that the forwarded destination is not another HTTP service but rather
  29. a gRPC service.
  30. */
  31. func AnnotateContext(ctx context.Context, req *http.Request) (context.Context, error) {
  32. var pairs []string
  33. timeout := DefaultContextTimeout
  34. if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
  35. var err error
  36. timeout, err = timeoutDecode(tm)
  37. if err != nil {
  38. return nil, grpc.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
  39. }
  40. }
  41. for key, vals := range req.Header {
  42. for _, val := range vals {
  43. if key == "Authorization" {
  44. pairs = append(pairs, "authorization", val)
  45. continue
  46. }
  47. if strings.HasPrefix(key, metadataHeaderPrefix) {
  48. pairs = append(pairs, key[len(metadataHeaderPrefix):], val)
  49. }
  50. }
  51. }
  52. if host := req.Header.Get(xForwardedHost); host != "" {
  53. pairs = append(pairs, strings.ToLower(xForwardedHost), host)
  54. } else if req.Host != "" {
  55. pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
  56. }
  57. if addr := req.RemoteAddr; addr != "" {
  58. if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
  59. if fwd := req.Header.Get(xForwardedFor); fwd == "" {
  60. pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
  61. } else {
  62. pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
  63. }
  64. } else {
  65. grpclog.Printf("invalid remote addr: %s", addr)
  66. }
  67. }
  68. if timeout != 0 {
  69. ctx, _ = context.WithTimeout(ctx, timeout)
  70. }
  71. if len(pairs) == 0 {
  72. return ctx, nil
  73. }
  74. return metadata.NewContext(ctx, metadata.Pairs(pairs...)), nil
  75. }
  76. // ServerMetadata consists of metadata sent from gRPC server.
  77. type ServerMetadata struct {
  78. HeaderMD metadata.MD
  79. TrailerMD metadata.MD
  80. }
  81. type serverMetadataKey struct{}
  82. // NewServerMetadataContext creates a new context with ServerMetadata
  83. func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
  84. return context.WithValue(ctx, serverMetadataKey{}, md)
  85. }
  86. // ServerMetadataFromContext returns the ServerMetadata in ctx
  87. func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
  88. md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
  89. return
  90. }
  91. func timeoutDecode(s string) (time.Duration, error) {
  92. size := len(s)
  93. if size < 2 {
  94. return 0, fmt.Errorf("timeout string is too short: %q", s)
  95. }
  96. d, ok := timeoutUnitToDuration(s[size-1])
  97. if !ok {
  98. return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
  99. }
  100. t, err := strconv.ParseInt(s[:size-1], 10, 64)
  101. if err != nil {
  102. return 0, err
  103. }
  104. return d * time.Duration(t), nil
  105. }
  106. func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
  107. switch u {
  108. case 'H':
  109. return time.Hour, true
  110. case 'M':
  111. return time.Minute, true
  112. case 'S':
  113. return time.Second, true
  114. case 'm':
  115. return time.Millisecond, true
  116. case 'u':
  117. return time.Microsecond, true
  118. case 'n':
  119. return time.Nanosecond, true
  120. default:
  121. }
  122. return
  123. }