mock.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. package mock
  2. import (
  3. "fmt"
  4. "github.com/stretchr/objx"
  5. "github.com/stretchr/testify/assert"
  6. "reflect"
  7. "runtime"
  8. "strings"
  9. "testing"
  10. )
  11. /*
  12. Call
  13. */
  14. // Call represents a method call and is used for setting expectations,
  15. // as well as recording activity.
  16. type Call struct {
  17. // The name of the method that was or will be called.
  18. Method string
  19. // Holds the arguments of the method.
  20. Arguments Arguments
  21. // Holds the arguments that should be returned when
  22. // this method is called.
  23. ReturnArguments Arguments
  24. // The number of times to return the return arguments when setting
  25. // expectations. 0 means to always return the value.
  26. Repeatability int
  27. }
  28. // Mock is the workhorse used to track activity on another object.
  29. // For an example of its usage, refer to the "Example Usage" section at the top of this document.
  30. type Mock struct {
  31. // The method name that is currently
  32. // being referred to by the On method.
  33. onMethodName string
  34. // An array of the arguments that are
  35. // currently being referred to by the On method.
  36. onMethodArguments Arguments
  37. // Represents the calls that are expected of
  38. // an object.
  39. ExpectedCalls []Call
  40. // Holds the calls that were made to this mocked object.
  41. Calls []Call
  42. // TestData holds any data that might be useful for testing. Testify ignores
  43. // this data completely allowing you to do whatever you like with it.
  44. testData objx.Map
  45. }
  46. // TestData holds any data that might be useful for testing. Testify ignores
  47. // this data completely allowing you to do whatever you like with it.
  48. func (m *Mock) TestData() objx.Map {
  49. if m.testData == nil {
  50. m.testData = make(objx.Map)
  51. }
  52. return m.testData
  53. }
  54. /*
  55. Setting expectations
  56. */
  57. // On starts a description of an expectation of the specified method
  58. // being called.
  59. //
  60. // Mock.On("MyMethod", arg1, arg2)
  61. func (m *Mock) On(methodName string, arguments ...interface{}) *Mock {
  62. m.onMethodName = methodName
  63. m.onMethodArguments = arguments
  64. return m
  65. }
  66. // Return finishes a description of an expectation of the method (and arguments)
  67. // specified in the most recent On method call.
  68. //
  69. // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2)
  70. func (m *Mock) Return(returnArguments ...interface{}) *Mock {
  71. m.ExpectedCalls = append(m.ExpectedCalls, Call{m.onMethodName, m.onMethodArguments, returnArguments, 0})
  72. return m
  73. }
  74. // Once indicates that that the mock should only return the value once.
  75. //
  76. // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
  77. func (m *Mock) Once() {
  78. m.ExpectedCalls[len(m.ExpectedCalls)-1].Repeatability = 1
  79. }
  80. // Twice indicates that that the mock should only return the value twice.
  81. //
  82. // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
  83. func (m *Mock) Twice() {
  84. m.ExpectedCalls[len(m.ExpectedCalls)-1].Repeatability = 2
  85. }
  86. // Times indicates that that the mock should only return the indicated number
  87. // of times.
  88. //
  89. // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
  90. func (m *Mock) Times(i int) {
  91. m.ExpectedCalls[len(m.ExpectedCalls)-1].Repeatability = i
  92. }
  93. /*
  94. Recording and responding to activity
  95. */
  96. func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) {
  97. for i, call := range m.ExpectedCalls {
  98. if call.Method == method && call.Repeatability > -1 {
  99. _, diffCount := call.Arguments.Diff(arguments)
  100. if diffCount == 0 {
  101. return i, &call
  102. }
  103. }
  104. }
  105. return -1, nil
  106. }
  107. func (m *Mock) findClosestCall(method string, arguments ...interface{}) (bool, *Call) {
  108. diffCount := 0
  109. var closestCall *Call = nil
  110. for _, call := range m.ExpectedCalls {
  111. if call.Method == method {
  112. _, tempDiffCount := call.Arguments.Diff(arguments)
  113. if tempDiffCount < diffCount || diffCount == 0 {
  114. diffCount = tempDiffCount
  115. closestCall = &call
  116. }
  117. }
  118. }
  119. if closestCall == nil {
  120. return false, nil
  121. }
  122. return true, closestCall
  123. }
  124. func callString(method string, arguments Arguments, includeArgumentValues bool) string {
  125. var argValsString string = ""
  126. if includeArgumentValues {
  127. var argVals []string
  128. for argIndex, arg := range arguments {
  129. argVals = append(argVals, fmt.Sprintf("%d: %v", argIndex, arg))
  130. }
  131. argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t"))
  132. }
  133. return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString)
  134. }
  135. // Called tells the mock object that a method has been called, and gets an array
  136. // of arguments to return. Panics if the call is unexpected (i.e. not preceeded by
  137. // appropriate .On .Return() calls)
  138. func (m *Mock) Called(arguments ...interface{}) Arguments {
  139. // get the calling function's name
  140. pc, _, _, ok := runtime.Caller(1)
  141. if !ok {
  142. panic("Couldn't get the caller information")
  143. }
  144. functionPath := runtime.FuncForPC(pc).Name()
  145. parts := strings.Split(functionPath, ".")
  146. functionName := parts[len(parts)-1]
  147. found, call := m.findExpectedCall(functionName, arguments...)
  148. switch {
  149. case found < 0:
  150. // we have to fail here - because we don't know what to do
  151. // as the return arguments. This is because:
  152. //
  153. // a) this is a totally unexpected call to this method,
  154. // b) the arguments are not what was expected, or
  155. // c) the developer has forgotten to add an accompanying On...Return pair.
  156. closestFound, closestCall := m.findClosestCall(functionName, arguments...)
  157. if closestFound {
  158. panic(fmt.Sprintf("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n", callString(functionName, arguments, true), callString(functionName, closestCall.Arguments, true)))
  159. } else {
  160. panic(fmt.Sprintf("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", functionName, functionName, callString(functionName, arguments, true), assert.CallerInfo()))
  161. }
  162. case call.Repeatability == 1:
  163. call.Repeatability = -1
  164. m.ExpectedCalls[found] = *call
  165. case call.Repeatability > 1:
  166. call.Repeatability -= 1
  167. m.ExpectedCalls[found] = *call
  168. }
  169. // add the call
  170. m.Calls = append(m.Calls, Call{functionName, arguments, make([]interface{}, 0), 0})
  171. return call.ReturnArguments
  172. }
  173. /*
  174. Assertions
  175. */
  176. // AssertExpectationsForObjects asserts that everything specified with On and Return
  177. // of the specified objects was in fact called as expected.
  178. //
  179. // Calls may have occurred in any order.
  180. func AssertExpectationsForObjects(t *testing.T, testObjects ...interface{}) bool {
  181. var success bool = true
  182. for _, obj := range testObjects {
  183. mockObj := obj.(Mock)
  184. success = success && mockObj.AssertExpectations(t)
  185. }
  186. return success
  187. }
  188. // AssertExpectations asserts that everything specified with On and Return was
  189. // in fact called as expected. Calls may have occurred in any order.
  190. func (m *Mock) AssertExpectations(t *testing.T) bool {
  191. var somethingMissing bool = false
  192. var failedExpectations int = 0
  193. // iterate through each expectation
  194. for _, expectedCall := range m.ExpectedCalls {
  195. switch {
  196. case !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments):
  197. somethingMissing = true
  198. failedExpectations++
  199. t.Logf("\u274C\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
  200. case expectedCall.Repeatability > 0:
  201. somethingMissing = true
  202. failedExpectations++
  203. default:
  204. t.Logf("\u2705\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
  205. }
  206. }
  207. if somethingMissing {
  208. t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(m.ExpectedCalls)-failedExpectations, len(m.ExpectedCalls), failedExpectations, assert.CallerInfo())
  209. }
  210. return !somethingMissing
  211. }
  212. // AssertNumberOfCalls asserts that the method was called expectedCalls times.
  213. func (m *Mock) AssertNumberOfCalls(t *testing.T, methodName string, expectedCalls int) bool {
  214. var actualCalls int = 0
  215. for _, call := range m.Calls {
  216. if call.Method == methodName {
  217. actualCalls++
  218. }
  219. }
  220. return assert.Equal(t, actualCalls, expectedCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls))
  221. }
  222. // AssertCalled asserts that the method was called.
  223. func (m *Mock) AssertCalled(t *testing.T, methodName string, arguments ...interface{}) bool {
  224. if !assert.True(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method should have been called with %d argument(s), but was not.", methodName, len(arguments))) {
  225. t.Logf("%s", m.ExpectedCalls)
  226. return false
  227. }
  228. return true
  229. }
  230. // AssertNotCalled asserts that the method was not called.
  231. func (m *Mock) AssertNotCalled(t *testing.T, methodName string, arguments ...interface{}) bool {
  232. if !assert.False(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method was called with %d argument(s), but should NOT have been.", methodName, len(arguments))) {
  233. t.Logf("%s", m.ExpectedCalls)
  234. return false
  235. }
  236. return true
  237. }
  238. func (m *Mock) methodWasCalled(methodName string, arguments []interface{}) bool {
  239. for _, call := range m.Calls {
  240. if call.Method == methodName {
  241. _, differences := call.Arguments.Diff(arguments)
  242. if differences == 0 {
  243. // found the expected call
  244. return true
  245. }
  246. }
  247. }
  248. // we didn't find the expected call
  249. return false
  250. }
  251. /*
  252. Arguments
  253. */
  254. // Arguments holds an array of method arguments or return values.
  255. type Arguments []interface{}
  256. const (
  257. // The "any" argument. Used in Diff and Assert when
  258. // the argument being tested shouldn't be taken into consideration.
  259. Anything string = "mock.Anything"
  260. )
  261. // AnythingOfTypeArgument is a string that contains the type of an argument
  262. // for use when type checking. Used in Diff and Assert.
  263. type AnythingOfTypeArgument string
  264. // AnythingOfType returns an AnythingOfTypeArgument object containing the
  265. // name of the type to check for. Used in Diff and Assert.
  266. //
  267. // For example:
  268. // Assert(t, AnythingOfType("string"), AnythingOfType("int"))
  269. func AnythingOfType(t string) AnythingOfTypeArgument {
  270. return AnythingOfTypeArgument(t)
  271. }
  272. // Get Returns the argument at the specified index.
  273. func (args Arguments) Get(index int) interface{} {
  274. if index+1 > len(args) {
  275. panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args)))
  276. }
  277. return args[index]
  278. }
  279. // Is gets whether the objects match the arguments specified.
  280. func (args Arguments) Is(objects ...interface{}) bool {
  281. for i, obj := range args {
  282. if obj != objects[i] {
  283. return false
  284. }
  285. }
  286. return true
  287. }
  288. // Diff gets a string describing the differences between the arguments
  289. // and the specified objects.
  290. //
  291. // Returns the diff string and number of differences found.
  292. func (args Arguments) Diff(objects []interface{}) (string, int) {
  293. var output string = "\n"
  294. var differences int
  295. var maxArgCount int = len(args)
  296. if len(objects) > maxArgCount {
  297. maxArgCount = len(objects)
  298. }
  299. for i := 0; i < maxArgCount; i++ {
  300. var actual, expected interface{}
  301. if len(objects) <= i {
  302. actual = "(Missing)"
  303. } else {
  304. actual = objects[i]
  305. }
  306. if len(args) <= i {
  307. expected = "(Missing)"
  308. } else {
  309. expected = args[i]
  310. }
  311. if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {
  312. // type checking
  313. if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) {
  314. // not match
  315. differences++
  316. output = fmt.Sprintf("%s\t%d: \u274C type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actual)
  317. }
  318. } else {
  319. // normal checking
  320. if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
  321. // match
  322. output = fmt.Sprintf("%s\t%d: \u2705 %s == %s\n", output, i, actual, expected)
  323. } else {
  324. // not match
  325. differences++
  326. output = fmt.Sprintf("%s\t%d: \u274C %s != %s\n", output, i, actual, expected)
  327. }
  328. }
  329. }
  330. if differences == 0 {
  331. return "No differences.", differences
  332. }
  333. return output, differences
  334. }
  335. // Assert compares the arguments with the specified objects and fails if
  336. // they do not exactly match.
  337. func (args Arguments) Assert(t *testing.T, objects ...interface{}) bool {
  338. // get the differences
  339. diff, diffCount := args.Diff(objects)
  340. if diffCount == 0 {
  341. return true
  342. }
  343. // there are differences... report them...
  344. t.Logf(diff)
  345. t.Errorf("%sArguments do not match.", assert.CallerInfo())
  346. return false
  347. }
  348. // String gets the argument at the specified index. Panics if there is no argument, or
  349. // if the argument is of the wrong type.
  350. //
  351. // If no index is provided, String() returns a complete string representation
  352. // of the arguments.
  353. func (args Arguments) String(indexOrNil ...int) string {
  354. if len(indexOrNil) == 0 {
  355. // normal String() method - return a string representation of the args
  356. var argsStr []string
  357. for _, arg := range args {
  358. argsStr = append(argsStr, fmt.Sprintf("%s", reflect.TypeOf(arg)))
  359. }
  360. return strings.Join(argsStr, ",")
  361. } else if len(indexOrNil) == 1 {
  362. // Index has been specified - get the argument at that index
  363. var index int = indexOrNil[0]
  364. var s string
  365. var ok bool
  366. if s, ok = args.Get(index).(string); !ok {
  367. panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
  368. }
  369. return s
  370. }
  371. panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil)))
  372. }
  373. // Int gets the argument at the specified index. Panics if there is no argument, or
  374. // if the argument is of the wrong type.
  375. func (args Arguments) Int(index int) int {
  376. var s int
  377. var ok bool
  378. if s, ok = args.Get(index).(int); !ok {
  379. panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
  380. }
  381. return s
  382. }
  383. // Error gets the argument at the specified index. Panics if there is no argument, or
  384. // if the argument is of the wrong type.
  385. func (args Arguments) Error(index int) error {
  386. obj := args.Get(index)
  387. var s error
  388. var ok bool
  389. if obj == nil {
  390. return nil
  391. }
  392. if s, ok = obj.(error); !ok {
  393. panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
  394. }
  395. return s
  396. }
  397. // Bool gets the argument at the specified index. Panics if there is no argument, or
  398. // if the argument is of the wrong type.
  399. func (args Arguments) Bool(index int) bool {
  400. var s bool
  401. var ok bool
  402. if s, ok = args.Get(index).(bool); !ok {
  403. panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
  404. }
  405. return s
  406. }