reverse_test.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. /*
  2. Copyright 2014 CoreOS, Inc.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package proxy
  14. import (
  15. "bytes"
  16. "errors"
  17. "io/ioutil"
  18. "net/http"
  19. "net/http/httptest"
  20. "net/url"
  21. "reflect"
  22. "testing"
  23. )
  24. type staticRoundTripper struct {
  25. res *http.Response
  26. err error
  27. }
  28. func (srt *staticRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
  29. return srt.res, srt.err
  30. }
  31. func TestReverseProxyServe(t *testing.T) {
  32. u := url.URL{Scheme: "http", Host: "192.0.2.3:4040"}
  33. tests := []struct {
  34. eps []*endpoint
  35. rt http.RoundTripper
  36. want int
  37. }{
  38. // no endpoints available so no requests are even made
  39. {
  40. eps: []*endpoint{},
  41. rt: &staticRoundTripper{
  42. res: &http.Response{
  43. StatusCode: http.StatusCreated,
  44. Body: ioutil.NopCloser(&bytes.Reader{}),
  45. },
  46. },
  47. want: http.StatusServiceUnavailable,
  48. },
  49. // error is returned from one endpoint that should be available
  50. {
  51. eps: []*endpoint{&endpoint{URL: u, Available: true}},
  52. rt: &staticRoundTripper{err: errors.New("what a bad trip")},
  53. want: http.StatusBadGateway,
  54. },
  55. // endpoint is available and returns success
  56. {
  57. eps: []*endpoint{&endpoint{URL: u, Available: true}},
  58. rt: &staticRoundTripper{
  59. res: &http.Response{
  60. StatusCode: http.StatusCreated,
  61. Body: ioutil.NopCloser(&bytes.Reader{}),
  62. },
  63. },
  64. want: http.StatusCreated,
  65. },
  66. }
  67. for i, tt := range tests {
  68. rp := reverseProxy{
  69. director: &director{ep: tt.eps},
  70. transport: tt.rt,
  71. }
  72. req, _ := http.NewRequest("GET", "http://192.0.2.2:4001", nil)
  73. rr := httptest.NewRecorder()
  74. rp.ServeHTTP(rr, req)
  75. if rr.Code != tt.want {
  76. t.Errorf("#%d: unexpected HTTP status code: want = %d, got = %d", i, tt.want, rr.Code)
  77. }
  78. }
  79. }
  80. func TestRedirectRequest(t *testing.T) {
  81. loc := url.URL{
  82. Scheme: "http",
  83. Host: "bar.example.com",
  84. }
  85. req := &http.Request{
  86. Method: "GET",
  87. Host: "foo.example.com",
  88. URL: &url.URL{
  89. Host: "foo.example.com",
  90. Path: "/v2/keys/baz",
  91. },
  92. }
  93. redirectRequest(req, loc)
  94. want := &http.Request{
  95. Method: "GET",
  96. // this field must not change
  97. Host: "foo.example.com",
  98. URL: &url.URL{
  99. // the Scheme field is updated to that of the provided URL
  100. Scheme: "http",
  101. // the Host field is updated to that of the provided URL
  102. Host: "bar.example.com",
  103. Path: "/v2/keys/baz",
  104. },
  105. }
  106. if !reflect.DeepEqual(want, req) {
  107. t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
  108. }
  109. }
  110. func TestMaybeSetForwardedFor(t *testing.T) {
  111. tests := []struct {
  112. raddr string
  113. fwdFor string
  114. want string
  115. }{
  116. {"192.0.2.3:8002", "", "192.0.2.3"},
  117. {"192.0.2.3:8002", "192.0.2.2", "192.0.2.2, 192.0.2.3"},
  118. {"192.0.2.3:8002", "192.0.2.1, 192.0.2.2", "192.0.2.1, 192.0.2.2, 192.0.2.3"},
  119. {"example.com:8002", "", "example.com"},
  120. // While these cases look valid, golang net/http will not let it happen
  121. // The RemoteAddr field will always be a valid host:port
  122. {":8002", "", ""},
  123. {"192.0.2.3", "", ""},
  124. // blatantly invalid host w/o a port
  125. {"12", "", ""},
  126. {"12", "192.0.2.3", "192.0.2.3"},
  127. }
  128. for i, tt := range tests {
  129. req := &http.Request{
  130. RemoteAddr: tt.raddr,
  131. Header: make(http.Header),
  132. }
  133. if tt.fwdFor != "" {
  134. req.Header.Set("X-Forwarded-For", tt.fwdFor)
  135. }
  136. maybeSetForwardedFor(req)
  137. got := req.Header.Get("X-Forwarded-For")
  138. if tt.want != got {
  139. t.Errorf("#%d: incorrect header: want = %q, got = %q", i, tt.want, got)
  140. }
  141. }
  142. }
  143. func TestRemoveSingleHopHeaders(t *testing.T) {
  144. hdr := http.Header(map[string][]string{
  145. // single-hop headers that should be removed
  146. "Connection": []string{"close"},
  147. "Keep-Alive": []string{"foo"},
  148. "Proxy-Authenticate": []string{"Basic realm=example.com"},
  149. "Proxy-Authorization": []string{"foo"},
  150. "Te": []string{"deflate,gzip"},
  151. "Trailers": []string{"ETag"},
  152. "Transfer-Encoding": []string{"chunked"},
  153. "Upgrade": []string{"WebSocket"},
  154. // headers that should persist
  155. "Accept": []string{"application/json"},
  156. "X-Foo": []string{"Bar"},
  157. })
  158. removeSingleHopHeaders(&hdr)
  159. want := http.Header(map[string][]string{
  160. "Accept": []string{"application/json"},
  161. "X-Foo": []string{"Bar"},
  162. })
  163. if !reflect.DeepEqual(want, hdr) {
  164. t.Fatalf("unexpected result: want = %#v, got = %#v", want, hdr)
  165. }
  166. }
  167. func TestCopyHeader(t *testing.T) {
  168. tests := []struct {
  169. src http.Header
  170. dst http.Header
  171. want http.Header
  172. }{
  173. {
  174. src: http.Header(map[string][]string{
  175. "Foo": []string{"bar", "baz"},
  176. }),
  177. dst: http.Header(map[string][]string{}),
  178. want: http.Header(map[string][]string{
  179. "Foo": []string{"bar", "baz"},
  180. }),
  181. },
  182. {
  183. src: http.Header(map[string][]string{
  184. "Foo": []string{"bar"},
  185. "Ping": []string{"pong"},
  186. }),
  187. dst: http.Header(map[string][]string{}),
  188. want: http.Header(map[string][]string{
  189. "Foo": []string{"bar"},
  190. "Ping": []string{"pong"},
  191. }),
  192. },
  193. {
  194. src: http.Header(map[string][]string{
  195. "Foo": []string{"bar", "baz"},
  196. }),
  197. dst: http.Header(map[string][]string{
  198. "Foo": []string{"qux"},
  199. }),
  200. want: http.Header(map[string][]string{
  201. "Foo": []string{"qux", "bar", "baz"},
  202. }),
  203. },
  204. }
  205. for i, tt := range tests {
  206. copyHeader(tt.dst, tt.src)
  207. if !reflect.DeepEqual(tt.dst, tt.want) {
  208. t.Errorf("#%d: unexpected headers: want = %v, got = %v", i, tt.want, tt.dst)
  209. }
  210. }
  211. }