Quellcode durchsuchen

Fix race in command example

Close files in handler goroutine to avoid race.
Remove use of exec.Cmd. It's not adding anything.
Gary Burd vor 10 Jahren
Ursprung
Commit
0e33ab35f9
2 geänderte Dateien mit 66 neuen und 48 gelöschten Zeilen
  1. 5 1
      examples/command/home.html
  2. 61 47
      examples/command/main.go

+ 5 - 1
examples/command/home.html

@@ -37,7 +37,7 @@
             appendLog($("<div><b>Connection closed.</b></div>"))
         }
         conn.onmessage = function(evt) {
-            appendLog($("<div/>").text(evt.data))
+            appendLog($("<pre/>").text(evt.data))
         }
     } else {
         appendLog($("<div><b>Your browser does not support WebSockets.</b></div>"))
@@ -70,6 +70,10 @@ body {
     overflow: auto;
 }
 
+#log pre {
+  margin: 0;
+}
+
 #form {
     padding: 0 0.5em 0 0.5em;
     margin: 0;

+ 61 - 47
examples/command/main.go

@@ -20,6 +20,7 @@ import (
 
 var (
 	addr      = flag.String("addr", "127.0.0.1:8080", "http service address")
+	cmdPath   string
 	homeTempl = template.Must(template.ParseFiles("home.html"))
 )
 
@@ -31,49 +32,40 @@ const (
 	maxMessageSize = 8192
 )
 
-// connection is an middleman between the websocket connection and the command.
-type connection struct {
-	ws     *websocket.Conn
-	stdout io.ReadCloser
-	stdin  io.WriteCloser
-	cmd    *exec.Cmd
-}
-
-func (c *connection) pumpStdin() {
-	defer c.ws.Close()
-	c.ws.SetReadLimit(maxMessageSize)
+func pumpStdin(ws *websocket.Conn, w io.Writer) {
+	defer ws.Close()
+	ws.SetReadLimit(maxMessageSize)
 	for {
-		_, message, err := c.ws.ReadMessage()
+		_, message, err := ws.ReadMessage()
 		if err != nil {
 			break
 		}
 		message = append(message, '\n')
-		if _, err := c.stdin.Write(message); err != nil {
+		if _, err := w.Write(message); err != nil {
 			break
 		}
 	}
-	c.stdin.Close()
 	log.Println("exit stdin pump")
 }
 
-func (c *connection) pumpStdout() {
-	defer c.ws.Close()
-	s := bufio.NewScanner(c.stdout)
+func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) {
+	defer ws.Close()
+	s := bufio.NewScanner(r)
 	for s.Scan() {
-		c.ws.SetWriteDeadline(time.Now().Add(writeWait))
-		if err := c.ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil {
+		ws.SetWriteDeadline(time.Now().Add(writeWait))
+		if err := ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil {
 			break
 		}
 	}
 	if s.Err() != nil {
 		log.Println("scan:", s.Err())
 	}
-	c.stdout.Close()
+	close(done)
 	log.Println("exit stdout pump")
 }
 
-func internalError(ws *websocket.Conn, fmt string, err error) {
-	log.Println(fmt, err)
+func internalError(ws *websocket.Conn, msg string, err error) {
+	log.Println(msg, err)
 	ws.WriteMessage(websocket.TextMessage, []byte("Internal server error."))
 }
 
@@ -87,49 +79,66 @@ func serveWs(w http.ResponseWriter, r *http.Request) {
 
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
-		log.Println(err)
+		log.Println("upgrade:", err)
 		return
 	}
 
-	c := &connection{
-		cmd: exec.Command(flag.Args()[0], flag.Args()[1:]...),
-		ws:  ws,
-	}
+	defer ws.Close()
 
-	c.stdout, err = c.cmd.StdoutPipe()
+	outr, outw, err := os.Pipe()
 	if err != nil {
-		internalError(ws, "stdout: %v", err)
-		ws.Close()
+		internalError(ws, "stdout:", err)
 		return
 	}
+	defer outr.Close()
+	defer outw.Close()
 
-	c.stdin, err = c.cmd.StdinPipe()
+	inr, inw, err := os.Pipe()
 	if err != nil {
-		internalError(ws, "stdin: %v", err)
-		c.stdout.Close()
-		if closer, ok := c.cmd.Stdout.(io.Closer); ok {
-			closer.Close()
-		}
-		ws.Close()
+		internalError(ws, "stdin:", err)
 		return
 	}
+	defer inr.Close()
+	defer inw.Close()
 
-	if err := c.cmd.Start(); err != nil {
-		internalError(ws, "start: %v", err)
-		c.stdout.Close()
-		c.stdin.Close()
-		ws.Close()
+	proc, err := os.StartProcess(cmdPath, flag.Args(), &os.ProcAttr{
+		Files: []*os.File{inr, outw, outw},
+	})
+	if err != nil {
+		internalError(ws, "start:", err)
 		return
 	}
 
-	go c.pumpStdout()
-	c.pumpStdin()
+	inr.Close()
+	outw.Close()
+
+	done := make(chan struct{})
+	go pumpStdout(ws, outr, done)
+
+	pumpStdin(ws, inw)
+
+	// Some commands will exit when stdin is closed.
+	inw.Close()
 
-	c.cmd.Process.Signal(os.Interrupt)
-	if err := c.cmd.Wait(); err != nil {
+	// Other comamnds need a bonk on the head.
+	if err := proc.Signal(os.Interrupt); err != nil {
+		log.Println("inter:", err)
+	}
+
+	select {
+	case <-done:
+	case <-time.After(time.Second):
+		// A bigger bonk on the head.
+		if err := proc.Signal(os.Kill); err != nil {
+			log.Println("term:", err)
+		}
+		<-done
+	}
+
+	if _, err := proc.Wait(); err != nil {
 		log.Println("wait:", err)
 	}
-	log.Println("exit serveWs")
+	log.Println("exiting handler")
 }
 
 func serveHome(w http.ResponseWriter, r *http.Request) {
@@ -150,6 +159,11 @@ func main() {
 	if len(flag.Args()) < 1 {
 		log.Fatal("must specify at least one argument")
 	}
+	var err error
+	cmdPath, err = exec.LookPath(flag.Args()[0])
+	if err != nil {
+		log.Fatal(err)
+	}
 	http.HandleFunc("/", serveHome)
 	http.HandleFunc("/ws", serveWs)
 	log.Fatal(http.ListenAndServe(*addr, nil))