浏览代码

Fix race in command example

Close files in handler goroutine to avoid race.
Remove use of exec.Cmd. It's not adding anything.
Gary Burd 10 年之前
父节点
当前提交
0e33ab35f9
共有 2 个文件被更改,包括 66 次插入48 次删除
  1. 5 1
      examples/command/home.html
  2. 61 47
      examples/command/main.go

+ 5 - 1
examples/command/home.html

@@ -37,7 +37,7 @@
             appendLog($("<div><b>Connection closed.</b></div>"))
         }
         conn.onmessage = function(evt) {
-            appendLog($("<div/>").text(evt.data))
+            appendLog($("<pre/>").text(evt.data))
         }
     } else {
         appendLog($("<div><b>Your browser does not support WebSockets.</b></div>"))
@@ -70,6 +70,10 @@ body {
     overflow: auto;
 }
 
+#log pre {
+  margin: 0;
+}
+
 #form {
     padding: 0 0.5em 0 0.5em;
     margin: 0;

+ 61 - 47
examples/command/main.go

@@ -20,6 +20,7 @@ import (
 
 var (
 	addr      = flag.String("addr", "127.0.0.1:8080", "http service address")
+	cmdPath   string
 	homeTempl = template.Must(template.ParseFiles("home.html"))
 )
 
@@ -31,49 +32,40 @@ const (
 	maxMessageSize = 8192
 )
 
-// connection is an middleman between the websocket connection and the command.
-type connection struct {
-	ws     *websocket.Conn
-	stdout io.ReadCloser
-	stdin  io.WriteCloser
-	cmd    *exec.Cmd
-}
-
-func (c *connection) pumpStdin() {
-	defer c.ws.Close()
-	c.ws.SetReadLimit(maxMessageSize)
+func pumpStdin(ws *websocket.Conn, w io.Writer) {
+	defer ws.Close()
+	ws.SetReadLimit(maxMessageSize)
 	for {
-		_, message, err := c.ws.ReadMessage()
+		_, message, err := ws.ReadMessage()
 		if err != nil {
 			break
 		}
 		message = append(message, '\n')
-		if _, err := c.stdin.Write(message); err != nil {
+		if _, err := w.Write(message); err != nil {
 			break
 		}
 	}
-	c.stdin.Close()
 	log.Println("exit stdin pump")
 }
 
-func (c *connection) pumpStdout() {
-	defer c.ws.Close()
-	s := bufio.NewScanner(c.stdout)
+func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) {
+	defer ws.Close()
+	s := bufio.NewScanner(r)
 	for s.Scan() {
-		c.ws.SetWriteDeadline(time.Now().Add(writeWait))
-		if err := c.ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil {
+		ws.SetWriteDeadline(time.Now().Add(writeWait))
+		if err := ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil {
 			break
 		}
 	}
 	if s.Err() != nil {
 		log.Println("scan:", s.Err())
 	}
-	c.stdout.Close()
+	close(done)
 	log.Println("exit stdout pump")
 }
 
-func internalError(ws *websocket.Conn, fmt string, err error) {
-	log.Println(fmt, err)
+func internalError(ws *websocket.Conn, msg string, err error) {
+	log.Println(msg, err)
 	ws.WriteMessage(websocket.TextMessage, []byte("Internal server error."))
 }
 
@@ -87,49 +79,66 @@ func serveWs(w http.ResponseWriter, r *http.Request) {
 
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
-		log.Println(err)
+		log.Println("upgrade:", err)
 		return
 	}
 
-	c := &connection{
-		cmd: exec.Command(flag.Args()[0], flag.Args()[1:]...),
-		ws:  ws,
-	}
+	defer ws.Close()
 
-	c.stdout, err = c.cmd.StdoutPipe()
+	outr, outw, err := os.Pipe()
 	if err != nil {
-		internalError(ws, "stdout: %v", err)
-		ws.Close()
+		internalError(ws, "stdout:", err)
 		return
 	}
+	defer outr.Close()
+	defer outw.Close()
 
-	c.stdin, err = c.cmd.StdinPipe()
+	inr, inw, err := os.Pipe()
 	if err != nil {
-		internalError(ws, "stdin: %v", err)
-		c.stdout.Close()
-		if closer, ok := c.cmd.Stdout.(io.Closer); ok {
-			closer.Close()
-		}
-		ws.Close()
+		internalError(ws, "stdin:", err)
 		return
 	}
+	defer inr.Close()
+	defer inw.Close()
 
-	if err := c.cmd.Start(); err != nil {
-		internalError(ws, "start: %v", err)
-		c.stdout.Close()
-		c.stdin.Close()
-		ws.Close()
+	proc, err := os.StartProcess(cmdPath, flag.Args(), &os.ProcAttr{
+		Files: []*os.File{inr, outw, outw},
+	})
+	if err != nil {
+		internalError(ws, "start:", err)
 		return
 	}
 
-	go c.pumpStdout()
-	c.pumpStdin()
+	inr.Close()
+	outw.Close()
+
+	done := make(chan struct{})
+	go pumpStdout(ws, outr, done)
+
+	pumpStdin(ws, inw)
+
+	// Some commands will exit when stdin is closed.
+	inw.Close()
 
-	c.cmd.Process.Signal(os.Interrupt)
-	if err := c.cmd.Wait(); err != nil {
+	// Other comamnds need a bonk on the head.
+	if err := proc.Signal(os.Interrupt); err != nil {
+		log.Println("inter:", err)
+	}
+
+	select {
+	case <-done:
+	case <-time.After(time.Second):
+		// A bigger bonk on the head.
+		if err := proc.Signal(os.Kill); err != nil {
+			log.Println("term:", err)
+		}
+		<-done
+	}
+
+	if _, err := proc.Wait(); err != nil {
 		log.Println("wait:", err)
 	}
-	log.Println("exit serveWs")
+	log.Println("exiting handler")
 }
 
 func serveHome(w http.ResponseWriter, r *http.Request) {
@@ -150,6 +159,11 @@ func main() {
 	if len(flag.Args()) < 1 {
 		log.Fatal("must specify at least one argument")
 	}
+	var err error
+	cmdPath, err = exec.LookPath(flag.Args()[0])
+	if err != nil {
+		log.Fatal(err)
+	}
 	http.HandleFunc("/", serveHome)
 	http.HandleFunc("/ws", serveWs)
 	log.Fatal(http.ListenAndServe(*addr, nil))