reverse_test.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. package proxy
  2. import (
  3. "bytes"
  4. "errors"
  5. "io/ioutil"
  6. "net/http"
  7. "net/http/httptest"
  8. "net/url"
  9. "reflect"
  10. "testing"
  11. )
  12. type staticRoundTripper struct {
  13. res *http.Response
  14. err error
  15. }
  16. func (srt *staticRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
  17. return srt.res, srt.err
  18. }
  19. func TestReverseProxyServe(t *testing.T) {
  20. u := url.URL{Scheme: "http", Host: "192.0.2.3:4040"}
  21. tests := []struct {
  22. eps []*endpoint
  23. rt http.RoundTripper
  24. want int
  25. }{
  26. // no endpoints available so no requests are even made
  27. {
  28. eps: []*endpoint{},
  29. rt: &staticRoundTripper{
  30. res: &http.Response{
  31. StatusCode: http.StatusCreated,
  32. Body: ioutil.NopCloser(&bytes.Reader{}),
  33. },
  34. },
  35. want: http.StatusServiceUnavailable,
  36. },
  37. // error is returned from one endpoint that should be available
  38. {
  39. eps: []*endpoint{&endpoint{URL: u, Available: true}},
  40. rt: &staticRoundTripper{err: errors.New("what a bad trip")},
  41. want: http.StatusBadGateway,
  42. },
  43. // endpoint is available and returns success
  44. {
  45. eps: []*endpoint{&endpoint{URL: u, Available: true}},
  46. rt: &staticRoundTripper{
  47. res: &http.Response{
  48. StatusCode: http.StatusCreated,
  49. Body: ioutil.NopCloser(&bytes.Reader{}),
  50. },
  51. },
  52. want: http.StatusCreated,
  53. },
  54. }
  55. for i, tt := range tests {
  56. rp := reverseProxy{
  57. director: &director{tt.eps},
  58. transport: tt.rt,
  59. }
  60. req, _ := http.NewRequest("GET", "http://192.0.2.2:4001", nil)
  61. rr := httptest.NewRecorder()
  62. rp.ServeHTTP(rr, req)
  63. if rr.Code != tt.want {
  64. t.Errorf("#%d: unexpected HTTP status code: want = %d, got = %d", i, tt.want, rr.Code)
  65. }
  66. }
  67. }
  68. func TestRedirectRequest(t *testing.T) {
  69. loc := url.URL{
  70. Scheme: "http",
  71. Host: "bar.example.com",
  72. }
  73. req := &http.Request{
  74. Method: "GET",
  75. Host: "foo.example.com",
  76. URL: &url.URL{
  77. Host: "foo.example.com",
  78. Path: "/v2/keys/baz",
  79. },
  80. }
  81. redirectRequest(req, loc)
  82. want := &http.Request{
  83. Method: "GET",
  84. // this field must not change
  85. Host: "foo.example.com",
  86. URL: &url.URL{
  87. // the Scheme field is updated to that of the provided URL
  88. Scheme: "http",
  89. // the Host field is updated to that of the provided URL
  90. Host: "bar.example.com",
  91. Path: "/v2/keys/baz",
  92. },
  93. }
  94. if !reflect.DeepEqual(want, req) {
  95. t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
  96. }
  97. }
  98. func TestMaybeSetForwardedFor(t *testing.T) {
  99. tests := []struct {
  100. raddr string
  101. fwdFor string
  102. want string
  103. }{
  104. {"192.0.2.3:8002", "", "192.0.2.3"},
  105. {"192.0.2.3:8002", "192.0.2.2", "192.0.2.2, 192.0.2.3"},
  106. {"192.0.2.3:8002", "192.0.2.1, 192.0.2.2", "192.0.2.1, 192.0.2.2, 192.0.2.3"},
  107. {"example.com:8002", "", "example.com"},
  108. // While these cases look valid, golang net/http will not let it happen
  109. // The RemoteAddr field will always be a valid host:port
  110. {":8002", "", ""},
  111. {"192.0.2.3", "", ""},
  112. // blatantly invalid host w/o a port
  113. {"12", "", ""},
  114. {"12", "192.0.2.3", "192.0.2.3"},
  115. }
  116. for i, tt := range tests {
  117. req := &http.Request{
  118. RemoteAddr: tt.raddr,
  119. Header: make(http.Header),
  120. }
  121. if tt.fwdFor != "" {
  122. req.Header.Set("X-Forwarded-For", tt.fwdFor)
  123. }
  124. maybeSetForwardedFor(req)
  125. got := req.Header.Get("X-Forwarded-For")
  126. if tt.want != got {
  127. t.Errorf("#%d: incorrect header: want = %q, got = %q", i, tt.want, got)
  128. }
  129. }
  130. }
  131. func TestRemoveSingleHopHeaders(t *testing.T) {
  132. hdr := http.Header(map[string][]string{
  133. // single-hop headers that should be removed
  134. "Connection": []string{"close"},
  135. "Keep-Alive": []string{"foo"},
  136. "Proxy-Authenticate": []string{"Basic realm=example.com"},
  137. "Proxy-Authorization": []string{"foo"},
  138. "Te": []string{"deflate,gzip"},
  139. "Trailers": []string{"ETag"},
  140. "Transfer-Encoding": []string{"chunked"},
  141. "Upgrade": []string{"WebSocket"},
  142. // headers that should persist
  143. "Accept": []string{"application/json"},
  144. "X-Foo": []string{"Bar"},
  145. })
  146. removeSingleHopHeaders(&hdr)
  147. want := http.Header(map[string][]string{
  148. "Accept": []string{"application/json"},
  149. "X-Foo": []string{"Bar"},
  150. })
  151. if !reflect.DeepEqual(want, hdr) {
  152. t.Fatalf("unexpected result: want = %#v, got = %#v", want, hdr)
  153. }
  154. }
  155. func TestCopyHeader(t *testing.T) {
  156. tests := []struct {
  157. src http.Header
  158. dst http.Header
  159. want http.Header
  160. }{
  161. {
  162. src: http.Header(map[string][]string{
  163. "Foo": []string{"bar", "baz"},
  164. }),
  165. dst: http.Header(map[string][]string{}),
  166. want: http.Header(map[string][]string{
  167. "Foo": []string{"bar", "baz"},
  168. }),
  169. },
  170. {
  171. src: http.Header(map[string][]string{
  172. "Foo": []string{"bar"},
  173. "Ping": []string{"pong"},
  174. }),
  175. dst: http.Header(map[string][]string{}),
  176. want: http.Header(map[string][]string{
  177. "Foo": []string{"bar"},
  178. "Ping": []string{"pong"},
  179. }),
  180. },
  181. {
  182. src: http.Header(map[string][]string{
  183. "Foo": []string{"bar", "baz"},
  184. }),
  185. dst: http.Header(map[string][]string{
  186. "Foo": []string{"qux"},
  187. }),
  188. want: http.Header(map[string][]string{
  189. "Foo": []string{"qux", "bar", "baz"},
  190. }),
  191. },
  192. }
  193. for i, tt := range tests {
  194. copyHeader(tt.dst, tt.src)
  195. if !reflect.DeepEqual(tt.dst, tt.want) {
  196. t.Errorf("#%d: unexpected headers: want = %v, got = %v", i, tt.want, tt.dst)
  197. }
  198. }
  199. }