浏览代码

go.crypto/ssh: make {client,server}Chan use common window management

R=agl, gustav.paul, kardianos
CC=golang-dev
https://golang.org/cl/6208043
Dave Cheney 13 年之前
父节点
当前提交
8a2e7c966a
共有 4 个文件被更改,包括 76 次插入82 次删除
  1. 17 27
      ssh/channel.go
  2. 7 51
      ssh/client.go
  3. 47 0
      ssh/common.go
  4. 5 4
      ssh/server.go

+ 17 - 27
ssh/channel.go

@@ -73,6 +73,7 @@ const (
 type channel struct {
 	conn              // the underlying transport
 	localId, remoteId uint32
+	remoteWin         window
 
 	theyClosed  bool // indicates the close msg has been received from the remote side
 	weClosed    bool // incidates the close msg has been sent from our side
@@ -118,10 +119,10 @@ type serverChan struct {
 	chanType  string
 	extraData []byte
 
-	serverConn            *ServerConn
-	myWindow, theirWindow uint32
-	maxPacketSize         uint32
-	err                   error
+	serverConn    *ServerConn
+	myWindow      uint32
+	maxPacketSize uint32
+	err           error
 
 	pendingRequests []ChannelRequest
 	pendingData     []byte
@@ -180,8 +181,9 @@ func (c *serverChan) handlePacket(packet interface{}) {
 		c.theySentEOF = true
 		c.cond.Signal()
 	case *windowAdjustMsg:
-		c.theirWindow += packet.AdditionalBytes
-		c.cond.Signal()
+		if !c.remoteWin.add(packet.AdditionalBytes) {
+			panic("illegal window update")
+		}
 	default:
 		panic("unknown packet type")
 	}
@@ -230,7 +232,6 @@ func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
 		if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
 			return 0, err
 		}
-
 		todo := data
 		if uint32(len(todo)) > space {
 			todo = todo[:space]
@@ -321,28 +322,18 @@ func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint3
 // getWindowSpace takes, at most, max bytes of space from the peer's window. It
 // returns the number of bytes actually reserved.
 func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
+	var err error
+	// TODO(dfc) This lock and check of c.weClosed is necessary because unlike
+	// clientChan, c.weClosed is observed by more than one goroutine.
 	c.cond.L.Lock()
-	defer c.cond.L.Unlock()
-
-	for {
-		if c.dead || c.weClosed {
-			return 0, io.EOF
-		}
-
-		if c.theirWindow > 0 {
-			break
-		}
-
-		c.cond.Wait()
+	if c.dead || c.weClosed {
+		err = io.EOF
 	}
-
-	taken := c.theirWindow
-	if taken > max {
-		taken = max
+	c.cond.L.Unlock()
+	if err != nil {
+		return 0, err
 	}
-
-	c.theirWindow -= taken
-	return taken, nil
+	return c.remoteWin.reserve(max), nil
 }
 
 func (c *serverChan) Write(data []byte) (n int, err error) {
@@ -351,7 +342,6 @@ func (c *serverChan) Write(data []byte) (n int, err error) {
 		if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
 			return 0, err
 		}
-
 		todo := data
 		if uint32(len(todo)) > space {
 			todo = todo[:space]

+ 7 - 51
ssh/client.go

@@ -276,7 +276,7 @@ func (c *ClientConn) mainLoop() {
 			case *channelRequestMsg:
 				c.getChan(msg.PeersId).msg <- msg
 			case *windowAdjustMsg:
-				if !c.getChan(msg.PeersId).stdin.win.add(msg.AdditionalBytes) {
+				if !c.getChan(msg.PeersId).remoteWin.add(msg.AdditionalBytes) {
 					// invalid window update
 					return
 				}
@@ -317,7 +317,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 		}
 		ch := c.newChan(c.transport)
 		ch.remoteId = msg.PeersId
-		ch.stdin.win.add(msg.PeersWindow)
+		ch.remoteWin.add(msg.PeersWindow)
 
 		m := channelOpenConfirmMsg{
 			PeersId:       ch.remoteId,
@@ -438,13 +438,13 @@ type clientChan struct {
 func newClientChan(cc conn, id uint32) *clientChan {
 	c := &clientChan{
 		channel: channel{
-			conn:    cc,
-			localId: id,
+			conn:      cc,
+			localId:   id,
+			remoteWin: window{Cond: newCond()},
 		},
 		msg: make(chan interface{}, 16),
 	}
 	c.stdin = &chanWriter{
-		win:     &window{Cond: sync.NewCond(new(sync.Mutex))},
 		channel: &c.channel,
 	}
 	c.stdout = &chanReader{
@@ -465,7 +465,7 @@ func (c *clientChan) waitForChannelOpenResponse() error {
 	case *channelOpenConfirmMsg:
 		// fixup remoteId field
 		c.remoteId = msg.MyId
-		c.stdin.win.add(msg.MyWindow)
+		c.remoteWin.add(msg.MyWindow)
 		return nil
 	case *channelOpenFailureMsg:
 		return errors.New(safeString(msg.Message))
@@ -542,7 +542,6 @@ func (c *chanList) closeAll() {
 
 // A chanWriter represents the stdin of a remote process.
 type chanWriter struct {
-	win *window
 	*channel
 }
 
@@ -551,7 +550,7 @@ func (w *chanWriter) Write(data []byte) (written int, err error) {
 	for len(data) > 0 {
 		// n cannot be larger than 2^31 as len(data) cannot
 		// be larger than 2^31
-		n := int(w.win.reserve(uint32(len(data))))
+		n := int(w.remoteWin.reserve(uint32(len(data))))
 		remoteId := w.remoteId
 		packet := []byte{
 			msgChannelData,
@@ -621,46 +620,3 @@ func (r *chanReader) Read(data []byte) (int, error) {
 	}
 	panic("unreachable")
 }
-
-// window represents the buffer available to clients 
-// wishing to write to a channel.
-type window struct {
-	*sync.Cond
-	win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
-}
-
-// add adds win to the amount of window available
-// for consumers.
-func (w *window) add(win uint32) bool {
-	if win == 0 {
-		return false
-	}
-	w.L.Lock()
-	if w.win+win < win {
-		w.L.Unlock()
-		return false
-	}
-	w.win += win
-	// It is unusual that multiple goroutines would be attempting to reserve
-	// window space, but not guaranteed. Use broadcast to notify all waiters 
-	// that additional window is available.
-	w.Broadcast()
-	w.L.Unlock()
-	return true
-}
-
-// reserve reserves win from the available window capacity.
-// If no capacity remains, reserve will block. reserve may
-// return less than requested.
-func (w *window) reserve(win uint32) uint32 {
-	w.L.Lock()
-	for w.win == 0 {
-		w.Wait()
-	}
-	if w.win < win {
-		win = w.win
-	}
-	w.win -= win
-	w.L.Unlock()
-	return win
-}

+ 47 - 0
ssh/common.go

@@ -285,3 +285,50 @@ func appendU32(buf []byte, n uint32) []byte {
 func appendInt(buf []byte, n int) []byte {
 	return appendU32(buf, uint32(n))
 }
+
+// newCond is a helper to hide the fact that there is no usable zero 
+// value for sync.Cond.
+func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) }
+
+// window represents the buffer available to clients 
+// wishing to write to a channel.
+type window struct {
+	*sync.Cond
+	win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
+}
+
+// add adds win to the amount of window available
+// for consumers.
+func (w *window) add(win uint32) bool {
+	if win == 0 {
+		return false
+	}
+	w.L.Lock()
+	if w.win+win < win {
+		w.L.Unlock()
+		return false
+	}
+	w.win += win
+	// It is unusual that multiple goroutines would be attempting to reserve
+	// window space, but not guaranteed. Use broadcast to notify all waiters 
+	// that additional window is available.
+	w.Broadcast()
+	w.L.Unlock()
+	return true
+}
+
+// reserve reserves win from the available window capacity.
+// If no capacity remains, reserve will block. reserve may
+// return less than requested.
+func (w *window) reserve(win uint32) uint32 {
+	w.L.Lock()
+	for w.win == 0 {
+		w.Wait()
+	}
+	if w.win < win {
+		win = w.win
+	}
+	w.win -= win
+	w.L.Unlock()
+	return win
+}

+ 5 - 4
ssh/server.go

@@ -543,18 +543,19 @@ func (s *ServerConn) Accept() (Channel, error) {
 			case *channelOpenMsg:
 				c := &serverChan{
 					channel: channel{
-						conn:     s,
-						remoteId: msg.PeersId,
+						conn:      s,
+						remoteId:  msg.PeersId,
+						remoteWin: window{Cond: newCond()},
 					},
-					theirWindow:   msg.PeersWindow,
 					chanType:      msg.ChanType,
 					maxPacketSize: msg.MaxPacketSize,
 					extraData:     msg.TypeSpecificData,
 					myWindow:      defaultWindowSize,
 					serverConn:    s,
-					cond:          sync.NewCond(new(sync.Mutex)),
+					cond:          newCond(),
 					pendingData:   make([]byte, defaultWindowSize),
 				}
+				c.remoteWin.add(msg.PeersWindow)
 				s.lock.Lock()
 				c.localId = s.nextChanId
 				s.nextChanId++