reverse_test.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. // Copyright 2015 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package httpproxy
  15. import (
  16. "bytes"
  17. "errors"
  18. "io/ioutil"
  19. "net/http"
  20. "net/http/httptest"
  21. "net/url"
  22. "reflect"
  23. "testing"
  24. )
  25. type staticRoundTripper struct {
  26. res *http.Response
  27. err error
  28. }
  29. func (srt *staticRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
  30. return srt.res, srt.err
  31. }
  32. func TestReverseProxyServe(t *testing.T) {
  33. u := url.URL{Scheme: "http", Host: "192.0.2.3:4040"}
  34. tests := []struct {
  35. eps []*endpoint
  36. rt http.RoundTripper
  37. want int
  38. }{
  39. // no endpoints available so no requests are even made
  40. {
  41. eps: []*endpoint{},
  42. rt: &staticRoundTripper{
  43. res: &http.Response{
  44. StatusCode: http.StatusCreated,
  45. Body: ioutil.NopCloser(&bytes.Reader{}),
  46. },
  47. },
  48. want: http.StatusServiceUnavailable,
  49. },
  50. // error is returned from one endpoint that should be available
  51. {
  52. eps: []*endpoint{{URL: u, Available: true}},
  53. rt: &staticRoundTripper{err: errors.New("what a bad trip")},
  54. want: http.StatusBadGateway,
  55. },
  56. // endpoint is available and returns success
  57. {
  58. eps: []*endpoint{{URL: u, Available: true}},
  59. rt: &staticRoundTripper{
  60. res: &http.Response{
  61. StatusCode: http.StatusCreated,
  62. Body: ioutil.NopCloser(&bytes.Reader{}),
  63. Header: map[string][]string{"Content-Type": {"application/json"}},
  64. },
  65. },
  66. want: http.StatusCreated,
  67. },
  68. }
  69. for i, tt := range tests {
  70. rp := reverseProxy{
  71. director: &director{ep: tt.eps},
  72. transport: tt.rt,
  73. }
  74. req, _ := http.NewRequest("GET", "http://192.0.2.2:2379", nil)
  75. rr := httptest.NewRecorder()
  76. rp.ServeHTTP(rr, req)
  77. if rr.Code != tt.want {
  78. t.Errorf("#%d: unexpected HTTP status code: want = %d, got = %d", i, tt.want, rr.Code)
  79. }
  80. if gct := rr.Header().Get("Content-Type"); gct != "application/json" {
  81. t.Errorf("#%d: Content-Type = %s, want %s", i, gct, "application/json")
  82. }
  83. }
  84. }
  85. func TestRedirectRequest(t *testing.T) {
  86. loc := url.URL{
  87. Scheme: "http",
  88. Host: "bar.example.com",
  89. }
  90. req := &http.Request{
  91. Method: "GET",
  92. Host: "foo.example.com",
  93. URL: &url.URL{
  94. Host: "foo.example.com",
  95. Path: "/v2/keys/baz",
  96. },
  97. }
  98. redirectRequest(req, loc)
  99. want := &http.Request{
  100. Method: "GET",
  101. // this field must not change
  102. Host: "foo.example.com",
  103. URL: &url.URL{
  104. // the Scheme field is updated to that of the provided URL
  105. Scheme: "http",
  106. // the Host field is updated to that of the provided URL
  107. Host: "bar.example.com",
  108. Path: "/v2/keys/baz",
  109. },
  110. }
  111. if !reflect.DeepEqual(want, req) {
  112. t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
  113. }
  114. }
  115. func TestMaybeSetForwardedFor(t *testing.T) {
  116. tests := []struct {
  117. raddr string
  118. fwdFor string
  119. want string
  120. }{
  121. {"192.0.2.3:8002", "", "192.0.2.3"},
  122. {"192.0.2.3:8002", "192.0.2.2", "192.0.2.2, 192.0.2.3"},
  123. {"192.0.2.3:8002", "192.0.2.1, 192.0.2.2", "192.0.2.1, 192.0.2.2, 192.0.2.3"},
  124. {"example.com:8002", "", "example.com"},
  125. // While these cases look valid, golang net/http will not let it happen
  126. // The RemoteAddr field will always be a valid host:port
  127. {":8002", "", ""},
  128. {"192.0.2.3", "", ""},
  129. // blatantly invalid host w/o a port
  130. {"12", "", ""},
  131. {"12", "192.0.2.3", "192.0.2.3"},
  132. }
  133. for i, tt := range tests {
  134. req := &http.Request{
  135. RemoteAddr: tt.raddr,
  136. Header: make(http.Header),
  137. }
  138. if tt.fwdFor != "" {
  139. req.Header.Set("X-Forwarded-For", tt.fwdFor)
  140. }
  141. maybeSetForwardedFor(req)
  142. got := req.Header.Get("X-Forwarded-For")
  143. if tt.want != got {
  144. t.Errorf("#%d: incorrect header: want = %q, got = %q", i, tt.want, got)
  145. }
  146. }
  147. }
  148. func TestRemoveSingleHopHeaders(t *testing.T) {
  149. hdr := http.Header(map[string][]string{
  150. // single-hop headers that should be removed
  151. "Connection": {"close"},
  152. "Keep-Alive": {"foo"},
  153. "Proxy-Authenticate": {"Basic realm=example.com"},
  154. "Proxy-Authorization": {"foo"},
  155. "Te": {"deflate,gzip"},
  156. "Trailers": {"ETag"},
  157. "Transfer-Encoding": {"chunked"},
  158. "Upgrade": {"WebSocket"},
  159. // headers that should persist
  160. "Accept": {"application/json"},
  161. "X-Foo": {"Bar"},
  162. })
  163. removeSingleHopHeaders(&hdr)
  164. want := http.Header(map[string][]string{
  165. "Accept": {"application/json"},
  166. "X-Foo": {"Bar"},
  167. })
  168. if !reflect.DeepEqual(want, hdr) {
  169. t.Fatalf("unexpected result: want = %#v, got = %#v", want, hdr)
  170. }
  171. }
  172. func TestCopyHeader(t *testing.T) {
  173. tests := []struct {
  174. src http.Header
  175. dst http.Header
  176. want http.Header
  177. }{
  178. {
  179. src: http.Header(map[string][]string{
  180. "Foo": {"bar", "baz"},
  181. }),
  182. dst: http.Header(map[string][]string{}),
  183. want: http.Header(map[string][]string{
  184. "Foo": {"bar", "baz"},
  185. }),
  186. },
  187. {
  188. src: http.Header(map[string][]string{
  189. "Foo": {"bar"},
  190. "Ping": {"pong"},
  191. }),
  192. dst: http.Header(map[string][]string{}),
  193. want: http.Header(map[string][]string{
  194. "Foo": {"bar"},
  195. "Ping": {"pong"},
  196. }),
  197. },
  198. {
  199. src: http.Header(map[string][]string{
  200. "Foo": {"bar", "baz"},
  201. }),
  202. dst: http.Header(map[string][]string{
  203. "Foo": {"qux"},
  204. }),
  205. want: http.Header(map[string][]string{
  206. "Foo": {"qux", "bar", "baz"},
  207. }),
  208. },
  209. }
  210. for i, tt := range tests {
  211. copyHeader(tt.dst, tt.src)
  212. if !reflect.DeepEqual(tt.dst, tt.want) {
  213. t.Errorf("#%d: unexpected headers: want = %v, got = %v", i, tt.want, tt.dst)
  214. }
  215. }
  216. }