mock.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. package mock
  2. import (
  3. "fmt"
  4. "github.com/stretchr/objx"
  5. "github.com/stretchr/testify/assert"
  6. "reflect"
  7. "runtime"
  8. "strings"
  9. "testing"
  10. )
  11. /*
  12. Call
  13. */
  14. // Call represents a method call and is used for setting expectations,
  15. // as well as recording activity.
  16. type Call struct {
  17. // The name of the method that was or will be called.
  18. Method string
  19. // Holds the arguments of the method.
  20. Arguments Arguments
  21. // Holds the arguments that should be returned when
  22. // this method is called.
  23. ReturnArguments Arguments
  24. }
  25. // Mock is the workhorse used to track activity on another object.
  26. // For an example of its usage, refer to the "Example Usage" section at the top of this document.
  27. type Mock struct {
  28. // The method name that is currently
  29. // being referred to by the On method.
  30. onMethodName string
  31. // An array of the arguments that are
  32. // currently being referred to by the On method.
  33. onMethodArguments Arguments
  34. // Represents the calls that are expected of
  35. // an object.
  36. ExpectedCalls []Call
  37. // Holds the calls that were made to this mocked object.
  38. Calls []Call
  39. // TestData holds any data that might be useful for testing. Testify ignores
  40. // this data completely allowing you to do whatever you like with it.
  41. testData objx.Map
  42. }
  43. // TestData holds any data that might be useful for testing. Testify ignores
  44. // this data completely allowing you to do whatever you like with it.
  45. func (m *Mock) TestData() objx.Map {
  46. if m.testData == nil {
  47. m.testData = make(objx.Map)
  48. }
  49. return m.testData
  50. }
  51. /*
  52. Setting expectations
  53. */
  54. // On starts a description of an expectation of the specified method
  55. // being called.
  56. //
  57. // Mock.On("MyMethod", arg1, arg2)
  58. func (m *Mock) On(methodName string, arguments ...interface{}) *Mock {
  59. m.onMethodName = methodName
  60. m.onMethodArguments = arguments
  61. return m
  62. }
  63. // Return finishes a description of an expectation of the method (and arguments)
  64. // specified in the most recent On method call.
  65. //
  66. // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2)
  67. func (m *Mock) Return(returnArguments ...interface{}) *Mock {
  68. m.ExpectedCalls = append(m.ExpectedCalls, Call{m.onMethodName, m.onMethodArguments, returnArguments})
  69. return m
  70. }
  71. /*
  72. Recording and responding to activity
  73. */
  74. func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (bool, *Call) {
  75. for _, call := range m.ExpectedCalls {
  76. if call.Method == method {
  77. _, diffCount := call.Arguments.Diff(arguments)
  78. if diffCount == 0 {
  79. return true, &call
  80. }
  81. }
  82. }
  83. return false, nil
  84. }
  85. func (m *Mock) findClosestCall(method string, arguments ...interface{}) (bool, *Call) {
  86. diffCount := 0
  87. var closestCall *Call = nil
  88. for _, call := range m.ExpectedCalls {
  89. if call.Method == method {
  90. _, tempDiffCount := call.Arguments.Diff(arguments)
  91. if tempDiffCount < diffCount || diffCount == 0 {
  92. diffCount = tempDiffCount
  93. closestCall = &call
  94. }
  95. }
  96. }
  97. if closestCall == nil {
  98. return false, nil
  99. }
  100. return true, closestCall
  101. }
  102. func callString(method string, arguments Arguments, includeArgumentValues bool) string {
  103. var argValsString string = ""
  104. if includeArgumentValues {
  105. var argVals []string
  106. for argIndex, arg := range arguments {
  107. argVals = append(argVals, fmt.Sprintf("%d: %v", argIndex, arg))
  108. }
  109. argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t"))
  110. }
  111. return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString)
  112. }
  113. // Called tells the mock object that a method has been called, and gets an array
  114. // of arguments to return. Panics if the call is unexpected (i.e. not preceeded by
  115. // appropriate .On .Return() calls)
  116. func (m *Mock) Called(arguments ...interface{}) Arguments {
  117. // get the calling function's name
  118. pc, _, _, ok := runtime.Caller(1)
  119. if !ok {
  120. panic("Couldn't get the caller information")
  121. }
  122. functionPath := runtime.FuncForPC(pc).Name()
  123. parts := strings.Split(functionPath, ".")
  124. functionName := parts[len(parts)-1]
  125. found, call := m.findExpectedCall(functionName, arguments...)
  126. if !found {
  127. // we have to fail here - because we don't know what to do
  128. // as the return arguments. This is because:
  129. //
  130. // a) this is a totally unexpected call to this method,
  131. // b) the arguments are not what was expected, or
  132. // c) the developer has forgotten to add an accompanying On...Return pair.
  133. closestFound, closestCall := m.findClosestCall(functionName, arguments...)
  134. if closestFound {
  135. panic(fmt.Sprintf("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n", callString(functionName, arguments, true), callString(functionName, closestCall.Arguments, true)))
  136. } else {
  137. panic(fmt.Sprintf("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", functionName, functionName, callString(functionName, arguments, true), assert.CallerInfo()))
  138. }
  139. }
  140. // add the call
  141. m.Calls = append(m.Calls, Call{functionName, arguments, make([]interface{}, 0)})
  142. return call.ReturnArguments
  143. }
  144. /*
  145. Assertions
  146. */
  147. // AssertExpectationsForObjects asserts that everything specified with On and Return
  148. // of the specified objects was in fact called as expected.
  149. //
  150. // Calls may have occurred in any order.
  151. func AssertExpectationsForObjects(t *testing.T, testObjects ...interface{}) bool {
  152. var success bool = true
  153. for _, obj := range testObjects {
  154. mockObj := obj.(Mock)
  155. success = success && mockObj.AssertExpectations(t)
  156. }
  157. return success
  158. }
  159. // AssertExpectations asserts that everything specified with On and Return was
  160. // in fact called as expected. Calls may have occurred in any order.
  161. func (m *Mock) AssertExpectations(t *testing.T) bool {
  162. var somethingMissing bool = false
  163. var failedExpectations int = 0
  164. // iterate through each expectation
  165. for _, expectedCall := range m.ExpectedCalls {
  166. if !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments) {
  167. somethingMissing = true
  168. failedExpectations++
  169. t.Logf("\u274C\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
  170. } else {
  171. t.Logf("\u2705\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
  172. }
  173. }
  174. if somethingMissing {
  175. t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(m.ExpectedCalls)-failedExpectations, len(m.ExpectedCalls), failedExpectations, assert.CallerInfo())
  176. }
  177. return !somethingMissing
  178. }
  179. // AssertNumberOfCalls asserts that the method was called expectedCalls times.
  180. func (m *Mock) AssertNumberOfCalls(t *testing.T, methodName string, expectedCalls int) bool {
  181. var actualCalls int = 0
  182. for _, call := range m.Calls {
  183. if call.Method == methodName {
  184. actualCalls++
  185. }
  186. }
  187. return assert.Equal(t, actualCalls, expectedCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls))
  188. }
  189. // AssertCalled asserts that the method was called.
  190. func (m *Mock) AssertCalled(t *testing.T, methodName string, arguments ...interface{}) bool {
  191. if !assert.True(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method should have been called with %d argument(s), but was not.", methodName, len(arguments))) {
  192. t.Logf("%s", m.ExpectedCalls)
  193. return false
  194. }
  195. return true
  196. }
  197. // AssertNotCalled asserts that the method was not called.
  198. func (m *Mock) AssertNotCalled(t *testing.T, methodName string, arguments ...interface{}) bool {
  199. if !assert.False(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method was called with %d argument(s), but should NOT have been.", methodName, len(arguments))) {
  200. t.Logf("%s", m.ExpectedCalls)
  201. return false
  202. }
  203. return true
  204. }
  205. func (m *Mock) methodWasCalled(methodName string, arguments []interface{}) bool {
  206. for _, call := range m.Calls {
  207. if call.Method == methodName {
  208. _, differences := call.Arguments.Diff(arguments)
  209. if differences == 0 {
  210. // found the expected call
  211. return true
  212. }
  213. }
  214. }
  215. // we didn't find the expected call
  216. return false
  217. }
  218. /*
  219. Arguments
  220. */
  221. // Arguments holds an array of method arguments or return values.
  222. type Arguments []interface{}
  223. const (
  224. // The "any" argument. Used in Diff and Assert when
  225. // the argument being tested shouldn't be taken into consideration.
  226. Anything string = "mock.Anything"
  227. )
  228. // AnythingOfTypeArgument is a string that contains the type of an argument
  229. // for use when type checking. Used in Diff and Assert.
  230. type AnythingOfTypeArgument string
  231. // AnythingOfType returns an AnythingOfTypeArgument object containing the
  232. // name of the type to check for. Used in Diff and Assert.
  233. //
  234. // For example:
  235. // Assert(t, AnythingOfType("string"), AnythingOfType("int"))
  236. func AnythingOfType(t string) AnythingOfTypeArgument {
  237. return AnythingOfTypeArgument(t)
  238. }
  239. // Get Returns the argument at the specified index.
  240. func (args Arguments) Get(index int) interface{} {
  241. if index+1 > len(args) {
  242. panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args)))
  243. }
  244. return args[index]
  245. }
  246. // Is gets whether the objects match the arguments specified.
  247. func (args Arguments) Is(objects ...interface{}) bool {
  248. for i, obj := range args {
  249. if obj != objects[i] {
  250. return false
  251. }
  252. }
  253. return true
  254. }
  255. // Diff gets a string describing the differences between the arguments
  256. // and the specified objects.
  257. //
  258. // Returns the diff string and number of differences found.
  259. func (args Arguments) Diff(objects []interface{}) (string, int) {
  260. var output string = "\n"
  261. var differences int
  262. var maxArgCount int = len(args)
  263. if len(objects) > maxArgCount {
  264. maxArgCount = len(objects)
  265. }
  266. for i := 0; i < maxArgCount; i++ {
  267. var actual, expected interface{}
  268. if len(args) <= i {
  269. actual = "(Missing)"
  270. } else {
  271. actual = args[i]
  272. }
  273. if len(objects) <= i {
  274. expected = "(Missing)"
  275. } else {
  276. expected = objects[i]
  277. }
  278. if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {
  279. // type checking
  280. if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) {
  281. // not match
  282. differences++
  283. output = fmt.Sprintf("%s\t%d: \u274C type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actual)
  284. }
  285. } else {
  286. // normal checking
  287. if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
  288. // match
  289. output = fmt.Sprintf("%s\t%d: \u2705 %s == %s\n", output, i, expected, actual)
  290. } else {
  291. // not match
  292. differences++
  293. output = fmt.Sprintf("%s\t%d: \u274C %s != %s\n", output, i, expected, actual)
  294. }
  295. }
  296. }
  297. if differences == 0 {
  298. return "No differences.", differences
  299. }
  300. return output, differences
  301. }
  302. // Assert compares the arguments with the specified objects and fails if
  303. // they do not exactly match.
  304. func (args Arguments) Assert(t *testing.T, objects ...interface{}) bool {
  305. // get the differences
  306. diff, diffCount := args.Diff(objects)
  307. if diffCount == 0 {
  308. return true
  309. }
  310. // there are differences... report them...
  311. t.Logf(diff)
  312. t.Errorf("%sArguments do not match.", assert.CallerInfo())
  313. return false
  314. }
  315. // String gets the argument at the specified index. Panics if there is no argument, or
  316. // if the argument is of the wrong type.
  317. //
  318. // If no index is provided, String() returns a complete string representation
  319. // of the arguments.
  320. func (args Arguments) String(indexOrNil ...int) string {
  321. if len(indexOrNil) == 0 {
  322. // normal String() method - return a string representation of the args
  323. var argsStr []string
  324. for _, arg := range args {
  325. argsStr = append(argsStr, fmt.Sprintf("%s", reflect.TypeOf(arg)))
  326. }
  327. return strings.Join(argsStr, ",")
  328. } else if len(indexOrNil) == 1 {
  329. // Index has been specified - get the argument at that index
  330. var index int = indexOrNil[0]
  331. var s string
  332. var ok bool
  333. if s, ok = args.Get(index).(string); !ok {
  334. panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
  335. }
  336. return s
  337. }
  338. panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil)))
  339. }
  340. // Int gets the argument at the specified index. Panics if there is no argument, or
  341. // if the argument is of the wrong type.
  342. func (args Arguments) Int(index int) int {
  343. var s int
  344. var ok bool
  345. if s, ok = args.Get(index).(int); !ok {
  346. panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
  347. }
  348. return s
  349. }
  350. // Error gets the argument at the specified index. Panics if there is no argument, or
  351. // if the argument is of the wrong type.
  352. func (args Arguments) Error(index int) error {
  353. obj := args.Get(index)
  354. var s error
  355. var ok bool
  356. if obj == nil {
  357. return nil
  358. }
  359. if s, ok = obj.(error); !ok {
  360. panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
  361. }
  362. return s
  363. }
  364. // Bool gets the argument at the specified index. Panics if there is no argument, or
  365. // if the argument is of the wrong type.
  366. func (args Arguments) Bool(index int) bool {
  367. var s bool
  368. var ok bool
  369. if s, ok = args.Get(index).(bool); !ok {
  370. panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
  371. }
  372. return s
  373. }