datadriven.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. // Copyright 2018 The Cockroach Authors.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
  12. // implied. See the License for the specific language governing
  13. // permissions and limitations under the License.
  14. package datadriven
  15. import (
  16. "bufio"
  17. "flag"
  18. "fmt"
  19. "io"
  20. "io/ioutil"
  21. "os"
  22. "path/filepath"
  23. "strconv"
  24. "strings"
  25. "testing"
  26. )
  27. var (
  28. rewriteTestFiles = flag.Bool(
  29. "rewrite", false,
  30. "ignore the expected results and rewrite the test files with the actual results from this "+
  31. "run. Used to update tests when a change affects many cases; please verify the testfile "+
  32. "diffs carefully!",
  33. )
  34. )
  35. // RunTest invokes a data-driven test. The test cases are contained in a
  36. // separate test file and are dynamically loaded, parsed, and executed by this
  37. // testing framework. By convention, test files are typically located in a
  38. // sub-directory called "testdata". Each test file has the following format:
  39. //
  40. // <command>[,<command>...] [arg | arg=val | arg=(val1, val2, ...)]...
  41. // <input to the command>
  42. // ----
  43. // <expected results>
  44. //
  45. // The command input can contain blank lines. However, by default, the expected
  46. // results cannot contain blank lines. This alternate syntax allows the use of
  47. // blank lines:
  48. //
  49. // <command>[,<command>...] [arg | arg=val | arg=(val1, val2, ...)]...
  50. // <input to the command>
  51. // ----
  52. // ----
  53. // <expected results>
  54. //
  55. // <more expected results>
  56. // ----
  57. // ----
  58. //
  59. // To execute data-driven tests, pass the path of the test file as well as a
  60. // function which can interpret and execute whatever commands are present in
  61. // the test file. The framework invokes the function, passing it information
  62. // about the test case in a TestData struct. The function then returns the
  63. // actual results of the case, which this function compares with the expected
  64. // results, and either succeeds or fails the test.
  65. func RunTest(t *testing.T, path string, f func(d *TestData) string) {
  66. t.Helper()
  67. file, err := os.OpenFile(path, os.O_RDWR, 0644 /* irrelevant */)
  68. if err != nil {
  69. t.Fatal(err)
  70. }
  71. defer func() {
  72. _ = file.Close()
  73. }()
  74. runTestInternal(t, path, file, f, *rewriteTestFiles)
  75. }
  76. // RunTestFromString is a version of RunTest which takes the contents of a test
  77. // directly.
  78. func RunTestFromString(t *testing.T, input string, f func(d *TestData) string) {
  79. t.Helper()
  80. runTestInternal(t, "<string>" /* optionalPath */, strings.NewReader(input), f, *rewriteTestFiles)
  81. }
  82. func runTestInternal(
  83. t *testing.T, sourceName string, reader io.Reader, f func(d *TestData) string, rewrite bool,
  84. ) {
  85. t.Helper()
  86. r := newTestDataReader(t, sourceName, reader, rewrite)
  87. for r.Next(t) {
  88. d := &r.data
  89. actual := func() string {
  90. defer func() {
  91. if r := recover(); r != nil {
  92. fmt.Printf("\npanic during %s:\n%s\n", d.Pos, d.Input)
  93. panic(r)
  94. }
  95. }()
  96. return f(d)
  97. }()
  98. if r.rewrite != nil {
  99. r.emit("----")
  100. if hasBlankLine(actual) {
  101. r.emit("----")
  102. r.rewrite.WriteString(actual)
  103. r.emit("----")
  104. r.emit("----")
  105. } else {
  106. r.emit(actual)
  107. }
  108. } else if d.Expected != actual {
  109. t.Fatalf("\n%s: %s\nexpected:\n%s\nfound:\n%s", d.Pos, d.Input, d.Expected, actual)
  110. } else if testing.Verbose() {
  111. input := d.Input
  112. if input == "" {
  113. input = "<no input to command>"
  114. }
  115. // TODO(tbg): it's awkward to reproduce the args, but it would be helpful.
  116. fmt.Printf("\n%s:\n%s [%d args]\n%s\n----\n%s", d.Pos, d.Cmd, len(d.CmdArgs), input, actual)
  117. }
  118. }
  119. if r.rewrite != nil {
  120. data := r.rewrite.Bytes()
  121. if l := len(data); l > 2 && data[l-1] == '\n' && data[l-2] == '\n' {
  122. data = data[:l-1]
  123. }
  124. if dest, ok := reader.(*os.File); ok {
  125. if _, err := dest.WriteAt(data, 0); err != nil {
  126. t.Fatal(err)
  127. }
  128. if err := dest.Truncate(int64(len(data))); err != nil {
  129. t.Fatal(err)
  130. }
  131. if err := dest.Sync(); err != nil {
  132. t.Fatal(err)
  133. }
  134. } else {
  135. t.Logf("input is not a file; rewritten output is:\n%s", data)
  136. }
  137. }
  138. }
  139. // Walk goes through all the files in a subdirectory, creating subtests to match
  140. // the file hierarchy; for each "leaf" file, the given function is called.
  141. //
  142. // This can be used in conjunction with RunTest. For example:
  143. //
  144. // datadriven.Walk(t, path, func (t *testing.T, path string) {
  145. // // initialize per-test state
  146. // datadriven.RunTest(t, path, func (d *datadriven.TestData) {
  147. // // ...
  148. // }
  149. // }
  150. //
  151. // Files:
  152. // testdata/typing
  153. // testdata/logprops/scan
  154. // testdata/logprops/select
  155. //
  156. // If path is "testdata/typing", the function is called once and no subtests
  157. // care created.
  158. //
  159. // If path is "testdata/logprops", the function is called two times, in
  160. // separate subtests /scan, /select.
  161. //
  162. // If path is "testdata", the function is called three times, in subtest
  163. // hierarchy /typing, /logprops/scan, /logprops/select.
  164. //
  165. func Walk(t *testing.T, path string, f func(t *testing.T, path string)) {
  166. finfo, err := os.Stat(path)
  167. if err != nil {
  168. t.Fatal(err)
  169. }
  170. if !finfo.IsDir() {
  171. f(t, path)
  172. return
  173. }
  174. files, err := ioutil.ReadDir(path)
  175. if err != nil {
  176. t.Fatal(err)
  177. }
  178. for _, file := range files {
  179. t.Run(file.Name(), func(t *testing.T) {
  180. Walk(t, filepath.Join(path, file.Name()), f)
  181. })
  182. }
  183. }
  184. // TestData contains information about one data-driven test case that was
  185. // parsed from the test file.
  186. type TestData struct {
  187. Pos string // reader and line number
  188. // Cmd is the first string on the directive line (up to the first whitespace).
  189. Cmd string
  190. CmdArgs []CmdArg
  191. Input string
  192. Expected string
  193. }
  194. // ScanArgs looks up the first CmdArg matching the given key and scans it into
  195. // the given destinations in order. If the arg does not exist, the number of
  196. // destinations does not match that of the arguments, or a destination can not
  197. // be populated from its matching value, a fatal error results.
  198. //
  199. // For example, for a TestData originating from
  200. //
  201. // cmd arg1=50 arg2=yoruba arg3=(50, 50, 50)
  202. //
  203. // the following would be valid:
  204. //
  205. // var i1, i2, i3, i4 int
  206. // var s string
  207. // td.ScanArgs(t, "arg1", &i1)
  208. // td.ScanArgs(t, "arg2", &s)
  209. // td.ScanArgs(t, "arg3", &i2, &i3, &i4)
  210. func (td *TestData) ScanArgs(t *testing.T, key string, dests ...interface{}) {
  211. t.Helper()
  212. var arg CmdArg
  213. for i := range td.CmdArgs {
  214. if td.CmdArgs[i].Key == key {
  215. arg = td.CmdArgs[i]
  216. break
  217. }
  218. }
  219. if arg.Key == "" {
  220. t.Fatalf("missing argument: %s", key)
  221. }
  222. if len(dests) != len(arg.Vals) {
  223. t.Fatalf("%s: got %d destinations, but %d values", arg.Key, len(dests), len(arg.Vals))
  224. }
  225. for i := range dests {
  226. arg.Scan(t, i, dests[i])
  227. }
  228. }
  229. // CmdArg contains information about an argument on the directive line. An
  230. // argument is specified in one of the following forms:
  231. // - argument
  232. // - argument=value
  233. // - argument=(values, ...)
  234. type CmdArg struct {
  235. Key string
  236. Vals []string
  237. }
  238. func (arg CmdArg) String() string {
  239. switch len(arg.Vals) {
  240. case 0:
  241. return arg.Key
  242. case 1:
  243. return fmt.Sprintf("%s=%s", arg.Key, arg.Vals[0])
  244. default:
  245. return fmt.Sprintf("%s=(%s)", arg.Key, strings.Join(arg.Vals, ", "))
  246. }
  247. }
  248. // Scan attempts to parse the value at index i into the dest.
  249. func (arg CmdArg) Scan(t *testing.T, i int, dest interface{}) {
  250. if i < 0 || i >= len(arg.Vals) {
  251. t.Fatalf("cannot scan index %d of key %s", i, arg.Key)
  252. }
  253. val := arg.Vals[i]
  254. switch dest := dest.(type) {
  255. case *string:
  256. *dest = val
  257. case *int:
  258. n, err := strconv.ParseInt(val, 10, 64)
  259. if err != nil {
  260. t.Fatal(err)
  261. }
  262. *dest = int(n) // assume 64bit ints
  263. case *uint64:
  264. n, err := strconv.ParseUint(val, 10, 64)
  265. if err != nil {
  266. t.Fatal(err)
  267. }
  268. *dest = n
  269. case *bool:
  270. b, err := strconv.ParseBool(val)
  271. if err != nil {
  272. t.Fatal(err)
  273. }
  274. *dest = b
  275. default:
  276. t.Fatalf("unsupported type %T for destination #%d (might be easy to add it)", dest, i+1)
  277. }
  278. }
  279. // Fatalf wraps a fatal testing error with test file position information, so
  280. // that it's easy to locate the source of the error.
  281. func (td TestData) Fatalf(tb testing.TB, format string, args ...interface{}) {
  282. tb.Helper()
  283. tb.Fatalf("%s: %s", td.Pos, fmt.Sprintf(format, args...))
  284. }
  285. func hasBlankLine(s string) bool {
  286. scanner := bufio.NewScanner(strings.NewReader(s))
  287. for scanner.Scan() {
  288. if strings.TrimSpace(scanner.Text()) == "" {
  289. return true
  290. }
  291. }
  292. return false
  293. }