浏览代码

go.crypto/ssh: replace window channel with an atomic variable and condition

Fixes golang/go#3479.

Using a channel to model window size was a mistake. Unlike stdin and
stdout, which are streams of data, window size is an variable and
should be modeled as such.

R=golang-dev, agl, gustav.paul, kardianos, dvyukov
CC=golang-dev
https://golang.org/cl/5986053
Dave Cheney 13 年之前
父节点
当前提交
bd686358fc
共有 1 个文件被更改,包括 54 次插入16 次删除
  1. 54 16
      ssh/client.go

+ 54 - 16
ssh/client.go

@@ -192,7 +192,7 @@ func (c *ClientConn) mainLoop() {
 			break
 		}
 		// TODO(dfc) A note on blocking channel use.
-		// The msg, win, data and dataExt channels of a clientChan can
+		// The msg, data and dataExt channels of a clientChan can
 		// cause this loop to block indefinately if the consumer does
 		// not service them.
 		switch packet[0] {
@@ -233,7 +233,6 @@ func (c *ClientConn) mainLoop() {
 			case *channelCloseMsg:
 				ch := c.getChan(msg.PeersId)
 				ch.theyClosed = true
-				close(ch.stdin.win)
 				ch.stdout.eof()
 				ch.stderr.eof()
 				close(ch.msg)
@@ -255,7 +254,10 @@ func (c *ClientConn) mainLoop() {
 			case *channelRequestMsg:
 				c.getChan(msg.PeersId).msg <- msg
 			case *windowAdjustMsg:
-				c.getChan(msg.PeersId).stdin.win <- int(msg.AdditionalBytes)
+				if !c.getChan(msg.PeersId).stdin.win.add(msg.AdditionalBytes) {
+					// invalid window update
+					break
+				}
 			case *disconnectMsg:
 				break
 			default:
@@ -324,7 +326,7 @@ func newClientChan(t *transport, id uint32) *clientChan {
 		msg:          make(chan interface{}, 16),
 	}
 	c.stdin = &chanWriter{
-		win:        make(chan int, 16),
+		win:        &window{Cond: sync.NewCond(new(sync.Mutex))},
 		clientChan: c,
 	}
 	c.stdout = &chanReader{
@@ -345,7 +347,7 @@ func (c *clientChan) waitForChannelOpenResponse() error {
 	case *channelOpenConfirmMsg:
 		// fixup peersId field
 		c.peersId = msg.MyId
-		c.stdin.win <- int(msg.MyWindow)
+		c.stdin.win.add(msg.MyWindow)
 		return nil
 	case *channelOpenFailureMsg:
 		return errors.New(safeString(msg.Message))
@@ -417,22 +419,16 @@ func (c *chanlist) remove(id uint32) {
 
 // A chanWriter represents the stdin of a remote process.
 type chanWriter struct {
-	win        chan int    // receives window adjustments
-	rwin       int         // current rwin size
+	win        *window
 	clientChan *clientChan // the channel backing this writer
 }
 
 // Write writes data to the remote process's standard input.
 func (w *chanWriter) Write(data []byte) (written int, err error) {
 	for len(data) > 0 {
-		for w.rwin < 1 {
-			win, ok := <-w.win
-			if !ok {
-				return 0, io.EOF
-			}
-			w.rwin += win
-		}
-		n := min(len(data), w.rwin)
+		// n cannot be larger than 2^31 as len(data) cannot
+		// be larger than 2^31
+		n := int(w.win.reserve(uint32(len(data))))
 		peersId := w.clientChan.peersId
 		packet := []byte{
 			msgChannelData,
@@ -443,7 +439,6 @@ func (w *chanWriter) Write(data []byte) (written int, err error) {
 			break
 		}
 		data = data[n:]
-		w.rwin -= n
 		written += n
 	}
 	return
@@ -507,3 +502,46 @@ func (r *chanReader) Read(data []byte) (int, error) {
 	}
 	panic("unreachable")
 }
+
+// window represents the buffer available to clients 
+// wishing to write to a channel.
+type window struct {
+	*sync.Cond
+	win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
+}
+
+// add adds win to the amount of window available
+// for consumers.
+func (w *window) add(win uint32) bool {
+	if win == 0 {
+		return false
+	}
+	w.L.Lock()
+	if w.win+win < win {
+		w.L.Unlock()
+		return false
+	}
+	w.win += win
+	// It is unusual that multiple goroutines would be attempting to reserve
+	// window space, but not guaranteed. Use broadcast to notify all waiters 
+	// that additional window is available.
+	w.Broadcast()
+	w.L.Unlock()
+	return true
+}
+
+// reserve reserves win from the available window capacity.
+// If no capacity remains, reserve will block. reserve may
+// return less than requested.
+func (w *window) reserve(win uint32) uint32 {
+	w.L.Lock()
+	for w.win == 0 {
+		w.Wait()
+	}
+	if w.win < win {
+		win = w.win
+	}
+	w.win -= win
+	w.L.Unlock()
+	return win
+}