reverse_test.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /*
  2. Copyright 2014 CoreOS, Inc.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package proxy
  14. import (
  15. "bytes"
  16. "errors"
  17. "io/ioutil"
  18. "net/http"
  19. "net/http/httptest"
  20. "net/url"
  21. "reflect"
  22. "testing"
  23. )
  24. type staticRoundTripper struct {
  25. res *http.Response
  26. err error
  27. }
  28. func (srt *staticRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
  29. return srt.res, srt.err
  30. }
  31. func TestReverseProxyServe(t *testing.T) {
  32. u := url.URL{Scheme: "http", Host: "192.0.2.3:4040"}
  33. tests := []struct {
  34. eps []*endpoint
  35. rt http.RoundTripper
  36. want int
  37. }{
  38. // no endpoints available so no requests are even made
  39. {
  40. eps: []*endpoint{},
  41. rt: &staticRoundTripper{
  42. res: &http.Response{
  43. StatusCode: http.StatusCreated,
  44. Body: ioutil.NopCloser(&bytes.Reader{}),
  45. },
  46. },
  47. want: http.StatusServiceUnavailable,
  48. },
  49. // error is returned from one endpoint that should be available
  50. {
  51. eps: []*endpoint{&endpoint{URL: u, Available: true}},
  52. rt: &staticRoundTripper{err: errors.New("what a bad trip")},
  53. want: http.StatusBadGateway,
  54. },
  55. // endpoint is available and returns success
  56. {
  57. eps: []*endpoint{&endpoint{URL: u, Available: true}},
  58. rt: &staticRoundTripper{
  59. res: &http.Response{
  60. StatusCode: http.StatusCreated,
  61. Body: ioutil.NopCloser(&bytes.Reader{}),
  62. Header: map[string][]string{"Content-Type": []string{"application/json"}},
  63. },
  64. },
  65. want: http.StatusCreated,
  66. },
  67. }
  68. for i, tt := range tests {
  69. rp := reverseProxy{
  70. director: &director{ep: tt.eps},
  71. transport: tt.rt,
  72. }
  73. req, _ := http.NewRequest("GET", "http://192.0.2.2:4001", nil)
  74. rr := httptest.NewRecorder()
  75. rp.ServeHTTP(rr, req)
  76. if rr.Code != tt.want {
  77. t.Errorf("#%d: unexpected HTTP status code: want = %d, got = %d", i, tt.want, rr.Code)
  78. }
  79. if gct := rr.Header().Get("Content-Type"); gct != "application/json" {
  80. t.Errorf("#%d: Content-Type = %s, want %s", i, gct, "application/json")
  81. }
  82. }
  83. }
  84. func TestRedirectRequest(t *testing.T) {
  85. loc := url.URL{
  86. Scheme: "http",
  87. Host: "bar.example.com",
  88. }
  89. req := &http.Request{
  90. Method: "GET",
  91. Host: "foo.example.com",
  92. URL: &url.URL{
  93. Host: "foo.example.com",
  94. Path: "/v2/keys/baz",
  95. },
  96. }
  97. redirectRequest(req, loc)
  98. want := &http.Request{
  99. Method: "GET",
  100. // this field must not change
  101. Host: "foo.example.com",
  102. URL: &url.URL{
  103. // the Scheme field is updated to that of the provided URL
  104. Scheme: "http",
  105. // the Host field is updated to that of the provided URL
  106. Host: "bar.example.com",
  107. Path: "/v2/keys/baz",
  108. },
  109. }
  110. if !reflect.DeepEqual(want, req) {
  111. t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
  112. }
  113. }
  114. func TestMaybeSetForwardedFor(t *testing.T) {
  115. tests := []struct {
  116. raddr string
  117. fwdFor string
  118. want string
  119. }{
  120. {"192.0.2.3:8002", "", "192.0.2.3"},
  121. {"192.0.2.3:8002", "192.0.2.2", "192.0.2.2, 192.0.2.3"},
  122. {"192.0.2.3:8002", "192.0.2.1, 192.0.2.2", "192.0.2.1, 192.0.2.2, 192.0.2.3"},
  123. {"example.com:8002", "", "example.com"},
  124. // While these cases look valid, golang net/http will not let it happen
  125. // The RemoteAddr field will always be a valid host:port
  126. {":8002", "", ""},
  127. {"192.0.2.3", "", ""},
  128. // blatantly invalid host w/o a port
  129. {"12", "", ""},
  130. {"12", "192.0.2.3", "192.0.2.3"},
  131. }
  132. for i, tt := range tests {
  133. req := &http.Request{
  134. RemoteAddr: tt.raddr,
  135. Header: make(http.Header),
  136. }
  137. if tt.fwdFor != "" {
  138. req.Header.Set("X-Forwarded-For", tt.fwdFor)
  139. }
  140. maybeSetForwardedFor(req)
  141. got := req.Header.Get("X-Forwarded-For")
  142. if tt.want != got {
  143. t.Errorf("#%d: incorrect header: want = %q, got = %q", i, tt.want, got)
  144. }
  145. }
  146. }
  147. func TestRemoveSingleHopHeaders(t *testing.T) {
  148. hdr := http.Header(map[string][]string{
  149. // single-hop headers that should be removed
  150. "Connection": []string{"close"},
  151. "Keep-Alive": []string{"foo"},
  152. "Proxy-Authenticate": []string{"Basic realm=example.com"},
  153. "Proxy-Authorization": []string{"foo"},
  154. "Te": []string{"deflate,gzip"},
  155. "Trailers": []string{"ETag"},
  156. "Transfer-Encoding": []string{"chunked"},
  157. "Upgrade": []string{"WebSocket"},
  158. // headers that should persist
  159. "Accept": []string{"application/json"},
  160. "X-Foo": []string{"Bar"},
  161. })
  162. removeSingleHopHeaders(&hdr)
  163. want := http.Header(map[string][]string{
  164. "Accept": []string{"application/json"},
  165. "X-Foo": []string{"Bar"},
  166. })
  167. if !reflect.DeepEqual(want, hdr) {
  168. t.Fatalf("unexpected result: want = %#v, got = %#v", want, hdr)
  169. }
  170. }
  171. func TestCopyHeader(t *testing.T) {
  172. tests := []struct {
  173. src http.Header
  174. dst http.Header
  175. want http.Header
  176. }{
  177. {
  178. src: http.Header(map[string][]string{
  179. "Foo": []string{"bar", "baz"},
  180. }),
  181. dst: http.Header(map[string][]string{}),
  182. want: http.Header(map[string][]string{
  183. "Foo": []string{"bar", "baz"},
  184. }),
  185. },
  186. {
  187. src: http.Header(map[string][]string{
  188. "Foo": []string{"bar"},
  189. "Ping": []string{"pong"},
  190. }),
  191. dst: http.Header(map[string][]string{}),
  192. want: http.Header(map[string][]string{
  193. "Foo": []string{"bar"},
  194. "Ping": []string{"pong"},
  195. }),
  196. },
  197. {
  198. src: http.Header(map[string][]string{
  199. "Foo": []string{"bar", "baz"},
  200. }),
  201. dst: http.Header(map[string][]string{
  202. "Foo": []string{"qux"},
  203. }),
  204. want: http.Header(map[string][]string{
  205. "Foo": []string{"qux", "bar", "baz"},
  206. }),
  207. },
  208. }
  209. for i, tt := range tests {
  210. copyHeader(tt.dst, tt.src)
  211. if !reflect.DeepEqual(tt.dst, tt.want) {
  212. t.Errorf("#%d: unexpected headers: want = %v, got = %v", i, tt.want, tt.dst)
  213. }
  214. }
  215. }