channel.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "errors"
  7. "io"
  8. "sync"
  9. )
  10. // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
  11. // section 5.2.
  12. type extendedDataTypeCode uint32
  13. // extendedDataStderr is the extended data type that is used for stderr.
  14. const extendedDataStderr extendedDataTypeCode = 1
  15. // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
  16. // SSH connection. Channel.Read can return a ChannelRequest as an error.
  17. type Channel interface {
  18. // Accept accepts the channel creation request.
  19. Accept() error
  20. // Reject rejects the channel creation request. After calling this, no
  21. // other methods on the Channel may be called. If they are then the
  22. // peer is likely to signal a protocol error and drop the connection.
  23. Reject(reason RejectionReason, message string) error
  24. // Read may return a ChannelRequest as an error.
  25. Read(data []byte) (int, error)
  26. Write(data []byte) (int, error)
  27. Close() error
  28. // Stderr returns an io.Writer that writes to this channel with the
  29. // extended data type set to stderr.
  30. Stderr() io.Writer
  31. // AckRequest either sends an ack or nack to the channel request.
  32. AckRequest(ok bool) error
  33. // ChannelType returns the type of the channel, as supplied by the
  34. // client.
  35. ChannelType() string
  36. // ExtraData returns the arbitary payload for this channel, as supplied
  37. // by the client. This data is specific to the channel type.
  38. ExtraData() []byte
  39. }
  40. // ChannelRequest represents a request sent on a channel, outside of the normal
  41. // stream of bytes. It may result from calling Read on a Channel.
  42. type ChannelRequest struct {
  43. Request string
  44. WantReply bool
  45. Payload []byte
  46. }
  47. func (c ChannelRequest) Error() string {
  48. return "ssh: channel request received"
  49. }
  50. // RejectionReason is an enumeration used when rejecting channel creation
  51. // requests. See RFC 4254, section 5.1.
  52. type RejectionReason int
  53. const (
  54. Prohibited RejectionReason = iota + 1
  55. ConnectionFailed
  56. UnknownChannelType
  57. ResourceShortage
  58. )
  59. type serverChan struct {
  60. // immutable once created
  61. chanType string
  62. extraData []byte
  63. theyClosed bool
  64. theySentEOF bool
  65. weClosed bool
  66. dead bool
  67. serverConn *ServerConn
  68. localId, remoteId uint32
  69. myWindow, theirWindow uint32
  70. maxPacketSize uint32
  71. err error
  72. pendingRequests []ChannelRequest
  73. pendingData []byte
  74. head, length int
  75. // This lock is inferior to serverConn.lock
  76. lock sync.Mutex
  77. cond *sync.Cond
  78. }
  79. func (c *serverChan) Accept() error {
  80. c.serverConn.lock.Lock()
  81. defer c.serverConn.lock.Unlock()
  82. if c.serverConn.err != nil {
  83. return c.serverConn.err
  84. }
  85. confirm := channelOpenConfirmMsg{
  86. PeersId: c.remoteId,
  87. MyId: c.localId,
  88. MyWindow: c.myWindow,
  89. MaxPacketSize: c.maxPacketSize,
  90. }
  91. return c.serverConn.writePacket(marshal(msgChannelOpenConfirm, confirm))
  92. }
  93. func (c *serverChan) Reject(reason RejectionReason, message string) error {
  94. c.serverConn.lock.Lock()
  95. defer c.serverConn.lock.Unlock()
  96. if c.serverConn.err != nil {
  97. return c.serverConn.err
  98. }
  99. reject := channelOpenFailureMsg{
  100. PeersId: c.remoteId,
  101. Reason: reason,
  102. Message: message,
  103. Language: "en",
  104. }
  105. return c.serverConn.writePacket(marshal(msgChannelOpenFailure, reject))
  106. }
  107. func (c *serverChan) handlePacket(packet interface{}) {
  108. c.lock.Lock()
  109. defer c.lock.Unlock()
  110. switch packet := packet.(type) {
  111. case *channelRequestMsg:
  112. req := ChannelRequest{
  113. Request: packet.Request,
  114. WantReply: packet.WantReply,
  115. Payload: packet.RequestSpecificData,
  116. }
  117. c.pendingRequests = append(c.pendingRequests, req)
  118. c.cond.Signal()
  119. case *channelCloseMsg:
  120. c.theyClosed = true
  121. c.cond.Signal()
  122. case *channelEOFMsg:
  123. c.theySentEOF = true
  124. c.cond.Signal()
  125. case *windowAdjustMsg:
  126. c.theirWindow += packet.AdditionalBytes
  127. c.cond.Signal()
  128. default:
  129. panic("unknown packet type")
  130. }
  131. }
  132. func (c *serverChan) handleData(data []byte) {
  133. c.lock.Lock()
  134. defer c.lock.Unlock()
  135. // The other side should never send us more than our window.
  136. if len(data)+c.length > len(c.pendingData) {
  137. // TODO(agl): we should tear down the channel with a protocol
  138. // error.
  139. return
  140. }
  141. c.myWindow -= uint32(len(data))
  142. for i := 0; i < 2; i++ {
  143. tail := c.head + c.length
  144. if tail >= len(c.pendingData) {
  145. tail -= len(c.pendingData)
  146. }
  147. n := copy(c.pendingData[tail:], data)
  148. data = data[n:]
  149. c.length += n
  150. }
  151. c.cond.Signal()
  152. }
  153. func (c *serverChan) Stderr() io.Writer {
  154. return extendedDataChannel{c: c, t: extendedDataStderr}
  155. }
  156. // extendedDataChannel is an io.Writer that writes any data to c as extended
  157. // data of the given type.
  158. type extendedDataChannel struct {
  159. t extendedDataTypeCode
  160. c *serverChan
  161. }
  162. func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
  163. c := edc.c
  164. for len(data) > 0 {
  165. var space uint32
  166. if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
  167. return 0, err
  168. }
  169. todo := data
  170. if uint32(len(todo)) > space {
  171. todo = todo[:space]
  172. }
  173. packet := make([]byte, 1+4+4+4+len(todo))
  174. packet[0] = msgChannelExtendedData
  175. marshalUint32(packet[1:], c.remoteId)
  176. marshalUint32(packet[5:], uint32(edc.t))
  177. marshalUint32(packet[9:], uint32(len(todo)))
  178. copy(packet[13:], todo)
  179. c.serverConn.lock.Lock()
  180. err = c.serverConn.writePacket(packet)
  181. c.serverConn.lock.Unlock()
  182. if err != nil {
  183. return
  184. }
  185. n += len(todo)
  186. data = data[len(todo):]
  187. }
  188. return
  189. }
  190. func (c *serverChan) Read(data []byte) (n int, err error) {
  191. n, err, windowAdjustment := c.read(data)
  192. if windowAdjustment > 0 {
  193. packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
  194. PeersId: c.remoteId,
  195. AdditionalBytes: windowAdjustment,
  196. })
  197. c.serverConn.lock.Lock()
  198. err = c.serverConn.writePacket(packet)
  199. c.serverConn.lock.Unlock()
  200. if err != nil {
  201. return
  202. }
  203. }
  204. return
  205. }
  206. func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
  207. c.lock.Lock()
  208. defer c.lock.Unlock()
  209. if c.err != nil {
  210. return 0, c.err, 0
  211. }
  212. for {
  213. if c.theySentEOF || c.theyClosed || c.dead {
  214. return 0, io.EOF, 0
  215. }
  216. if len(c.pendingRequests) > 0 {
  217. req := c.pendingRequests[0]
  218. if len(c.pendingRequests) == 1 {
  219. c.pendingRequests = nil
  220. } else {
  221. oldPendingRequests := c.pendingRequests
  222. c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
  223. copy(c.pendingRequests, oldPendingRequests[1:])
  224. }
  225. return 0, req, 0
  226. }
  227. if c.length > 0 {
  228. tail := min(c.head+c.length, len(c.pendingData))
  229. n = copy(data, c.pendingData[c.head:tail])
  230. c.head += n
  231. c.length -= n
  232. if c.head == len(c.pendingData) {
  233. c.head = 0
  234. }
  235. windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow
  236. if windowAdjustment < uint32(len(c.pendingData)/2) {
  237. windowAdjustment = 0
  238. }
  239. c.myWindow += windowAdjustment
  240. return
  241. }
  242. c.cond.Wait()
  243. }
  244. panic("unreachable")
  245. }
  246. // getWindowSpace takes, at most, max bytes of space from the peer's window. It
  247. // returns the number of bytes actually reserved.
  248. func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
  249. c.lock.Lock()
  250. defer c.lock.Unlock()
  251. for {
  252. if c.dead || c.weClosed {
  253. return 0, io.EOF
  254. }
  255. if c.theirWindow > 0 {
  256. break
  257. }
  258. c.cond.Wait()
  259. }
  260. taken := c.theirWindow
  261. if taken > max {
  262. taken = max
  263. }
  264. c.theirWindow -= taken
  265. return taken, nil
  266. }
  267. func (c *serverChan) Write(data []byte) (n int, err error) {
  268. for len(data) > 0 {
  269. var space uint32
  270. if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
  271. return 0, err
  272. }
  273. todo := data
  274. if uint32(len(todo)) > space {
  275. todo = todo[:space]
  276. }
  277. packet := make([]byte, 1+4+4+len(todo))
  278. packet[0] = msgChannelData
  279. marshalUint32(packet[1:], c.remoteId)
  280. marshalUint32(packet[5:], uint32(len(todo)))
  281. copy(packet[9:], todo)
  282. c.serverConn.lock.Lock()
  283. err = c.serverConn.writePacket(packet)
  284. c.serverConn.lock.Unlock()
  285. if err != nil {
  286. return
  287. }
  288. n += len(todo)
  289. data = data[len(todo):]
  290. }
  291. return
  292. }
  293. func (c *serverChan) Close() error {
  294. c.serverConn.lock.Lock()
  295. defer c.serverConn.lock.Unlock()
  296. if c.serverConn.err != nil {
  297. return c.serverConn.err
  298. }
  299. if c.weClosed {
  300. return errors.New("ssh: channel already closed")
  301. }
  302. c.weClosed = true
  303. closeMsg := channelCloseMsg{
  304. PeersId: c.remoteId,
  305. }
  306. return c.serverConn.writePacket(marshal(msgChannelClose, closeMsg))
  307. }
  308. func (c *serverChan) AckRequest(ok bool) error {
  309. c.serverConn.lock.Lock()
  310. defer c.serverConn.lock.Unlock()
  311. if c.serverConn.err != nil {
  312. return c.serverConn.err
  313. }
  314. if !ok {
  315. ack := channelRequestFailureMsg{
  316. PeersId: c.remoteId,
  317. }
  318. return c.serverConn.writePacket(marshal(msgChannelFailure, ack))
  319. }
  320. ack := channelRequestSuccessMsg{
  321. PeersId: c.remoteId,
  322. }
  323. return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
  324. }
  325. func (c *serverChan) ChannelType() string {
  326. return c.chanType
  327. }
  328. func (c *serverChan) ExtraData() []byte {
  329. return c.extraData
  330. }