client_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. // Copyright 2015 CoreOS, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package client
  15. import (
  16. "errors"
  17. "io/ioutil"
  18. "net/http"
  19. "net/url"
  20. "reflect"
  21. "strings"
  22. "testing"
  23. "time"
  24. "github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
  25. )
  26. type actionAssertingHTTPClient struct {
  27. t *testing.T
  28. num int
  29. act httpAction
  30. resp http.Response
  31. body []byte
  32. err error
  33. }
  34. func (a *actionAssertingHTTPClient) Do(_ context.Context, act httpAction) (*http.Response, []byte, error) {
  35. if !reflect.DeepEqual(a.act, act) {
  36. a.t.Errorf("#%d: unexpected httpAction: want=%#v got=%#v", a.num, a.act, act)
  37. }
  38. return &a.resp, a.body, a.err
  39. }
  40. type staticHTTPClient struct {
  41. resp http.Response
  42. body []byte
  43. err error
  44. }
  45. func (s *staticHTTPClient) Do(context.Context, httpAction) (*http.Response, []byte, error) {
  46. return &s.resp, s.body, s.err
  47. }
  48. type staticHTTPAction struct {
  49. request http.Request
  50. }
  51. func (s *staticHTTPAction) HTTPRequest(url.URL) *http.Request {
  52. return &s.request
  53. }
  54. type staticHTTPResponse struct {
  55. resp http.Response
  56. err error
  57. }
  58. type multiStaticHTTPClient struct {
  59. responses []staticHTTPResponse
  60. cur int
  61. }
  62. func (s *multiStaticHTTPClient) Do(context.Context, httpAction) (*http.Response, []byte, error) {
  63. r := s.responses[s.cur]
  64. s.cur++
  65. return &r.resp, nil, r.err
  66. }
  67. func newStaticHTTPClientFactory(responses []staticHTTPResponse) httpClientFactory {
  68. var cur int
  69. return func(url.URL) httpClient {
  70. r := responses[cur]
  71. cur++
  72. return &staticHTTPClient{resp: r.resp, err: r.err}
  73. }
  74. }
  75. type fakeTransport struct {
  76. respchan chan *http.Response
  77. errchan chan error
  78. startCancel chan struct{}
  79. finishCancel chan struct{}
  80. }
  81. func newFakeTransport() *fakeTransport {
  82. return &fakeTransport{
  83. respchan: make(chan *http.Response, 1),
  84. errchan: make(chan error, 1),
  85. startCancel: make(chan struct{}, 1),
  86. finishCancel: make(chan struct{}, 1),
  87. }
  88. }
  89. func (t *fakeTransport) RoundTrip(*http.Request) (*http.Response, error) {
  90. select {
  91. case resp := <-t.respchan:
  92. return resp, nil
  93. case err := <-t.errchan:
  94. return nil, err
  95. case <-t.startCancel:
  96. // wait on finishCancel to simulate taking some amount of
  97. // time while calling CancelRequest
  98. <-t.finishCancel
  99. return nil, errors.New("cancelled")
  100. }
  101. }
  102. func (t *fakeTransport) CancelRequest(*http.Request) {
  103. t.startCancel <- struct{}{}
  104. }
  105. type fakeAction struct{}
  106. func (a *fakeAction) HTTPRequest(url.URL) *http.Request {
  107. return &http.Request{}
  108. }
  109. func TestSimpleHTTPClientDoSuccess(t *testing.T) {
  110. tr := newFakeTransport()
  111. c := &simpleHTTPClient{transport: tr}
  112. tr.respchan <- &http.Response{
  113. StatusCode: http.StatusTeapot,
  114. Body: ioutil.NopCloser(strings.NewReader("foo")),
  115. }
  116. resp, body, err := c.Do(context.Background(), &fakeAction{})
  117. if err != nil {
  118. t.Fatalf("incorrect error value: want=nil got=%v", err)
  119. }
  120. wantCode := http.StatusTeapot
  121. if wantCode != resp.StatusCode {
  122. t.Fatalf("invalid response code: want=%d got=%d", wantCode, resp.StatusCode)
  123. }
  124. wantBody := []byte("foo")
  125. if !reflect.DeepEqual(wantBody, body) {
  126. t.Fatalf("invalid response body: want=%q got=%q", wantBody, body)
  127. }
  128. }
  129. func TestSimpleHTTPClientDoError(t *testing.T) {
  130. tr := newFakeTransport()
  131. c := &simpleHTTPClient{transport: tr}
  132. tr.errchan <- errors.New("fixture")
  133. _, _, err := c.Do(context.Background(), &fakeAction{})
  134. if err == nil {
  135. t.Fatalf("expected non-nil error, got nil")
  136. }
  137. }
  138. func TestSimpleHTTPClientDoCancelContext(t *testing.T) {
  139. tr := newFakeTransport()
  140. c := &simpleHTTPClient{transport: tr}
  141. tr.startCancel <- struct{}{}
  142. tr.finishCancel <- struct{}{}
  143. _, _, err := c.Do(context.Background(), &fakeAction{})
  144. if err == nil {
  145. t.Fatalf("expected non-nil error, got nil")
  146. }
  147. }
  148. func TestSimpleHTTPClientDoCancelContextWaitForRoundTrip(t *testing.T) {
  149. tr := newFakeTransport()
  150. c := &simpleHTTPClient{transport: tr}
  151. donechan := make(chan struct{})
  152. ctx, cancel := context.WithCancel(context.Background())
  153. go func() {
  154. c.Do(ctx, &fakeAction{})
  155. close(donechan)
  156. }()
  157. // This should call CancelRequest and begin the cancellation process
  158. cancel()
  159. select {
  160. case <-donechan:
  161. t.Fatalf("simpleHTTPClient.Do should not have exited yet")
  162. default:
  163. }
  164. tr.finishCancel <- struct{}{}
  165. select {
  166. case <-donechan:
  167. //expected behavior
  168. return
  169. case <-time.After(time.Second):
  170. t.Fatalf("simpleHTTPClient.Do did not exit within 1s")
  171. }
  172. }
  173. func TestHTTPClusterClientDo(t *testing.T) {
  174. fakeErr := errors.New("fake!")
  175. fakeURL := url.URL{}
  176. tests := []struct {
  177. client *httpClusterClient
  178. wantCode int
  179. wantErr error
  180. }{
  181. // first good response short-circuits Do
  182. {
  183. client: &httpClusterClient{
  184. endpoints: []url.URL{fakeURL, fakeURL},
  185. clientFactory: newStaticHTTPClientFactory(
  186. []staticHTTPResponse{
  187. staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
  188. staticHTTPResponse{err: fakeErr},
  189. },
  190. ),
  191. },
  192. wantCode: http.StatusTeapot,
  193. },
  194. // fall through to good endpoint if err is arbitrary
  195. {
  196. client: &httpClusterClient{
  197. endpoints: []url.URL{fakeURL, fakeURL},
  198. clientFactory: newStaticHTTPClientFactory(
  199. []staticHTTPResponse{
  200. staticHTTPResponse{err: fakeErr},
  201. staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
  202. },
  203. ),
  204. },
  205. wantCode: http.StatusTeapot,
  206. },
  207. // context.DeadlineExceeded short-circuits Do
  208. {
  209. client: &httpClusterClient{
  210. endpoints: []url.URL{fakeURL, fakeURL},
  211. clientFactory: newStaticHTTPClientFactory(
  212. []staticHTTPResponse{
  213. staticHTTPResponse{err: context.DeadlineExceeded},
  214. staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
  215. },
  216. ),
  217. },
  218. wantErr: context.DeadlineExceeded,
  219. },
  220. // context.Canceled short-circuits Do
  221. {
  222. client: &httpClusterClient{
  223. endpoints: []url.URL{fakeURL, fakeURL},
  224. clientFactory: newStaticHTTPClientFactory(
  225. []staticHTTPResponse{
  226. staticHTTPResponse{err: context.Canceled},
  227. staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
  228. },
  229. ),
  230. },
  231. wantErr: context.Canceled,
  232. },
  233. // return err if there are no endpoints
  234. {
  235. client: &httpClusterClient{
  236. endpoints: []url.URL{},
  237. clientFactory: newHTTPClientFactory(nil, nil),
  238. },
  239. wantErr: ErrNoEndpoints,
  240. },
  241. // return err if all endpoints return arbitrary errors
  242. {
  243. client: &httpClusterClient{
  244. endpoints: []url.URL{fakeURL, fakeURL},
  245. clientFactory: newStaticHTTPClientFactory(
  246. []staticHTTPResponse{
  247. staticHTTPResponse{err: fakeErr},
  248. staticHTTPResponse{err: fakeErr},
  249. },
  250. ),
  251. },
  252. wantErr: fakeErr,
  253. },
  254. // 500-level errors cause Do to fallthrough to next endpoint
  255. {
  256. client: &httpClusterClient{
  257. endpoints: []url.URL{fakeURL, fakeURL},
  258. clientFactory: newStaticHTTPClientFactory(
  259. []staticHTTPResponse{
  260. staticHTTPResponse{resp: http.Response{StatusCode: http.StatusBadGateway}},
  261. staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
  262. },
  263. ),
  264. },
  265. wantCode: http.StatusTeapot,
  266. },
  267. }
  268. for i, tt := range tests {
  269. resp, _, err := tt.client.Do(context.Background(), nil)
  270. if !reflect.DeepEqual(tt.wantErr, err) {
  271. t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
  272. continue
  273. }
  274. if resp == nil {
  275. if tt.wantCode != 0 {
  276. t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode)
  277. }
  278. continue
  279. }
  280. if resp.StatusCode != tt.wantCode {
  281. t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
  282. continue
  283. }
  284. }
  285. }
  286. func TestRedirectedHTTPAction(t *testing.T) {
  287. act := &redirectedHTTPAction{
  288. action: &staticHTTPAction{
  289. request: http.Request{
  290. Method: "DELETE",
  291. URL: &url.URL{
  292. Scheme: "https",
  293. Host: "foo.example.com",
  294. Path: "/ping",
  295. },
  296. },
  297. },
  298. location: url.URL{
  299. Scheme: "https",
  300. Host: "bar.example.com",
  301. Path: "/pong",
  302. },
  303. }
  304. want := &http.Request{
  305. Method: "DELETE",
  306. URL: &url.URL{
  307. Scheme: "https",
  308. Host: "bar.example.com",
  309. Path: "/pong",
  310. },
  311. }
  312. got := act.HTTPRequest(url.URL{Scheme: "http", Host: "baz.example.com", Path: "/pang"})
  313. if !reflect.DeepEqual(want, got) {
  314. t.Fatalf("HTTPRequest is %#v, want %#v", want, got)
  315. }
  316. }
  317. func TestRedirectFollowingHTTPClient(t *testing.T) {
  318. tests := []struct {
  319. checkRedirect CheckRedirectFunc
  320. client httpClient
  321. wantCode int
  322. wantErr error
  323. }{
  324. // errors bubbled up
  325. {
  326. checkRedirect: func(int) error { return ErrTooManyRedirects },
  327. client: &multiStaticHTTPClient{
  328. responses: []staticHTTPResponse{
  329. staticHTTPResponse{
  330. err: errors.New("fail!"),
  331. },
  332. },
  333. },
  334. wantErr: errors.New("fail!"),
  335. },
  336. // no need to follow redirect if none given
  337. {
  338. checkRedirect: func(int) error { return ErrTooManyRedirects },
  339. client: &multiStaticHTTPClient{
  340. responses: []staticHTTPResponse{
  341. staticHTTPResponse{
  342. resp: http.Response{
  343. StatusCode: http.StatusTeapot,
  344. },
  345. },
  346. },
  347. },
  348. wantCode: http.StatusTeapot,
  349. },
  350. // redirects if less than max
  351. {
  352. checkRedirect: func(via int) error {
  353. if via >= 2 {
  354. return ErrTooManyRedirects
  355. }
  356. return nil
  357. },
  358. client: &multiStaticHTTPClient{
  359. responses: []staticHTTPResponse{
  360. staticHTTPResponse{
  361. resp: http.Response{
  362. StatusCode: http.StatusTemporaryRedirect,
  363. Header: http.Header{"Location": []string{"http://example.com"}},
  364. },
  365. },
  366. staticHTTPResponse{
  367. resp: http.Response{
  368. StatusCode: http.StatusTeapot,
  369. },
  370. },
  371. },
  372. },
  373. wantCode: http.StatusTeapot,
  374. },
  375. // succeed after reaching max redirects
  376. {
  377. checkRedirect: func(via int) error {
  378. if via >= 3 {
  379. return ErrTooManyRedirects
  380. }
  381. return nil
  382. },
  383. client: &multiStaticHTTPClient{
  384. responses: []staticHTTPResponse{
  385. staticHTTPResponse{
  386. resp: http.Response{
  387. StatusCode: http.StatusTemporaryRedirect,
  388. Header: http.Header{"Location": []string{"http://example.com"}},
  389. },
  390. },
  391. staticHTTPResponse{
  392. resp: http.Response{
  393. StatusCode: http.StatusTemporaryRedirect,
  394. Header: http.Header{"Location": []string{"http://example.com"}},
  395. },
  396. },
  397. staticHTTPResponse{
  398. resp: http.Response{
  399. StatusCode: http.StatusTeapot,
  400. },
  401. },
  402. },
  403. },
  404. wantCode: http.StatusTeapot,
  405. },
  406. // fail if too many redirects
  407. {
  408. checkRedirect: func(via int) error {
  409. if via >= 2 {
  410. return ErrTooManyRedirects
  411. }
  412. return nil
  413. },
  414. client: &multiStaticHTTPClient{
  415. responses: []staticHTTPResponse{
  416. staticHTTPResponse{
  417. resp: http.Response{
  418. StatusCode: http.StatusTemporaryRedirect,
  419. Header: http.Header{"Location": []string{"http://example.com"}},
  420. },
  421. },
  422. staticHTTPResponse{
  423. resp: http.Response{
  424. StatusCode: http.StatusTemporaryRedirect,
  425. Header: http.Header{"Location": []string{"http://example.com"}},
  426. },
  427. },
  428. staticHTTPResponse{
  429. resp: http.Response{
  430. StatusCode: http.StatusTeapot,
  431. },
  432. },
  433. },
  434. },
  435. wantErr: ErrTooManyRedirects,
  436. },
  437. // fail if Location header not set
  438. {
  439. checkRedirect: func(int) error { return ErrTooManyRedirects },
  440. client: &multiStaticHTTPClient{
  441. responses: []staticHTTPResponse{
  442. staticHTTPResponse{
  443. resp: http.Response{
  444. StatusCode: http.StatusTemporaryRedirect,
  445. },
  446. },
  447. },
  448. },
  449. wantErr: errors.New("Location header not set"),
  450. },
  451. // fail if Location header is invalid
  452. {
  453. checkRedirect: func(int) error { return ErrTooManyRedirects },
  454. client: &multiStaticHTTPClient{
  455. responses: []staticHTTPResponse{
  456. staticHTTPResponse{
  457. resp: http.Response{
  458. StatusCode: http.StatusTemporaryRedirect,
  459. Header: http.Header{"Location": []string{":"}},
  460. },
  461. },
  462. },
  463. },
  464. wantErr: errors.New("Location header not valid URL: :"),
  465. },
  466. }
  467. for i, tt := range tests {
  468. client := &redirectFollowingHTTPClient{client: tt.client, checkRedirect: tt.checkRedirect}
  469. resp, _, err := client.Do(context.Background(), nil)
  470. if !reflect.DeepEqual(tt.wantErr, err) {
  471. t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
  472. continue
  473. }
  474. if resp == nil {
  475. if tt.wantCode != 0 {
  476. t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode)
  477. }
  478. continue
  479. }
  480. if resp.StatusCode != tt.wantCode {
  481. t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
  482. continue
  483. }
  484. }
  485. }
  486. func TestDefaultCheckRedirect(t *testing.T) {
  487. tests := []struct {
  488. num int
  489. err error
  490. }{
  491. {0, nil},
  492. {5, nil},
  493. {10, nil},
  494. {11, ErrTooManyRedirects},
  495. {29, ErrTooManyRedirects},
  496. }
  497. for i, tt := range tests {
  498. err := DefaultCheckRedirect(tt.num)
  499. if !reflect.DeepEqual(tt.err, err) {
  500. t.Errorf("#%d: want=%#v got=%#v", i, tt.err, err)
  501. }
  502. }
  503. }