123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- // Copyright 2015 CoreOS, Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package proxy
- import (
- "bytes"
- "errors"
- "io/ioutil"
- "net/http"
- "net/http/httptest"
- "net/url"
- "reflect"
- "testing"
- )
- type staticRoundTripper struct {
- res *http.Response
- err error
- }
- func (srt *staticRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
- return srt.res, srt.err
- }
- func TestReverseProxyServe(t *testing.T) {
- u := url.URL{Scheme: "http", Host: "192.0.2.3:4040"}
- tests := []struct {
- eps []*endpoint
- rt http.RoundTripper
- want int
- }{
- // no endpoints available so no requests are even made
- {
- eps: []*endpoint{},
- rt: &staticRoundTripper{
- res: &http.Response{
- StatusCode: http.StatusCreated,
- Body: ioutil.NopCloser(&bytes.Reader{}),
- },
- },
- want: http.StatusServiceUnavailable,
- },
- // error is returned from one endpoint that should be available
- {
- eps: []*endpoint{&endpoint{URL: u, Available: true}},
- rt: &staticRoundTripper{err: errors.New("what a bad trip")},
- want: http.StatusBadGateway,
- },
- // endpoint is available and returns success
- {
- eps: []*endpoint{&endpoint{URL: u, Available: true}},
- rt: &staticRoundTripper{
- res: &http.Response{
- StatusCode: http.StatusCreated,
- Body: ioutil.NopCloser(&bytes.Reader{}),
- Header: map[string][]string{"Content-Type": []string{"application/json"}},
- },
- },
- want: http.StatusCreated,
- },
- }
- for i, tt := range tests {
- rp := reverseProxy{
- director: &director{ep: tt.eps},
- transport: tt.rt,
- }
- req, _ := http.NewRequest("GET", "http://192.0.2.2:4001", nil)
- rr := httptest.NewRecorder()
- rp.ServeHTTP(rr, req)
- if rr.Code != tt.want {
- t.Errorf("#%d: unexpected HTTP status code: want = %d, got = %d", i, tt.want, rr.Code)
- }
- if gct := rr.Header().Get("Content-Type"); gct != "application/json" {
- t.Errorf("#%d: Content-Type = %s, want %s", i, gct, "application/json")
- }
- }
- }
- func TestRedirectRequest(t *testing.T) {
- loc := url.URL{
- Scheme: "http",
- Host: "bar.example.com",
- }
- req := &http.Request{
- Method: "GET",
- Host: "foo.example.com",
- URL: &url.URL{
- Host: "foo.example.com",
- Path: "/v2/keys/baz",
- },
- }
- redirectRequest(req, loc)
- want := &http.Request{
- Method: "GET",
- // this field must not change
- Host: "foo.example.com",
- URL: &url.URL{
- // the Scheme field is updated to that of the provided URL
- Scheme: "http",
- // the Host field is updated to that of the provided URL
- Host: "bar.example.com",
- Path: "/v2/keys/baz",
- },
- }
- if !reflect.DeepEqual(want, req) {
- t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
- }
- }
- func TestMaybeSetForwardedFor(t *testing.T) {
- tests := []struct {
- raddr string
- fwdFor string
- want string
- }{
- {"192.0.2.3:8002", "", "192.0.2.3"},
- {"192.0.2.3:8002", "192.0.2.2", "192.0.2.2, 192.0.2.3"},
- {"192.0.2.3:8002", "192.0.2.1, 192.0.2.2", "192.0.2.1, 192.0.2.2, 192.0.2.3"},
- {"example.com:8002", "", "example.com"},
- // While these cases look valid, golang net/http will not let it happen
- // The RemoteAddr field will always be a valid host:port
- {":8002", "", ""},
- {"192.0.2.3", "", ""},
- // blatantly invalid host w/o a port
- {"12", "", ""},
- {"12", "192.0.2.3", "192.0.2.3"},
- }
- for i, tt := range tests {
- req := &http.Request{
- RemoteAddr: tt.raddr,
- Header: make(http.Header),
- }
- if tt.fwdFor != "" {
- req.Header.Set("X-Forwarded-For", tt.fwdFor)
- }
- maybeSetForwardedFor(req)
- got := req.Header.Get("X-Forwarded-For")
- if tt.want != got {
- t.Errorf("#%d: incorrect header: want = %q, got = %q", i, tt.want, got)
- }
- }
- }
- func TestRemoveSingleHopHeaders(t *testing.T) {
- hdr := http.Header(map[string][]string{
- // single-hop headers that should be removed
- "Connection": []string{"close"},
- "Keep-Alive": []string{"foo"},
- "Proxy-Authenticate": []string{"Basic realm=example.com"},
- "Proxy-Authorization": []string{"foo"},
- "Te": []string{"deflate,gzip"},
- "Trailers": []string{"ETag"},
- "Transfer-Encoding": []string{"chunked"},
- "Upgrade": []string{"WebSocket"},
- // headers that should persist
- "Accept": []string{"application/json"},
- "X-Foo": []string{"Bar"},
- })
- removeSingleHopHeaders(&hdr)
- want := http.Header(map[string][]string{
- "Accept": []string{"application/json"},
- "X-Foo": []string{"Bar"},
- })
- if !reflect.DeepEqual(want, hdr) {
- t.Fatalf("unexpected result: want = %#v, got = %#v", want, hdr)
- }
- }
- func TestCopyHeader(t *testing.T) {
- tests := []struct {
- src http.Header
- dst http.Header
- want http.Header
- }{
- {
- src: http.Header(map[string][]string{
- "Foo": []string{"bar", "baz"},
- }),
- dst: http.Header(map[string][]string{}),
- want: http.Header(map[string][]string{
- "Foo": []string{"bar", "baz"},
- }),
- },
- {
- src: http.Header(map[string][]string{
- "Foo": []string{"bar"},
- "Ping": []string{"pong"},
- }),
- dst: http.Header(map[string][]string{}),
- want: http.Header(map[string][]string{
- "Foo": []string{"bar"},
- "Ping": []string{"pong"},
- }),
- },
- {
- src: http.Header(map[string][]string{
- "Foo": []string{"bar", "baz"},
- }),
- dst: http.Header(map[string][]string{
- "Foo": []string{"qux"},
- }),
- want: http.Header(map[string][]string{
- "Foo": []string{"qux", "bar", "baz"},
- }),
- },
- }
- for i, tt := range tests {
- copyHeader(tt.dst, tt.src)
- if !reflect.DeepEqual(tt.dst, tt.want) {
- t.Errorf("#%d: unexpected headers: want = %v, got = %v", i, tt.want, tt.dst)
- }
- }
- }
|