propagator_test.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. package trace
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. "google.golang.org/grpc/metadata"
  8. )
  9. func TestHttpPropagator_Extract(t *testing.T) {
  10. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  11. req.Header.Set(traceIdKey, "trace")
  12. req.Header.Set(spanIdKey, "span")
  13. carrier, err := Extract(HttpFormat, req.Header)
  14. assert.Nil(t, err)
  15. assert.Equal(t, "trace", carrier.Get(traceIdKey))
  16. assert.Equal(t, "span", carrier.Get(spanIdKey))
  17. _, err = Extract(HttpFormat, req)
  18. assert.Equal(t, ErrInvalidCarrier, err)
  19. }
  20. func TestHttpPropagator_Inject(t *testing.T) {
  21. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  22. req.Header.Set(traceIdKey, "trace")
  23. req.Header.Set(spanIdKey, "span")
  24. carrier, err := Inject(HttpFormat, req.Header)
  25. assert.Nil(t, err)
  26. assert.Equal(t, "trace", carrier.Get(traceIdKey))
  27. assert.Equal(t, "span", carrier.Get(spanIdKey))
  28. _, err = Inject(HttpFormat, req)
  29. assert.Equal(t, ErrInvalidCarrier, err)
  30. }
  31. func TestGrpcPropagator_Extract(t *testing.T) {
  32. md := metadata.New(map[string]string{
  33. traceIdKey: "trace",
  34. spanIdKey: "span",
  35. })
  36. carrier, err := Extract(GrpcFormat, md)
  37. assert.Nil(t, err)
  38. assert.Equal(t, "trace", carrier.Get(traceIdKey))
  39. assert.Equal(t, "span", carrier.Get(spanIdKey))
  40. _, err = Extract(GrpcFormat, 1)
  41. assert.Equal(t, ErrInvalidCarrier, err)
  42. _, err = Extract(nil, 1)
  43. assert.Equal(t, ErrInvalidCarrier, err)
  44. }
  45. func TestGrpcPropagator_Inject(t *testing.T) {
  46. md := metadata.New(map[string]string{
  47. traceIdKey: "trace",
  48. spanIdKey: "span",
  49. })
  50. carrier, err := Inject(GrpcFormat, md)
  51. assert.Nil(t, err)
  52. assert.Equal(t, "trace", carrier.Get(traceIdKey))
  53. assert.Equal(t, "span", carrier.Get(spanIdKey))
  54. _, err = Inject(GrpcFormat, 1)
  55. assert.Equal(t, ErrInvalidCarrier, err)
  56. _, err = Inject(nil, 1)
  57. assert.Equal(t, ErrInvalidCarrier, err)
  58. }