middleware_test.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. // Copyright 2014 Manu Martinez-Almeida. All rights reserved.
  2. // Use of this source code is governed by a MIT style
  3. // license that can be found in the LICENSE file.
  4. package gin
  5. import (
  6. "errors"
  7. "net/http"
  8. "strings"
  9. "testing"
  10. "github.com/gin-contrib/sse"
  11. "github.com/stretchr/testify/assert"
  12. )
  13. func TestMiddlewareGeneralCase(t *testing.T) {
  14. signature := ""
  15. router := New()
  16. router.Use(func(c *Context) {
  17. signature += "A"
  18. c.Next()
  19. signature += "B"
  20. })
  21. router.Use(func(c *Context) {
  22. signature += "C"
  23. })
  24. router.GET("/", func(c *Context) {
  25. signature += "D"
  26. })
  27. router.NoRoute(func(c *Context) {
  28. signature += " X "
  29. })
  30. router.NoMethod(func(c *Context) {
  31. signature += " XX "
  32. })
  33. // RUN
  34. w := performRequest(router, "GET", "/")
  35. // TEST
  36. assert.Equal(t, http.StatusOK, w.Code)
  37. assert.Equal(t, "ACDB", signature)
  38. }
  39. func TestMiddlewareNoRoute(t *testing.T) {
  40. signature := ""
  41. router := New()
  42. router.Use(func(c *Context) {
  43. signature += "A"
  44. c.Next()
  45. signature += "B"
  46. })
  47. router.Use(func(c *Context) {
  48. signature += "C"
  49. c.Next()
  50. c.Next()
  51. c.Next()
  52. c.Next()
  53. signature += "D"
  54. })
  55. router.NoRoute(func(c *Context) {
  56. signature += "E"
  57. c.Next()
  58. signature += "F"
  59. }, func(c *Context) {
  60. signature += "G"
  61. c.Next()
  62. signature += "H"
  63. })
  64. router.NoMethod(func(c *Context) {
  65. signature += " X "
  66. })
  67. // RUN
  68. w := performRequest(router, "GET", "/")
  69. // TEST
  70. assert.Equal(t, http.StatusNotFound, w.Code)
  71. assert.Equal(t, "ACEGHFDB", signature)
  72. }
  73. func TestMiddlewareNoMethodEnabled(t *testing.T) {
  74. signature := ""
  75. router := New()
  76. router.HandleMethodNotAllowed = true
  77. router.Use(func(c *Context) {
  78. signature += "A"
  79. c.Next()
  80. signature += "B"
  81. })
  82. router.Use(func(c *Context) {
  83. signature += "C"
  84. c.Next()
  85. signature += "D"
  86. })
  87. router.NoMethod(func(c *Context) {
  88. signature += "E"
  89. c.Next()
  90. signature += "F"
  91. }, func(c *Context) {
  92. signature += "G"
  93. c.Next()
  94. signature += "H"
  95. })
  96. router.NoRoute(func(c *Context) {
  97. signature += " X "
  98. })
  99. router.POST("/", func(c *Context) {
  100. signature += " XX "
  101. })
  102. // RUN
  103. w := performRequest(router, "GET", "/")
  104. // TEST
  105. assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
  106. assert.Equal(t, "ACEGHFDB", signature)
  107. }
  108. func TestMiddlewareNoMethodDisabled(t *testing.T) {
  109. signature := ""
  110. router := New()
  111. router.HandleMethodNotAllowed = false
  112. router.Use(func(c *Context) {
  113. signature += "A"
  114. c.Next()
  115. signature += "B"
  116. })
  117. router.Use(func(c *Context) {
  118. signature += "C"
  119. c.Next()
  120. signature += "D"
  121. })
  122. router.NoMethod(func(c *Context) {
  123. signature += "E"
  124. c.Next()
  125. signature += "F"
  126. }, func(c *Context) {
  127. signature += "G"
  128. c.Next()
  129. signature += "H"
  130. })
  131. router.NoRoute(func(c *Context) {
  132. signature += " X "
  133. })
  134. router.POST("/", func(c *Context) {
  135. signature += " XX "
  136. })
  137. // RUN
  138. w := performRequest(router, "GET", "/")
  139. // TEST
  140. assert.Equal(t, http.StatusNotFound, w.Code)
  141. assert.Equal(t, "AC X DB", signature)
  142. }
  143. func TestMiddlewareAbort(t *testing.T) {
  144. signature := ""
  145. router := New()
  146. router.Use(func(c *Context) {
  147. signature += "A"
  148. })
  149. router.Use(func(c *Context) {
  150. signature += "C"
  151. c.AbortWithStatus(http.StatusUnauthorized)
  152. c.Next()
  153. signature += "D"
  154. })
  155. router.GET("/", func(c *Context) {
  156. signature += " X "
  157. c.Next()
  158. signature += " XX "
  159. })
  160. // RUN
  161. w := performRequest(router, "GET", "/")
  162. // TEST
  163. assert.Equal(t, http.StatusUnauthorized, w.Code)
  164. assert.Equal(t, "ACD", signature)
  165. }
  166. func TestMiddlewareAbortHandlersChainAndNext(t *testing.T) {
  167. signature := ""
  168. router := New()
  169. router.Use(func(c *Context) {
  170. signature += "A"
  171. c.Next()
  172. c.AbortWithStatus(http.StatusGone)
  173. signature += "B"
  174. })
  175. router.GET("/", func(c *Context) {
  176. signature += "C"
  177. c.Next()
  178. })
  179. // RUN
  180. w := performRequest(router, "GET", "/")
  181. // TEST
  182. assert.Equal(t, http.StatusGone, w.Code)
  183. assert.Equal(t, "ACB", signature)
  184. }
  185. // TestFailHandlersChain - ensure that Fail interrupt used middleware in fifo order as
  186. // as well as Abort
  187. func TestMiddlewareFailHandlersChain(t *testing.T) {
  188. // SETUP
  189. signature := ""
  190. router := New()
  191. router.Use(func(context *Context) {
  192. signature += "A"
  193. context.AbortWithError(http.StatusInternalServerError, errors.New("foo")) // nolint: errcheck
  194. })
  195. router.Use(func(context *Context) {
  196. signature += "B"
  197. context.Next()
  198. signature += "C"
  199. })
  200. // RUN
  201. w := performRequest(router, "GET", "/")
  202. // TEST
  203. assert.Equal(t, http.StatusInternalServerError, w.Code)
  204. assert.Equal(t, "A", signature)
  205. }
  206. func TestMiddlewareWrite(t *testing.T) {
  207. router := New()
  208. router.Use(func(c *Context) {
  209. c.String(http.StatusBadRequest, "hola\n")
  210. })
  211. router.Use(func(c *Context) {
  212. c.XML(http.StatusBadRequest, H{"foo": "bar"})
  213. })
  214. router.Use(func(c *Context) {
  215. c.JSON(http.StatusBadRequest, H{"foo": "bar"})
  216. })
  217. router.GET("/", func(c *Context) {
  218. c.JSON(http.StatusBadRequest, H{"foo": "bar"})
  219. }, func(c *Context) {
  220. c.Render(http.StatusBadRequest, sse.Event{
  221. Event: "test",
  222. Data: "message",
  223. })
  224. })
  225. w := performRequest(router, "GET", "/")
  226. assert.Equal(t, http.StatusBadRequest, w.Code)
  227. assert.Equal(t, strings.Replace("hola\n<map><foo>bar</foo></map>{\"foo\":\"bar\"}\n{\"foo\":\"bar\"}\nevent:test\ndata:message\n\n", " ", "", -1), strings.Replace(w.Body.String(), " ", "", -1))
  228. }