Browse Source

Fix race in command example

Close files in handler goroutine to avoid race.
Remove use of exec.Cmd. It's not adding anything.
Gary Burd 10 years ago
parent
commit
0e33ab35f9
2 changed files with 66 additions and 48 deletions
  1. 5 1
      examples/command/home.html
  2. 61 47
      examples/command/main.go

+ 5 - 1
examples/command/home.html

@@ -37,7 +37,7 @@
             appendLog($("<div><b>Connection closed.</b></div>"))
             appendLog($("<div><b>Connection closed.</b></div>"))
         }
         }
         conn.onmessage = function(evt) {
         conn.onmessage = function(evt) {
-            appendLog($("<div/>").text(evt.data))
+            appendLog($("<pre/>").text(evt.data))
         }
         }
     } else {
     } else {
         appendLog($("<div><b>Your browser does not support WebSockets.</b></div>"))
         appendLog($("<div><b>Your browser does not support WebSockets.</b></div>"))
@@ -70,6 +70,10 @@ body {
     overflow: auto;
     overflow: auto;
 }
 }
 
 
+#log pre {
+  margin: 0;
+}
+
 #form {
 #form {
     padding: 0 0.5em 0 0.5em;
     padding: 0 0.5em 0 0.5em;
     margin: 0;
     margin: 0;

+ 61 - 47
examples/command/main.go

@@ -20,6 +20,7 @@ import (
 
 
 var (
 var (
 	addr      = flag.String("addr", "127.0.0.1:8080", "http service address")
 	addr      = flag.String("addr", "127.0.0.1:8080", "http service address")
+	cmdPath   string
 	homeTempl = template.Must(template.ParseFiles("home.html"))
 	homeTempl = template.Must(template.ParseFiles("home.html"))
 )
 )
 
 
@@ -31,49 +32,40 @@ const (
 	maxMessageSize = 8192
 	maxMessageSize = 8192
 )
 )
 
 
-// connection is an middleman between the websocket connection and the command.
-type connection struct {
-	ws     *websocket.Conn
-	stdout io.ReadCloser
-	stdin  io.WriteCloser
-	cmd    *exec.Cmd
-}
-
-func (c *connection) pumpStdin() {
-	defer c.ws.Close()
-	c.ws.SetReadLimit(maxMessageSize)
+func pumpStdin(ws *websocket.Conn, w io.Writer) {
+	defer ws.Close()
+	ws.SetReadLimit(maxMessageSize)
 	for {
 	for {
-		_, message, err := c.ws.ReadMessage()
+		_, message, err := ws.ReadMessage()
 		if err != nil {
 		if err != nil {
 			break
 			break
 		}
 		}
 		message = append(message, '\n')
 		message = append(message, '\n')
-		if _, err := c.stdin.Write(message); err != nil {
+		if _, err := w.Write(message); err != nil {
 			break
 			break
 		}
 		}
 	}
 	}
-	c.stdin.Close()
 	log.Println("exit stdin pump")
 	log.Println("exit stdin pump")
 }
 }
 
 
-func (c *connection) pumpStdout() {
-	defer c.ws.Close()
-	s := bufio.NewScanner(c.stdout)
+func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) {
+	defer ws.Close()
+	s := bufio.NewScanner(r)
 	for s.Scan() {
 	for s.Scan() {
-		c.ws.SetWriteDeadline(time.Now().Add(writeWait))
-		if err := c.ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil {
+		ws.SetWriteDeadline(time.Now().Add(writeWait))
+		if err := ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil {
 			break
 			break
 		}
 		}
 	}
 	}
 	if s.Err() != nil {
 	if s.Err() != nil {
 		log.Println("scan:", s.Err())
 		log.Println("scan:", s.Err())
 	}
 	}
-	c.stdout.Close()
+	close(done)
 	log.Println("exit stdout pump")
 	log.Println("exit stdout pump")
 }
 }
 
 
-func internalError(ws *websocket.Conn, fmt string, err error) {
-	log.Println(fmt, err)
+func internalError(ws *websocket.Conn, msg string, err error) {
+	log.Println(msg, err)
 	ws.WriteMessage(websocket.TextMessage, []byte("Internal server error."))
 	ws.WriteMessage(websocket.TextMessage, []byte("Internal server error."))
 }
 }
 
 
@@ -87,49 +79,66 @@ func serveWs(w http.ResponseWriter, r *http.Request) {
 
 
 	ws, err := upgrader.Upgrade(w, r, nil)
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 	if err != nil {
-		log.Println(err)
+		log.Println("upgrade:", err)
 		return
 		return
 	}
 	}
 
 
-	c := &connection{
-		cmd: exec.Command(flag.Args()[0], flag.Args()[1:]...),
-		ws:  ws,
-	}
+	defer ws.Close()
 
 
-	c.stdout, err = c.cmd.StdoutPipe()
+	outr, outw, err := os.Pipe()
 	if err != nil {
 	if err != nil {
-		internalError(ws, "stdout: %v", err)
-		ws.Close()
+		internalError(ws, "stdout:", err)
 		return
 		return
 	}
 	}
+	defer outr.Close()
+	defer outw.Close()
 
 
-	c.stdin, err = c.cmd.StdinPipe()
+	inr, inw, err := os.Pipe()
 	if err != nil {
 	if err != nil {
-		internalError(ws, "stdin: %v", err)
-		c.stdout.Close()
-		if closer, ok := c.cmd.Stdout.(io.Closer); ok {
-			closer.Close()
-		}
-		ws.Close()
+		internalError(ws, "stdin:", err)
 		return
 		return
 	}
 	}
+	defer inr.Close()
+	defer inw.Close()
 
 
-	if err := c.cmd.Start(); err != nil {
-		internalError(ws, "start: %v", err)
-		c.stdout.Close()
-		c.stdin.Close()
-		ws.Close()
+	proc, err := os.StartProcess(cmdPath, flag.Args(), &os.ProcAttr{
+		Files: []*os.File{inr, outw, outw},
+	})
+	if err != nil {
+		internalError(ws, "start:", err)
 		return
 		return
 	}
 	}
 
 
-	go c.pumpStdout()
-	c.pumpStdin()
+	inr.Close()
+	outw.Close()
+
+	done := make(chan struct{})
+	go pumpStdout(ws, outr, done)
+
+	pumpStdin(ws, inw)
+
+	// Some commands will exit when stdin is closed.
+	inw.Close()
 
 
-	c.cmd.Process.Signal(os.Interrupt)
-	if err := c.cmd.Wait(); err != nil {
+	// Other comamnds need a bonk on the head.
+	if err := proc.Signal(os.Interrupt); err != nil {
+		log.Println("inter:", err)
+	}
+
+	select {
+	case <-done:
+	case <-time.After(time.Second):
+		// A bigger bonk on the head.
+		if err := proc.Signal(os.Kill); err != nil {
+			log.Println("term:", err)
+		}
+		<-done
+	}
+
+	if _, err := proc.Wait(); err != nil {
 		log.Println("wait:", err)
 		log.Println("wait:", err)
 	}
 	}
-	log.Println("exit serveWs")
+	log.Println("exiting handler")
 }
 }
 
 
 func serveHome(w http.ResponseWriter, r *http.Request) {
 func serveHome(w http.ResponseWriter, r *http.Request) {
@@ -150,6 +159,11 @@ func main() {
 	if len(flag.Args()) < 1 {
 	if len(flag.Args()) < 1 {
 		log.Fatal("must specify at least one argument")
 		log.Fatal("must specify at least one argument")
 	}
 	}
+	var err error
+	cmdPath, err = exec.LookPath(flag.Args()[0])
+	if err != nil {
+		log.Fatal(err)
+	}
 	http.HandleFunc("/", serveHome)
 	http.HandleFunc("/", serveHome)
 	http.HandleFunc("/ws", serveWs)
 	http.HandleFunc("/ws", serveWs)
 	log.Fatal(http.ListenAndServe(*addr, nil))
 	log.Fatal(http.ListenAndServe(*addr, nil))