datadriven.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. // Copyright 2018 The Cockroach Authors.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
  12. // implied. See the License for the specific language governing
  13. // permissions and limitations under the License.
  14. package datadriven
  15. import (
  16. "bufio"
  17. "flag"
  18. "fmt"
  19. "io"
  20. "io/ioutil"
  21. "os"
  22. "path/filepath"
  23. "strconv"
  24. "strings"
  25. "testing"
  26. )
  27. var (
  28. rewriteTestFiles = flag.Bool(
  29. "rewrite", false,
  30. "ignore the expected results and rewrite the test files with the actual results from this "+
  31. "run. Used to update tests when a change affects many cases; please verify the testfile "+
  32. "diffs carefully!",
  33. )
  34. )
  35. // RunTest invokes a data-driven test. The test cases are contained in a
  36. // separate test file and are dynamically loaded, parsed, and executed by this
  37. // testing framework. By convention, test files are typically located in a
  38. // sub-directory called "testdata". Each test file has the following format:
  39. //
  40. // <command>[,<command>...] [arg | arg=val | arg=(val1, val2, ...)]...
  41. // <input to the command>
  42. // ----
  43. // <expected results>
  44. //
  45. // The command input can contain blank lines. However, by default, the expected
  46. // results cannot contain blank lines. This alternate syntax allows the use of
  47. // blank lines:
  48. //
  49. // <command>[,<command>...] [arg | arg=val | arg=(val1, val2, ...)]...
  50. // <input to the command>
  51. // ----
  52. // ----
  53. // <expected results>
  54. //
  55. // <more expected results>
  56. // ----
  57. // ----
  58. //
  59. // To execute data-driven tests, pass the path of the test file as well as a
  60. // function which can interpret and execute whatever commands are present in
  61. // the test file. The framework invokes the function, passing it information
  62. // about the test case in a TestData struct. The function then returns the
  63. // actual results of the case, which this function compares with the expected
  64. // results, and either succeeds or fails the test.
  65. func RunTest(t *testing.T, path string, f func(d *TestData) string) {
  66. t.Helper()
  67. file, err := os.OpenFile(path, os.O_RDWR, 0644 /* irrelevant */)
  68. if err != nil {
  69. t.Fatal(err)
  70. }
  71. defer func() {
  72. _ = file.Close()
  73. }()
  74. runTestInternal(t, path, file, f, *rewriteTestFiles)
  75. }
  76. // RunTestFromString is a version of RunTest which takes the contents of a test
  77. // directly.
  78. func RunTestFromString(t *testing.T, input string, f func(d *TestData) string) {
  79. t.Helper()
  80. runTestInternal(t, "<string>" /* optionalPath */, strings.NewReader(input), f, *rewriteTestFiles)
  81. }
  82. func runTestInternal(
  83. t *testing.T, sourceName string, reader io.Reader, f func(d *TestData) string, rewrite bool,
  84. ) {
  85. t.Helper()
  86. r := newTestDataReader(t, sourceName, reader, rewrite)
  87. for r.Next(t) {
  88. t.Run("", func(t *testing.T) {
  89. d := &r.data
  90. actual := func() string {
  91. defer func() {
  92. if r := recover(); r != nil {
  93. fmt.Printf("\npanic during %s:\n%s\n", d.Pos, d.Input)
  94. panic(r)
  95. }
  96. }()
  97. actual := f(d)
  98. if !strings.HasSuffix(actual, "\n") {
  99. actual += "\n"
  100. }
  101. return actual
  102. }()
  103. if r.rewrite != nil {
  104. r.emit("----")
  105. if hasBlankLine(actual) {
  106. r.emit("----")
  107. r.rewrite.WriteString(actual)
  108. r.emit("----")
  109. r.emit("----")
  110. } else {
  111. r.emit(actual)
  112. }
  113. } else if d.Expected != actual {
  114. t.Fatalf("\n%s: %s\nexpected:\n%s\nfound:\n%s", d.Pos, d.Input, d.Expected, actual)
  115. } else if testing.Verbose() {
  116. input := d.Input
  117. if input == "" {
  118. input = "<no input to command>"
  119. }
  120. // TODO(tbg): it's awkward to reproduce the args, but it would be helpful.
  121. fmt.Printf("\n%s:\n%s [%d args]\n%s\n----\n%s", d.Pos, d.Cmd, len(d.CmdArgs), input, actual)
  122. }
  123. })
  124. if t.Failed() {
  125. t.FailNow()
  126. }
  127. }
  128. if r.rewrite != nil {
  129. data := r.rewrite.Bytes()
  130. if l := len(data); l > 2 && data[l-1] == '\n' && data[l-2] == '\n' {
  131. data = data[:l-1]
  132. }
  133. if dest, ok := reader.(*os.File); ok {
  134. if _, err := dest.WriteAt(data, 0); err != nil {
  135. t.Fatal(err)
  136. }
  137. if err := dest.Truncate(int64(len(data))); err != nil {
  138. t.Fatal(err)
  139. }
  140. if err := dest.Sync(); err != nil {
  141. t.Fatal(err)
  142. }
  143. } else {
  144. t.Logf("input is not a file; rewritten output is:\n%s", data)
  145. }
  146. }
  147. }
  148. // Walk goes through all the files in a subdirectory, creating subtests to match
  149. // the file hierarchy; for each "leaf" file, the given function is called.
  150. //
  151. // This can be used in conjunction with RunTest. For example:
  152. //
  153. // datadriven.Walk(t, path, func (t *testing.T, path string) {
  154. // // initialize per-test state
  155. // datadriven.RunTest(t, path, func (d *datadriven.TestData) {
  156. // // ...
  157. // }
  158. // }
  159. //
  160. // Files:
  161. // testdata/typing
  162. // testdata/logprops/scan
  163. // testdata/logprops/select
  164. //
  165. // If path is "testdata/typing", the function is called once and no subtests
  166. // care created.
  167. //
  168. // If path is "testdata/logprops", the function is called two times, in
  169. // separate subtests /scan, /select.
  170. //
  171. // If path is "testdata", the function is called three times, in subtest
  172. // hierarchy /typing, /logprops/scan, /logprops/select.
  173. //
  174. func Walk(t *testing.T, path string, f func(t *testing.T, path string)) {
  175. finfo, err := os.Stat(path)
  176. if err != nil {
  177. t.Fatal(err)
  178. }
  179. if !finfo.IsDir() {
  180. f(t, path)
  181. return
  182. }
  183. files, err := ioutil.ReadDir(path)
  184. if err != nil {
  185. t.Fatal(err)
  186. }
  187. for _, file := range files {
  188. t.Run(file.Name(), func(t *testing.T) {
  189. Walk(t, filepath.Join(path, file.Name()), f)
  190. })
  191. }
  192. }
  193. // TestData contains information about one data-driven test case that was
  194. // parsed from the test file.
  195. type TestData struct {
  196. Pos string // reader and line number
  197. // Cmd is the first string on the directive line (up to the first whitespace).
  198. Cmd string
  199. CmdArgs []CmdArg
  200. Input string
  201. Expected string
  202. }
  203. // ScanArgs looks up the first CmdArg matching the given key and scans it into
  204. // the given destinations in order. If the arg does not exist, the number of
  205. // destinations does not match that of the arguments, or a destination can not
  206. // be populated from its matching value, a fatal error results.
  207. //
  208. // For example, for a TestData originating from
  209. //
  210. // cmd arg1=50 arg2=yoruba arg3=(50, 50, 50)
  211. //
  212. // the following would be valid:
  213. //
  214. // var i1, i2, i3, i4 int
  215. // var s string
  216. // td.ScanArgs(t, "arg1", &i1)
  217. // td.ScanArgs(t, "arg2", &s)
  218. // td.ScanArgs(t, "arg3", &i2, &i3, &i4)
  219. func (td *TestData) ScanArgs(t *testing.T, key string, dests ...interface{}) {
  220. t.Helper()
  221. var arg CmdArg
  222. for i := range td.CmdArgs {
  223. if td.CmdArgs[i].Key == key {
  224. arg = td.CmdArgs[i]
  225. break
  226. }
  227. }
  228. if arg.Key == "" {
  229. t.Fatalf("missing argument: %s", key)
  230. }
  231. if len(dests) != len(arg.Vals) {
  232. t.Fatalf("%s: got %d destinations, but %d values", arg.Key, len(dests), len(arg.Vals))
  233. }
  234. for i := range dests {
  235. arg.Scan(t, i, dests[i])
  236. }
  237. }
  238. // CmdArg contains information about an argument on the directive line. An
  239. // argument is specified in one of the following forms:
  240. // - argument
  241. // - argument=value
  242. // - argument=(values, ...)
  243. type CmdArg struct {
  244. Key string
  245. Vals []string
  246. }
  247. func (arg CmdArg) String() string {
  248. switch len(arg.Vals) {
  249. case 0:
  250. return arg.Key
  251. case 1:
  252. return fmt.Sprintf("%s=%s", arg.Key, arg.Vals[0])
  253. default:
  254. return fmt.Sprintf("%s=(%s)", arg.Key, strings.Join(arg.Vals, ", "))
  255. }
  256. }
  257. // Scan attempts to parse the value at index i into the dest.
  258. func (arg CmdArg) Scan(t *testing.T, i int, dest interface{}) {
  259. if i < 0 || i >= len(arg.Vals) {
  260. t.Fatalf("cannot scan index %d of key %s", i, arg.Key)
  261. }
  262. val := arg.Vals[i]
  263. switch dest := dest.(type) {
  264. case *string:
  265. *dest = val
  266. case *int:
  267. n, err := strconv.ParseInt(val, 10, 64)
  268. if err != nil {
  269. t.Fatal(err)
  270. }
  271. *dest = int(n) // assume 64bit ints
  272. case *uint64:
  273. n, err := strconv.ParseUint(val, 10, 64)
  274. if err != nil {
  275. t.Fatal(err)
  276. }
  277. *dest = n
  278. case *bool:
  279. b, err := strconv.ParseBool(val)
  280. if err != nil {
  281. t.Fatal(err)
  282. }
  283. *dest = b
  284. default:
  285. t.Fatalf("unsupported type %T for destination #%d (might be easy to add it)", dest, i+1)
  286. }
  287. }
  288. // Fatalf wraps a fatal testing error with test file position information, so
  289. // that it's easy to locate the source of the error.
  290. func (td TestData) Fatalf(tb testing.TB, format string, args ...interface{}) {
  291. tb.Helper()
  292. tb.Fatalf("%s: %s", td.Pos, fmt.Sprintf(format, args...))
  293. }
  294. func hasBlankLine(s string) bool {
  295. scanner := bufio.NewScanner(strings.NewReader(s))
  296. for scanner.Scan() {
  297. if strings.TrimSpace(scanner.Text()) == "" {
  298. return true
  299. }
  300. }
  301. return false
  302. }