server_test.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package rest
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/http/httptest"
  7. "testing"
  8. "github.com/stretchr/testify/assert"
  9. "github.com/tal-tech/go-zero/core/conf"
  10. "github.com/tal-tech/go-zero/rest/httpx"
  11. "github.com/tal-tech/go-zero/rest/router"
  12. )
  13. func TestNewServer(t *testing.T) {
  14. const configYaml = `
  15. Name: foo
  16. Port: 54321
  17. `
  18. var cnf RestConf
  19. assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
  20. failStart := func(server *Server) {
  21. server.opts.start = func(e *engine) error {
  22. return http.ErrServerClosed
  23. }
  24. }
  25. tests := []struct {
  26. c RestConf
  27. opts []RunOption
  28. fail bool
  29. }{
  30. {
  31. c: RestConf{},
  32. opts: []RunOption{failStart},
  33. fail: true,
  34. },
  35. {
  36. c: cnf,
  37. opts: []RunOption{failStart},
  38. },
  39. {
  40. c: cnf,
  41. opts: []RunOption{WithNotAllowedHandler(nil), failStart},
  42. },
  43. {
  44. c: cnf,
  45. opts: []RunOption{WithNotFoundHandler(nil), failStart},
  46. },
  47. {
  48. c: cnf,
  49. opts: []RunOption{WithUnauthorizedCallback(nil), failStart},
  50. },
  51. {
  52. c: cnf,
  53. opts: []RunOption{WithUnsignedCallback(nil), failStart},
  54. },
  55. }
  56. for _, test := range tests {
  57. srv, err := NewServer(test.c, test.opts...)
  58. if test.fail {
  59. assert.NotNil(t, err)
  60. }
  61. if err != nil {
  62. continue
  63. }
  64. srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
  65. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  66. next.ServeHTTP(w, r)
  67. })
  68. }))
  69. srv.AddRoute(Route{
  70. Method: http.MethodGet,
  71. Path: "/",
  72. Handler: nil,
  73. }, WithJwt("thesecret"), WithSignature(SignatureConf{}),
  74. WithJwtTransition("preivous", "thenewone"))
  75. srv.Start()
  76. srv.Stop()
  77. }
  78. }
  79. func TestWithMiddleware(t *testing.T) {
  80. m := make(map[string]string)
  81. rt := router.NewRouter()
  82. handler := func(w http.ResponseWriter, r *http.Request) {
  83. var v struct {
  84. Nickname string `form:"nickname"`
  85. Zipcode int64 `form:"zipcode"`
  86. }
  87. err := httpx.Parse(r, &v)
  88. assert.Nil(t, err)
  89. _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
  90. assert.Nil(t, err)
  91. }
  92. rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
  93. return func(w http.ResponseWriter, r *http.Request) {
  94. var v struct {
  95. Name string `path:"name"`
  96. Year string `path:"year"`
  97. }
  98. assert.Nil(t, httpx.ParsePath(r, &v))
  99. m[v.Name] = v.Year
  100. next.ServeHTTP(w, r)
  101. }
  102. }, Route{
  103. Method: http.MethodGet,
  104. Path: "/first/:name/:year",
  105. Handler: handler,
  106. }, Route{
  107. Method: http.MethodGet,
  108. Path: "/second/:name/:year",
  109. Handler: handler,
  110. })
  111. urls := []string{
  112. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  113. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  114. }
  115. for _, route := range rs {
  116. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  117. }
  118. for _, url := range urls {
  119. r, err := http.NewRequest(http.MethodGet, url, nil)
  120. assert.Nil(t, err)
  121. rr := httptest.NewRecorder()
  122. rt.ServeHTTP(rr, r)
  123. assert.Equal(t, "whatever:200000", rr.Body.String())
  124. }
  125. assert.EqualValues(t, map[string]string{
  126. "kevin": "2017",
  127. "wan": "2020",
  128. }, m)
  129. }
  130. func TestMultiMiddlewares(t *testing.T) {
  131. m := make(map[string]string)
  132. rt := router.NewRouter()
  133. handler := func(w http.ResponseWriter, r *http.Request) {
  134. var v struct {
  135. Nickname string `form:"nickname"`
  136. Zipcode int64 `form:"zipcode"`
  137. }
  138. err := httpx.Parse(r, &v)
  139. assert.Nil(t, err)
  140. _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
  141. assert.Nil(t, err)
  142. }
  143. rs := WithMiddlewares([]Middleware{
  144. func(next http.HandlerFunc) http.HandlerFunc {
  145. return func(w http.ResponseWriter, r *http.Request) {
  146. var v struct {
  147. Name string `path:"name"`
  148. Year string `path:"year"`
  149. }
  150. assert.Nil(t, httpx.ParsePath(r, &v))
  151. m[v.Name] = v.Year
  152. next.ServeHTTP(w, r)
  153. }
  154. },
  155. func(next http.HandlerFunc) http.HandlerFunc {
  156. return func(w http.ResponseWriter, r *http.Request) {
  157. var v struct {
  158. Name string `form:"nickname"`
  159. Zipcode string `form:"zipcode"`
  160. }
  161. assert.Nil(t, httpx.ParseForm(r, &v))
  162. assert.NotEmpty(t, m)
  163. m[v.Name] = v.Zipcode + v.Zipcode
  164. next.ServeHTTP(w, r)
  165. }
  166. },
  167. }, Route{
  168. Method: http.MethodGet,
  169. Path: "/first/:name/:year",
  170. Handler: handler,
  171. }, Route{
  172. Method: http.MethodGet,
  173. Path: "/second/:name/:year",
  174. Handler: handler,
  175. })
  176. urls := []string{
  177. "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
  178. "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
  179. }
  180. for _, route := range rs {
  181. assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
  182. }
  183. for _, url := range urls {
  184. r, err := http.NewRequest(http.MethodGet, url, nil)
  185. assert.Nil(t, err)
  186. rr := httptest.NewRecorder()
  187. rt.ServeHTTP(rr, r)
  188. assert.Equal(t, "whatever:200000200000", rr.Body.String())
  189. }
  190. assert.EqualValues(t, map[string]string{
  191. "kevin": "2017",
  192. "wan": "2020",
  193. "whatever": "200000200000",
  194. }, m)
  195. }
  196. func TestWithPriority(t *testing.T) {
  197. var fr featuredRoutes
  198. WithPriority()(&fr)
  199. assert.True(t, fr.priority)
  200. }