|
@@ -192,7 +192,7 @@ func (c *ClientConn) mainLoop() {
|
|
|
break
|
|
break
|
|
|
}
|
|
}
|
|
|
// TODO(dfc) A note on blocking channel use.
|
|
// TODO(dfc) A note on blocking channel use.
|
|
|
- // The msg, win, data and dataExt channels of a clientChan can
|
|
|
|
|
|
|
+ // The msg, data and dataExt channels of a clientChan can
|
|
|
// cause this loop to block indefinately if the consumer does
|
|
// cause this loop to block indefinately if the consumer does
|
|
|
// not service them.
|
|
// not service them.
|
|
|
switch packet[0] {
|
|
switch packet[0] {
|
|
@@ -233,7 +233,6 @@ func (c *ClientConn) mainLoop() {
|
|
|
case *channelCloseMsg:
|
|
case *channelCloseMsg:
|
|
|
ch := c.getChan(msg.PeersId)
|
|
ch := c.getChan(msg.PeersId)
|
|
|
ch.theyClosed = true
|
|
ch.theyClosed = true
|
|
|
- close(ch.stdin.win)
|
|
|
|
|
ch.stdout.eof()
|
|
ch.stdout.eof()
|
|
|
ch.stderr.eof()
|
|
ch.stderr.eof()
|
|
|
close(ch.msg)
|
|
close(ch.msg)
|
|
@@ -255,7 +254,10 @@ func (c *ClientConn) mainLoop() {
|
|
|
case *channelRequestMsg:
|
|
case *channelRequestMsg:
|
|
|
c.getChan(msg.PeersId).msg <- msg
|
|
c.getChan(msg.PeersId).msg <- msg
|
|
|
case *windowAdjustMsg:
|
|
case *windowAdjustMsg:
|
|
|
- c.getChan(msg.PeersId).stdin.win <- int(msg.AdditionalBytes)
|
|
|
|
|
|
|
+ if !c.getChan(msg.PeersId).stdin.win.add(msg.AdditionalBytes) {
|
|
|
|
|
+ // invalid window update
|
|
|
|
|
+ break
|
|
|
|
|
+ }
|
|
|
case *disconnectMsg:
|
|
case *disconnectMsg:
|
|
|
break
|
|
break
|
|
|
default:
|
|
default:
|
|
@@ -324,7 +326,7 @@ func newClientChan(t *transport, id uint32) *clientChan {
|
|
|
msg: make(chan interface{}, 16),
|
|
msg: make(chan interface{}, 16),
|
|
|
}
|
|
}
|
|
|
c.stdin = &chanWriter{
|
|
c.stdin = &chanWriter{
|
|
|
- win: make(chan int, 16),
|
|
|
|
|
|
|
+ win: &window{Cond: sync.NewCond(new(sync.Mutex))},
|
|
|
clientChan: c,
|
|
clientChan: c,
|
|
|
}
|
|
}
|
|
|
c.stdout = &chanReader{
|
|
c.stdout = &chanReader{
|
|
@@ -345,7 +347,7 @@ func (c *clientChan) waitForChannelOpenResponse() error {
|
|
|
case *channelOpenConfirmMsg:
|
|
case *channelOpenConfirmMsg:
|
|
|
// fixup peersId field
|
|
// fixup peersId field
|
|
|
c.peersId = msg.MyId
|
|
c.peersId = msg.MyId
|
|
|
- c.stdin.win <- int(msg.MyWindow)
|
|
|
|
|
|
|
+ c.stdin.win.add(msg.MyWindow)
|
|
|
return nil
|
|
return nil
|
|
|
case *channelOpenFailureMsg:
|
|
case *channelOpenFailureMsg:
|
|
|
return errors.New(safeString(msg.Message))
|
|
return errors.New(safeString(msg.Message))
|
|
@@ -417,22 +419,16 @@ func (c *chanlist) remove(id uint32) {
|
|
|
|
|
|
|
|
// A chanWriter represents the stdin of a remote process.
|
|
// A chanWriter represents the stdin of a remote process.
|
|
|
type chanWriter struct {
|
|
type chanWriter struct {
|
|
|
- win chan int // receives window adjustments
|
|
|
|
|
- rwin int // current rwin size
|
|
|
|
|
|
|
+ win *window
|
|
|
clientChan *clientChan // the channel backing this writer
|
|
clientChan *clientChan // the channel backing this writer
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Write writes data to the remote process's standard input.
|
|
// Write writes data to the remote process's standard input.
|
|
|
func (w *chanWriter) Write(data []byte) (written int, err error) {
|
|
func (w *chanWriter) Write(data []byte) (written int, err error) {
|
|
|
for len(data) > 0 {
|
|
for len(data) > 0 {
|
|
|
- for w.rwin < 1 {
|
|
|
|
|
- win, ok := <-w.win
|
|
|
|
|
- if !ok {
|
|
|
|
|
- return 0, io.EOF
|
|
|
|
|
- }
|
|
|
|
|
- w.rwin += win
|
|
|
|
|
- }
|
|
|
|
|
- n := min(len(data), w.rwin)
|
|
|
|
|
|
|
+ // n cannot be larger than 2^31 as len(data) cannot
|
|
|
|
|
+ // be larger than 2^31
|
|
|
|
|
+ n := int(w.win.reserve(uint32(len(data))))
|
|
|
peersId := w.clientChan.peersId
|
|
peersId := w.clientChan.peersId
|
|
|
packet := []byte{
|
|
packet := []byte{
|
|
|
msgChannelData,
|
|
msgChannelData,
|
|
@@ -443,7 +439,6 @@ func (w *chanWriter) Write(data []byte) (written int, err error) {
|
|
|
break
|
|
break
|
|
|
}
|
|
}
|
|
|
data = data[n:]
|
|
data = data[n:]
|
|
|
- w.rwin -= n
|
|
|
|
|
written += n
|
|
written += n
|
|
|
}
|
|
}
|
|
|
return
|
|
return
|
|
@@ -507,3 +502,46 @@ func (r *chanReader) Read(data []byte) (int, error) {
|
|
|
}
|
|
}
|
|
|
panic("unreachable")
|
|
panic("unreachable")
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+// window represents the buffer available to clients
|
|
|
|
|
+// wishing to write to a channel.
|
|
|
|
|
+type window struct {
|
|
|
|
|
+ *sync.Cond
|
|
|
|
|
+ win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// add adds win to the amount of window available
|
|
|
|
|
+// for consumers.
|
|
|
|
|
+func (w *window) add(win uint32) bool {
|
|
|
|
|
+ if win == 0 {
|
|
|
|
|
+ return false
|
|
|
|
|
+ }
|
|
|
|
|
+ w.L.Lock()
|
|
|
|
|
+ if w.win+win < win {
|
|
|
|
|
+ w.L.Unlock()
|
|
|
|
|
+ return false
|
|
|
|
|
+ }
|
|
|
|
|
+ w.win += win
|
|
|
|
|
+ // It is unusual that multiple goroutines would be attempting to reserve
|
|
|
|
|
+ // window space, but not guaranteed. Use broadcast to notify all waiters
|
|
|
|
|
+ // that additional window is available.
|
|
|
|
|
+ w.Broadcast()
|
|
|
|
|
+ w.L.Unlock()
|
|
|
|
|
+ return true
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// reserve reserves win from the available window capacity.
|
|
|
|
|
+// If no capacity remains, reserve will block. reserve may
|
|
|
|
|
+// return less than requested.
|
|
|
|
|
+func (w *window) reserve(win uint32) uint32 {
|
|
|
|
|
+ w.L.Lock()
|
|
|
|
|
+ for w.win == 0 {
|
|
|
|
|
+ w.Wait()
|
|
|
|
|
+ }
|
|
|
|
|
+ if w.win < win {
|
|
|
|
|
+ win = w.win
|
|
|
|
|
+ }
|
|
|
|
|
+ w.win -= win
|
|
|
|
|
+ w.L.Unlock()
|
|
|
|
|
+ return win
|
|
|
|
|
+}
|