mksyscall_windows.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build ignore
  5. /*
  6. mksyscall_windows generates windows system call bodies
  7. It parses all files specified on command line containing function
  8. prototypes (like syscall_windows.go) and prints system call bodies
  9. to standard output.
  10. The prototypes are marked by lines beginning with "//sys" and read
  11. like func declarations if //sys is replaced by func, but:
  12. * The parameter lists must give a name for each argument. This
  13. includes return parameters.
  14. * The parameter lists must give a type for each argument:
  15. the (x, y, z int) shorthand is not allowed.
  16. * If the return parameter is an error number, it must be named err.
  17. * If go func name needs to be different from it's winapi dll name,
  18. the winapi name could be specified at the end, after "=" sign, like
  19. //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA
  20. * Each function that returns err needs to supply a condition, that
  21. return value of winapi will be tested against to detect failure.
  22. This would set err to windows "last-error", otherwise it will be nil.
  23. The value can be provided at end of //sys declaration, like
  24. //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA
  25. and is [failretval==0] by default.
  26. Usage:
  27. mksyscall_windows [flags] [path ...]
  28. The flags are:
  29. -trace
  30. Generate print statement after every syscall.
  31. */
  32. package main
  33. import (
  34. "bufio"
  35. "errors"
  36. "flag"
  37. "fmt"
  38. "go/parser"
  39. "go/token"
  40. "io"
  41. "log"
  42. "os"
  43. "strconv"
  44. "strings"
  45. "text/template"
  46. )
  47. var PrintTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
  48. func trim(s string) string {
  49. return strings.Trim(s, " \t")
  50. }
  51. var packageName string
  52. func packagename() string {
  53. return packageName
  54. }
  55. func windowsdot() string {
  56. if packageName == "windows" {
  57. return ""
  58. }
  59. return "windows."
  60. }
  61. // Param is function parameter
  62. type Param struct {
  63. Name string
  64. Type string
  65. fn *Fn
  66. tmpVarIdx int
  67. }
  68. // tmpVar returns temp variable name that will be used to represent p during syscall.
  69. func (p *Param) tmpVar() string {
  70. if p.tmpVarIdx < 0 {
  71. p.tmpVarIdx = p.fn.curTmpVarIdx
  72. p.fn.curTmpVarIdx++
  73. }
  74. return fmt.Sprintf("_p%d", p.tmpVarIdx)
  75. }
  76. // BoolTmpVarCode returns source code for bool temp variable.
  77. func (p *Param) BoolTmpVarCode() string {
  78. const code = `var %s uint32
  79. if %s {
  80. %s = 1
  81. } else {
  82. %s = 0
  83. }`
  84. tmp := p.tmpVar()
  85. return fmt.Sprintf(code, tmp, p.Name, tmp, tmp)
  86. }
  87. // SliceTmpVarCode returns source code for slice temp variable.
  88. func (p *Param) SliceTmpVarCode() string {
  89. const code = `var %s *%s
  90. if len(%s) > 0 {
  91. %s = &%s[0]
  92. }`
  93. tmp := p.tmpVar()
  94. return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name)
  95. }
  96. // StringTmpVarCode returns source code for string temp variable.
  97. func (p *Param) StringTmpVarCode() string {
  98. errvar := p.fn.Rets.ErrorVarName()
  99. if errvar == "" {
  100. errvar = "_"
  101. }
  102. tmp := p.tmpVar()
  103. const code = `var %s %s
  104. %s, %s = %s(%s)`
  105. s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name)
  106. if errvar == "-" {
  107. return s
  108. }
  109. const morecode = `
  110. if %s != nil {
  111. return
  112. }`
  113. return s + fmt.Sprintf(morecode, errvar)
  114. }
  115. // TmpVarCode returns source code for temp variable.
  116. func (p *Param) TmpVarCode() string {
  117. switch {
  118. case p.Type == "string":
  119. return p.StringTmpVarCode()
  120. case p.Type == "bool":
  121. return p.BoolTmpVarCode()
  122. case strings.HasPrefix(p.Type, "[]"):
  123. return p.SliceTmpVarCode()
  124. default:
  125. return ""
  126. }
  127. }
  128. // SyscallArgList returns source code fragments representing p parameter
  129. // in syscall. Slices are translated into 2 syscall parameters: pointer to
  130. // the first element and length.
  131. func (p *Param) SyscallArgList() []string {
  132. var s string
  133. switch {
  134. case p.Type[0] == '*':
  135. s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
  136. case p.Type == "string":
  137. s = fmt.Sprintf("unsafe.Pointer(%s)", p.tmpVar())
  138. case p.Type == "bool":
  139. s = p.tmpVar()
  140. case strings.HasPrefix(p.Type, "[]"):
  141. return []string{
  142. fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
  143. fmt.Sprintf("uintptr(len(%s))", p.Name),
  144. }
  145. default:
  146. s = p.Name
  147. }
  148. return []string{fmt.Sprintf("uintptr(%s)", s)}
  149. }
  150. // IsError determines if p parameter is used to return error.
  151. func (p *Param) IsError() bool {
  152. return p.Name == "err" && p.Type == "error"
  153. }
  154. // join concatenates parameters ps into a string with sep separator.
  155. // Each parameter is converted into string by applying fn to it
  156. // before conversion.
  157. func join(ps []*Param, fn func(*Param) string, sep string) string {
  158. if len(ps) == 0 {
  159. return ""
  160. }
  161. a := make([]string, 0)
  162. for _, p := range ps {
  163. a = append(a, fn(p))
  164. }
  165. return strings.Join(a, sep)
  166. }
  167. // Rets describes function return parameters.
  168. type Rets struct {
  169. Name string
  170. Type string
  171. ReturnsError bool
  172. FailCond string
  173. }
  174. // ErrorVarName returns error variable name for r.
  175. func (r *Rets) ErrorVarName() string {
  176. if r.ReturnsError {
  177. return "err"
  178. }
  179. if r.Type == "error" {
  180. return r.Name
  181. }
  182. return ""
  183. }
  184. // ToParams converts r into slice of *Param.
  185. func (r *Rets) ToParams() []*Param {
  186. ps := make([]*Param, 0)
  187. if len(r.Name) > 0 {
  188. ps = append(ps, &Param{Name: r.Name, Type: r.Type})
  189. }
  190. if r.ReturnsError {
  191. ps = append(ps, &Param{Name: "err", Type: "error"})
  192. }
  193. return ps
  194. }
  195. // List returns source code of syscall return parameters.
  196. func (r *Rets) List() string {
  197. s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ")
  198. if len(s) > 0 {
  199. s = "(" + s + ")"
  200. }
  201. return s
  202. }
  203. // PrintList returns source code of trace printing part correspondent
  204. // to syscall return values.
  205. func (r *Rets) PrintList() string {
  206. return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
  207. }
  208. // SetReturnValuesCode returns source code that accepts syscall return values.
  209. func (r *Rets) SetReturnValuesCode() string {
  210. if r.Name == "" && !r.ReturnsError {
  211. return ""
  212. }
  213. retvar := "r0"
  214. if r.Name == "" {
  215. retvar = "r1"
  216. }
  217. errvar := "_"
  218. if r.ReturnsError {
  219. errvar = "e1"
  220. }
  221. return fmt.Sprintf("%s, _, %s := ", retvar, errvar)
  222. }
  223. func (r *Rets) useLongHandleErrorCode(retvar string) string {
  224. const code = `if %s {
  225. if e1 != 0 {
  226. err = error(e1)
  227. } else {
  228. err = %sEINVAL
  229. }
  230. }`
  231. cond := retvar + " == 0"
  232. if r.FailCond != "" {
  233. cond = strings.Replace(r.FailCond, "failretval", retvar, 1)
  234. }
  235. return fmt.Sprintf(code, cond, windowsdot())
  236. }
  237. // SetErrorCode returns source code that sets return parameters.
  238. func (r *Rets) SetErrorCode() string {
  239. const code = `if r0 != 0 {
  240. %s = syscall.Errno(r0)
  241. }`
  242. if r.Name == "" && !r.ReturnsError {
  243. return ""
  244. }
  245. if r.Name == "" {
  246. return r.useLongHandleErrorCode("r1")
  247. }
  248. if r.Type == "error" {
  249. return fmt.Sprintf(code, r.Name)
  250. }
  251. s := ""
  252. switch {
  253. case r.Type[0] == '*':
  254. s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type)
  255. case r.Type == "bool":
  256. s = fmt.Sprintf("%s = r0 != 0", r.Name)
  257. default:
  258. s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type)
  259. }
  260. if !r.ReturnsError {
  261. return s
  262. }
  263. return s + "\n\t" + r.useLongHandleErrorCode(r.Name)
  264. }
  265. // Fn describes a syscall function.
  266. type Fn struct {
  267. Name string
  268. Params []*Param
  269. Rets *Rets
  270. PrintTrace bool
  271. dllname string
  272. dllfuncname string
  273. src string
  274. // TODO: get rid of this field and just use parameter index instead
  275. curTmpVarIdx int // insure tmp variables have uniq names
  276. }
  277. // extractParams parses s to extract function parameters.
  278. func extractParams(s string, f *Fn) ([]*Param, error) {
  279. s = trim(s)
  280. if s == "" {
  281. return nil, nil
  282. }
  283. a := strings.Split(s, ",")
  284. ps := make([]*Param, len(a))
  285. for i := range ps {
  286. s2 := trim(a[i])
  287. b := strings.Split(s2, " ")
  288. if len(b) != 2 {
  289. b = strings.Split(s2, "\t")
  290. if len(b) != 2 {
  291. return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"")
  292. }
  293. }
  294. ps[i] = &Param{
  295. Name: trim(b[0]),
  296. Type: trim(b[1]),
  297. fn: f,
  298. tmpVarIdx: -1,
  299. }
  300. }
  301. return ps, nil
  302. }
  303. // extractSection extracts text out of string s starting after start
  304. // and ending just before end. found return value will indicate success,
  305. // and prefix, body and suffix will contain correspondent parts of string s.
  306. func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) {
  307. s = trim(s)
  308. if strings.HasPrefix(s, string(start)) {
  309. // no prefix
  310. body = s[1:]
  311. } else {
  312. a := strings.SplitN(s, string(start), 2)
  313. if len(a) != 2 {
  314. return "", "", s, false
  315. }
  316. prefix = a[0]
  317. body = a[1]
  318. }
  319. a := strings.SplitN(body, string(end), 2)
  320. if len(a) != 2 {
  321. return "", "", "", false
  322. }
  323. return prefix, a[0], a[1], true
  324. }
  325. // newFn parses string s and return created function Fn.
  326. func newFn(s string) (*Fn, error) {
  327. s = trim(s)
  328. f := &Fn{
  329. Rets: &Rets{},
  330. src: s,
  331. PrintTrace: *PrintTraceFlag,
  332. }
  333. // function name and args
  334. prefix, body, s, found := extractSection(s, '(', ')')
  335. if !found || prefix == "" {
  336. return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"")
  337. }
  338. f.Name = prefix
  339. var err error
  340. f.Params, err = extractParams(body, f)
  341. if err != nil {
  342. return nil, err
  343. }
  344. // return values
  345. _, body, s, found = extractSection(s, '(', ')')
  346. if found {
  347. r, err := extractParams(body, f)
  348. if err != nil {
  349. return nil, err
  350. }
  351. switch len(r) {
  352. case 0:
  353. case 1:
  354. if r[0].IsError() {
  355. f.Rets.ReturnsError = true
  356. } else {
  357. f.Rets.Name = r[0].Name
  358. f.Rets.Type = r[0].Type
  359. }
  360. case 2:
  361. if !r[1].IsError() {
  362. return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"")
  363. }
  364. f.Rets.ReturnsError = true
  365. f.Rets.Name = r[0].Name
  366. f.Rets.Type = r[0].Type
  367. default:
  368. return nil, errors.New("Too many return values in \"" + f.src + "\"")
  369. }
  370. }
  371. // fail condition
  372. _, body, s, found = extractSection(s, '[', ']')
  373. if found {
  374. f.Rets.FailCond = body
  375. }
  376. // dll and dll function names
  377. s = trim(s)
  378. if s == "" {
  379. return f, nil
  380. }
  381. if !strings.HasPrefix(s, "=") {
  382. return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
  383. }
  384. s = trim(s[1:])
  385. a := strings.Split(s, ".")
  386. switch len(a) {
  387. case 1:
  388. f.dllfuncname = a[0]
  389. case 2:
  390. f.dllname = a[0]
  391. f.dllfuncname = a[1]
  392. default:
  393. return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
  394. }
  395. return f, nil
  396. }
  397. // DLLName returns DLL name for function f.
  398. func (f *Fn) DLLName() string {
  399. if f.dllname == "" {
  400. return "kernel32"
  401. }
  402. return f.dllname
  403. }
  404. // DLLName returns DLL function name for function f.
  405. func (f *Fn) DLLFuncName() string {
  406. if f.dllfuncname == "" {
  407. return f.Name
  408. }
  409. return f.dllfuncname
  410. }
  411. // ParamList returns source code for function f parameters.
  412. func (f *Fn) ParamList() string {
  413. return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
  414. }
  415. // ParamPrintList returns source code of trace printing part correspondent
  416. // to syscall input parameters.
  417. func (f *Fn) ParamPrintList() string {
  418. return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
  419. }
  420. // ParamCount return number of syscall parameters for function f.
  421. func (f *Fn) ParamCount() int {
  422. n := 0
  423. for _, p := range f.Params {
  424. n += len(p.SyscallArgList())
  425. }
  426. return n
  427. }
  428. // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/...
  429. // to use. It returns parameter count for correspondent SyscallX function.
  430. func (f *Fn) SyscallParamCount() int {
  431. n := f.ParamCount()
  432. switch {
  433. case n <= 3:
  434. return 3
  435. case n <= 6:
  436. return 6
  437. case n <= 9:
  438. return 9
  439. case n <= 12:
  440. return 12
  441. case n <= 15:
  442. return 15
  443. default:
  444. panic("too many arguments to system call")
  445. }
  446. }
  447. // Syscall determines which SyscallX function to use for function f.
  448. func (f *Fn) Syscall() string {
  449. c := f.SyscallParamCount()
  450. return "syscall.Syscall" + strconv.Itoa(c)
  451. }
  452. // SyscallParamList returns source code for SyscallX parameters for function f.
  453. func (f *Fn) SyscallParamList() string {
  454. a := make([]string, 0)
  455. for _, p := range f.Params {
  456. a = append(a, p.SyscallArgList()...)
  457. }
  458. for len(a) < f.SyscallParamCount() {
  459. a = append(a, "0")
  460. }
  461. return strings.Join(a, ", ")
  462. }
  463. // IsUTF16 is true, if f is W (utf16) function. It is false
  464. // for all A (ascii) functions.
  465. func (f *Fn) IsUTF16() bool {
  466. s := f.DLLFuncName()
  467. return s[len(s)-1] == 'W'
  468. }
  469. // StrconvFunc returns name of Go string to OS string function for f.
  470. func (f *Fn) StrconvFunc() string {
  471. if f.IsUTF16() {
  472. return windowsdot() + "UTF16PtrFromString"
  473. }
  474. return windowsdot() + "BytePtrFromString"
  475. }
  476. // StrconvType returns Go type name used for OS string for f.
  477. func (f *Fn) StrconvType() string {
  478. if f.IsUTF16() {
  479. return "*uint16"
  480. }
  481. return "*byte"
  482. }
  483. // Source files and functions.
  484. type Source struct {
  485. Funcs []*Fn
  486. Files []string
  487. }
  488. // ParseFiles parses files listed in fs and extracts all syscall
  489. // functions listed in sys comments. It returns source files
  490. // and functions collection *Source if successful.
  491. func ParseFiles(fs []string) (*Source, error) {
  492. src := &Source{
  493. Funcs: make([]*Fn, 0),
  494. Files: make([]string, 0),
  495. }
  496. for _, file := range fs {
  497. if err := src.ParseFile(file); err != nil {
  498. return nil, err
  499. }
  500. }
  501. return src, nil
  502. }
  503. // DLLs return dll names for a source set src.
  504. func (src *Source) DLLs() []string {
  505. uniq := make(map[string]bool)
  506. r := make([]string, 0)
  507. for _, f := range src.Funcs {
  508. name := f.DLLName()
  509. if _, found := uniq[name]; !found {
  510. uniq[name] = true
  511. r = append(r, name)
  512. }
  513. }
  514. return r
  515. }
  516. // ParseFile adds adition file path to a source set src.
  517. func (src *Source) ParseFile(path string) error {
  518. file, err := os.Open(path)
  519. if err != nil {
  520. return err
  521. }
  522. defer file.Close()
  523. s := bufio.NewScanner(file)
  524. for s.Scan() {
  525. t := trim(s.Text())
  526. if len(t) < 7 {
  527. continue
  528. }
  529. if !strings.HasPrefix(t, "//sys") {
  530. continue
  531. }
  532. t = t[5:]
  533. if !(t[0] == ' ' || t[0] == '\t') {
  534. continue
  535. }
  536. f, err := newFn(t[1:])
  537. if err != nil {
  538. return err
  539. }
  540. src.Funcs = append(src.Funcs, f)
  541. }
  542. if err := s.Err(); err != nil {
  543. return err
  544. }
  545. src.Files = append(src.Files, path)
  546. // get package name
  547. fset := token.NewFileSet()
  548. _, err = file.Seek(0, 0)
  549. if err != nil {
  550. return err
  551. }
  552. pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly)
  553. if err != nil {
  554. return err
  555. }
  556. packageName = pkg.Name.Name
  557. return nil
  558. }
  559. // Generate output source file from a source set src.
  560. func (src *Source) Generate(w io.Writer) error {
  561. funcMap := template.FuncMap{
  562. "windowsdot": windowsdot,
  563. "packagename": packagename,
  564. }
  565. t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate))
  566. err := t.Execute(w, src)
  567. if err != nil {
  568. return errors.New("Failed to execute template: " + err.Error())
  569. }
  570. return nil
  571. }
  572. func usage() {
  573. fmt.Fprintf(os.Stderr, "usage: mksyscall_windows [flags] [path ...]\n")
  574. flag.PrintDefaults()
  575. os.Exit(1)
  576. }
  577. func main() {
  578. flag.Usage = usage
  579. flag.Parse()
  580. if len(os.Args) <= 1 {
  581. fmt.Fprintf(os.Stderr, "no files to parse provided\n")
  582. usage()
  583. }
  584. src, err := ParseFiles(os.Args[1:])
  585. if err != nil {
  586. log.Fatal(err)
  587. }
  588. if err := src.Generate(os.Stdout); err != nil {
  589. log.Fatal(err)
  590. }
  591. }
  592. // TODO: use println instead to print in the following template
  593. const srcTemplate = `
  594. {{define "main"}}// go build mksyscall_windows.go && ./mksyscall_windows{{range .Files}} {{.}}{{end}}
  595. // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
  596. package {{packagename}}
  597. import "syscall"
  598. import "unsafe"{{if windowsdot}}
  599. import "code.google.com/p/go.sys/windows"{{end}}
  600. var (
  601. {{template "dlls" .}}
  602. {{template "funcnames" .}})
  603. {{range .Funcs}}{{template "funcbody" .}}{{end}}
  604. {{end}}
  605. {{/* help functions */}}
  606. {{define "dlls"}}{{range .DLLs}} mod{{.}} = {{windowsdot}}NewLazyDLL("{{.}}.dll")
  607. {{end}}{{end}}
  608. {{define "funcnames"}}{{range .Funcs}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}")
  609. {{end}}{{end}}
  610. {{define "funcbody"}}
  611. func {{.Name}}({{.ParamList}}) {{if .Rets.List}}{{.Rets.List}} {{end}}{
  612. {{template "tmpvars" .}} {{template "syscall" .}}
  613. {{template "seterror" .}}{{template "printtrace" .}} return
  614. }
  615. {{end}}
  616. {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}}
  617. {{end}}{{end}}{{end}}
  618. {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
  619. {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}}
  620. {{end}}{{end}}
  621. {{define "printtrace"}}{{if .PrintTrace}} print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n")
  622. {{end}}{{end}}
  623. `