server_push_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. // Copyright 2016 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package http2
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "net/http"
  11. "reflect"
  12. "strconv"
  13. "sync"
  14. "testing"
  15. "time"
  16. )
  17. func TestServer_Push_Success(t *testing.T) {
  18. const (
  19. mainBody = "<html>index page</html>"
  20. pushedBody = "<html>pushed page</html>"
  21. userAgent = "testagent"
  22. cookie = "testcookie"
  23. )
  24. var stURL string
  25. checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
  26. if got, want := r.Method, wantMethod; got != want {
  27. return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
  28. }
  29. if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
  30. return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
  31. }
  32. if got, want := "https://"+r.Host, stURL; got != want {
  33. return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
  34. }
  35. if r.Body == nil {
  36. return fmt.Errorf("nil Body")
  37. }
  38. if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 {
  39. return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
  40. }
  41. return nil
  42. }
  43. errc := make(chan error, 3)
  44. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  45. switch r.URL.RequestURI() {
  46. case "/":
  47. // Push "/pushed?get" as a GET request, using an absolute URL.
  48. opt := &http.PushOptions{
  49. Header: http.Header{
  50. "User-Agent": {userAgent},
  51. },
  52. }
  53. if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
  54. errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
  55. return
  56. }
  57. // Push "/pushed?head" as a HEAD request, using a path.
  58. opt = &http.PushOptions{
  59. Method: "HEAD",
  60. Header: http.Header{
  61. "User-Agent": {userAgent},
  62. "Cookie": {cookie},
  63. },
  64. }
  65. if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
  66. errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
  67. return
  68. }
  69. w.Header().Set("Content-Type", "text/html")
  70. w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
  71. w.WriteHeader(200)
  72. io.WriteString(w, mainBody)
  73. errc <- nil
  74. case "/pushed?get":
  75. wantH := http.Header{}
  76. wantH.Set("User-Agent", userAgent)
  77. if err := checkPromisedReq(r, "GET", wantH); err != nil {
  78. errc <- fmt.Errorf("/pushed?get: %v", err)
  79. return
  80. }
  81. w.Header().Set("Content-Type", "text/html")
  82. w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
  83. w.WriteHeader(200)
  84. io.WriteString(w, pushedBody)
  85. errc <- nil
  86. case "/pushed?head":
  87. wantH := http.Header{}
  88. wantH.Set("User-Agent", userAgent)
  89. wantH.Set("Cookie", cookie)
  90. if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
  91. errc <- fmt.Errorf("/pushed?head: %v", err)
  92. return
  93. }
  94. w.WriteHeader(204)
  95. errc <- nil
  96. default:
  97. errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
  98. }
  99. })
  100. stURL = st.ts.URL
  101. // Send one request, which should push two responses.
  102. st.greet()
  103. getSlash(st)
  104. for k := 0; k < 3; k++ {
  105. select {
  106. case <-time.After(2 * time.Second):
  107. t.Errorf("timeout waiting for handler %d to finish", k)
  108. case err := <-errc:
  109. if err != nil {
  110. t.Fatal(err)
  111. }
  112. }
  113. }
  114. checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
  115. pp, ok := f.(*PushPromiseFrame)
  116. if !ok {
  117. return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
  118. }
  119. if !pp.HeadersEnded() {
  120. return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
  121. }
  122. if got, want := pp.PromiseID, promiseID; got != want {
  123. return fmt.Errorf("got PromiseID %v; want %v", got, want)
  124. }
  125. gotH := st.decodeHeader(pp.HeaderBlockFragment())
  126. if !reflect.DeepEqual(gotH, wantH) {
  127. return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
  128. }
  129. return nil
  130. }
  131. checkHeaders := func(f Frame, wantH [][2]string) error {
  132. hf, ok := f.(*HeadersFrame)
  133. if !ok {
  134. return fmt.Errorf("got a %T; want *HeadersFrame", f)
  135. }
  136. gotH := st.decodeHeader(hf.HeaderBlockFragment())
  137. if !reflect.DeepEqual(gotH, wantH) {
  138. return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
  139. }
  140. return nil
  141. }
  142. checkData := func(f Frame, wantData string) error {
  143. df, ok := f.(*DataFrame)
  144. if !ok {
  145. return fmt.Errorf("got a %T; want *DataFrame", f)
  146. }
  147. if gotData := string(df.Data()); gotData != wantData {
  148. return fmt.Errorf("got response data %q; want %q", gotData, wantData)
  149. }
  150. return nil
  151. }
  152. // Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA
  153. // Stream 2 has HEADERS + DATA
  154. // Stream 4 has HEADERS
  155. expected := map[uint32][]func(Frame) error{
  156. 1: {
  157. func(f Frame) error {
  158. return checkPushPromise(f, 2, [][2]string{
  159. {":method", "GET"},
  160. {":scheme", "https"},
  161. {":authority", st.ts.Listener.Addr().String()},
  162. {":path", "/pushed?get"},
  163. {"user-agent", userAgent},
  164. })
  165. },
  166. func(f Frame) error {
  167. return checkPushPromise(f, 4, [][2]string{
  168. {":method", "HEAD"},
  169. {":scheme", "https"},
  170. {":authority", st.ts.Listener.Addr().String()},
  171. {":path", "/pushed?head"},
  172. {"cookie", cookie},
  173. {"user-agent", userAgent},
  174. })
  175. },
  176. func(f Frame) error {
  177. return checkHeaders(f, [][2]string{
  178. {":status", "200"},
  179. {"content-type", "text/html"},
  180. {"content-length", strconv.Itoa(len(mainBody))},
  181. })
  182. },
  183. func(f Frame) error {
  184. return checkData(f, mainBody)
  185. },
  186. },
  187. 2: {
  188. func(f Frame) error {
  189. return checkHeaders(f, [][2]string{
  190. {":status", "200"},
  191. {"content-type", "text/html"},
  192. {"content-length", strconv.Itoa(len(pushedBody))},
  193. })
  194. },
  195. func(f Frame) error {
  196. return checkData(f, pushedBody)
  197. },
  198. },
  199. 4: {
  200. func(f Frame) error {
  201. return checkHeaders(f, [][2]string{
  202. {":status", "204"},
  203. })
  204. },
  205. },
  206. }
  207. consumed := map[uint32]int{}
  208. for k := 0; len(expected) > 0; k++ {
  209. f, err := st.readFrame()
  210. if err != nil {
  211. for id, left := range expected {
  212. t.Errorf("stream %d: missing %d frames", id, len(left))
  213. }
  214. t.Fatalf("readFrame %d: %v", k, err)
  215. }
  216. id := f.Header().StreamID
  217. label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
  218. if len(expected[id]) == 0 {
  219. t.Fatalf("%s: unexpected frame %#+v", label, f)
  220. }
  221. check := expected[id][0]
  222. expected[id] = expected[id][1:]
  223. if len(expected[id]) == 0 {
  224. delete(expected, id)
  225. }
  226. if err := check(f); err != nil {
  227. t.Fatalf("%s: %v", label, err)
  228. }
  229. consumed[id]++
  230. }
  231. }
  232. func TestServer_Push_SuccessNoRace(t *testing.T) {
  233. // Regression test for issue #18326. Ensure the request handler can mutate
  234. // pushed request headers without racing with the PUSH_PROMISE write.
  235. errc := make(chan error, 2)
  236. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  237. switch r.URL.RequestURI() {
  238. case "/":
  239. opt := &http.PushOptions{
  240. Header: http.Header{"User-Agent": {"testagent"}},
  241. }
  242. if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
  243. errc <- fmt.Errorf("error pushing: %v", err)
  244. return
  245. }
  246. w.WriteHeader(200)
  247. errc <- nil
  248. case "/pushed":
  249. // Update request header, ensure there is no race.
  250. r.Header.Set("User-Agent", "newagent")
  251. r.Header.Set("Cookie", "cookie")
  252. w.WriteHeader(200)
  253. errc <- nil
  254. default:
  255. errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
  256. }
  257. })
  258. // Send one request, which should push one response.
  259. st.greet()
  260. getSlash(st)
  261. for k := 0; k < 2; k++ {
  262. select {
  263. case <-time.After(2 * time.Second):
  264. t.Errorf("timeout waiting for handler %d to finish", k)
  265. case err := <-errc:
  266. if err != nil {
  267. t.Fatal(err)
  268. }
  269. }
  270. }
  271. }
  272. func TestServer_Push_RejectRecursivePush(t *testing.T) {
  273. // Expect two requests, but might get three if there's a bug and the second push succeeds.
  274. errc := make(chan error, 3)
  275. handler := func(w http.ResponseWriter, r *http.Request) error {
  276. baseURL := "https://" + r.Host
  277. switch r.URL.Path {
  278. case "/":
  279. if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
  280. return fmt.Errorf("first Push()=%v, want nil", err)
  281. }
  282. return nil
  283. case "/push1":
  284. if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
  285. return fmt.Errorf("Push()=%v, want %v", got, want)
  286. }
  287. return nil
  288. default:
  289. return fmt.Errorf("unexpected path: %q", r.URL.Path)
  290. }
  291. }
  292. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  293. errc <- handler(w, r)
  294. })
  295. defer st.Close()
  296. st.greet()
  297. getSlash(st)
  298. if err := <-errc; err != nil {
  299. t.Errorf("First request failed: %v", err)
  300. }
  301. if err := <-errc; err != nil {
  302. t.Errorf("Second request failed: %v", err)
  303. }
  304. }
  305. func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
  306. // Expect one request, but might get two if there's a bug and the push succeeds.
  307. errc := make(chan error, 2)
  308. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  309. errc <- doPush(w.(http.Pusher), r)
  310. })
  311. defer st.Close()
  312. st.greet()
  313. if err := st.fr.WriteSettings(settings...); err != nil {
  314. st.t.Fatalf("WriteSettings: %v", err)
  315. }
  316. st.wantSettingsAck()
  317. getSlash(st)
  318. if err := <-errc; err != nil {
  319. t.Error(err)
  320. }
  321. // Should not get a PUSH_PROMISE frame.
  322. hf := st.wantHeaders()
  323. if !hf.StreamEnded() {
  324. t.Error("stream should end after headers")
  325. }
  326. }
  327. func TestServer_Push_RejectIfDisabled(t *testing.T) {
  328. testServer_Push_RejectSingleRequest(t,
  329. func(p http.Pusher, r *http.Request) error {
  330. if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
  331. return fmt.Errorf("Push()=%v, want %v", got, want)
  332. }
  333. return nil
  334. },
  335. Setting{SettingEnablePush, 0})
  336. }
  337. func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
  338. testServer_Push_RejectSingleRequest(t,
  339. func(p http.Pusher, r *http.Request) error {
  340. if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
  341. return fmt.Errorf("Push()=%v, want %v", got, want)
  342. }
  343. return nil
  344. },
  345. Setting{SettingMaxConcurrentStreams, 0})
  346. }
  347. func TestServer_Push_RejectWrongScheme(t *testing.T) {
  348. testServer_Push_RejectSingleRequest(t,
  349. func(p http.Pusher, r *http.Request) error {
  350. if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
  351. return errors.New("Push() should have failed (push target URL is http)")
  352. }
  353. return nil
  354. })
  355. }
  356. func TestServer_Push_RejectMissingHost(t *testing.T) {
  357. testServer_Push_RejectSingleRequest(t,
  358. func(p http.Pusher, r *http.Request) error {
  359. if err := p.Push("https:pushed", nil); err == nil {
  360. return errors.New("Push() should have failed (push target URL missing host)")
  361. }
  362. return nil
  363. })
  364. }
  365. func TestServer_Push_RejectRelativePath(t *testing.T) {
  366. testServer_Push_RejectSingleRequest(t,
  367. func(p http.Pusher, r *http.Request) error {
  368. if err := p.Push("../test", nil); err == nil {
  369. return errors.New("Push() should have failed (push target is a relative path)")
  370. }
  371. return nil
  372. })
  373. }
  374. func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
  375. testServer_Push_RejectSingleRequest(t,
  376. func(p http.Pusher, r *http.Request) error {
  377. if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
  378. return errors.New("Push() should have failed (cannot promise a POST)")
  379. }
  380. return nil
  381. })
  382. }
  383. func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
  384. testServer_Push_RejectSingleRequest(t,
  385. func(p http.Pusher, r *http.Request) error {
  386. header := http.Header{
  387. "Content-Length": {"10"},
  388. "Content-Encoding": {"gzip"},
  389. "Trailer": {"Foo"},
  390. "Te": {"trailers"},
  391. "Host": {"test.com"},
  392. ":authority": {"test.com"},
  393. }
  394. if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
  395. return errors.New("Push() should have failed (forbidden headers)")
  396. }
  397. return nil
  398. })
  399. }
  400. func TestServer_Push_StateTransitions(t *testing.T) {
  401. const body = "foo"
  402. gotPromise := make(chan bool)
  403. finishedPush := make(chan bool)
  404. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  405. switch r.URL.RequestURI() {
  406. case "/":
  407. if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
  408. t.Errorf("Push error: %v", err)
  409. }
  410. // Don't finish this request until the push finishes so we don't
  411. // nondeterministically interleave output frames with the push.
  412. <-finishedPush
  413. case "/pushed":
  414. <-gotPromise
  415. }
  416. w.Header().Set("Content-Type", "text/html")
  417. w.Header().Set("Content-Length", strconv.Itoa(len(body)))
  418. w.WriteHeader(200)
  419. io.WriteString(w, body)
  420. })
  421. defer st.Close()
  422. st.greet()
  423. if st.stream(2) != nil {
  424. t.Fatal("stream 2 should be empty")
  425. }
  426. if got, want := st.streamState(2), stateIdle; got != want {
  427. t.Fatalf("streamState(2)=%v, want %v", got, want)
  428. }
  429. getSlash(st)
  430. // After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote.
  431. st.wantPushPromise()
  432. if got, want := st.streamState(2), stateHalfClosedRemote; got != want {
  433. t.Fatalf("streamState(2)=%v, want %v", got, want)
  434. }
  435. // We stall the HTTP handler for "/pushed" until the above check. If we don't
  436. // stall the handler, then the handler might write HEADERS and DATA and finish
  437. // the stream before we check st.streamState(2) -- should that happen, we'll
  438. // see stateClosed and fail the above check.
  439. close(gotPromise)
  440. st.wantHeaders()
  441. if df := st.wantData(); !df.StreamEnded() {
  442. t.Fatal("expected END_STREAM flag on DATA")
  443. }
  444. if got, want := st.streamState(2), stateClosed; got != want {
  445. t.Fatalf("streamState(2)=%v, want %v", got, want)
  446. }
  447. close(finishedPush)
  448. }
  449. func TestServer_Push_RejectAfterGoAway(t *testing.T) {
  450. var readyOnce sync.Once
  451. ready := make(chan struct{})
  452. errc := make(chan error, 2)
  453. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  454. select {
  455. case <-ready:
  456. case <-time.After(5 * time.Second):
  457. errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed")
  458. }
  459. if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
  460. errc <- fmt.Errorf("Push()=%v, want %v", got, want)
  461. }
  462. errc <- nil
  463. })
  464. defer st.Close()
  465. st.greet()
  466. getSlash(st)
  467. // Send GOAWAY and wait for it to be processed.
  468. st.fr.WriteGoAway(1, ErrCodeNo, nil)
  469. go func() {
  470. for {
  471. select {
  472. case <-ready:
  473. return
  474. default:
  475. }
  476. st.sc.serveMsgCh <- func(loopNum int) {
  477. if !st.sc.pushEnabled {
  478. readyOnce.Do(func() { close(ready) })
  479. }
  480. }
  481. }
  482. }()
  483. if err := <-errc; err != nil {
  484. t.Error(err)
  485. }
  486. }