channel.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "errors"
  7. "io"
  8. "sync"
  9. )
  10. // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
  11. // section 5.2.
  12. type extendedDataTypeCode uint32
  13. // extendedDataStderr is the extended data type that is used for stderr.
  14. const extendedDataStderr extendedDataTypeCode = 1
  15. // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
  16. // SSH connection. Channel.Read can return a ChannelRequest as an error.
  17. type Channel interface {
  18. // Accept accepts the channel creation request.
  19. Accept() error
  20. // Reject rejects the channel creation request. After calling this, no
  21. // other methods on the Channel may be called. If they are then the
  22. // peer is likely to signal a protocol error and drop the connection.
  23. Reject(reason RejectionReason, message string) error
  24. // Read may return a ChannelRequest as an error.
  25. Read(data []byte) (int, error)
  26. Write(data []byte) (int, error)
  27. Close() error
  28. // Stderr returns an io.Writer that writes to this channel with the
  29. // extended data type set to stderr.
  30. Stderr() io.Writer
  31. // AckRequest either sends an ack or nack to the channel request.
  32. AckRequest(ok bool) error
  33. // ChannelType returns the type of the channel, as supplied by the
  34. // client.
  35. ChannelType() string
  36. // ExtraData returns the arbitary payload for this channel, as supplied
  37. // by the client. This data is specific to the channel type.
  38. ExtraData() []byte
  39. }
  40. // ChannelRequest represents a request sent on a channel, outside of the normal
  41. // stream of bytes. It may result from calling Read on a Channel.
  42. type ChannelRequest struct {
  43. Request string
  44. WantReply bool
  45. Payload []byte
  46. }
  47. func (c ChannelRequest) Error() string {
  48. return "ssh: channel request received"
  49. }
  50. // RejectionReason is an enumeration used when rejecting channel creation
  51. // requests. See RFC 4254, section 5.1.
  52. type RejectionReason int
  53. const (
  54. Prohibited RejectionReason = iota + 1
  55. ConnectionFailed
  56. UnknownChannelType
  57. ResourceShortage
  58. )
  59. type channel struct {
  60. // immutable once created
  61. chanType string
  62. extraData []byte
  63. theyClosed bool
  64. theySentEOF bool
  65. weClosed bool
  66. dead bool
  67. serverConn *ServerConn
  68. myId, theirId uint32
  69. myWindow, theirWindow uint32
  70. maxPacketSize uint32
  71. err error
  72. pendingRequests []ChannelRequest
  73. pendingData []byte
  74. head, length int
  75. // This lock is inferior to serverConn.lock
  76. lock sync.Mutex
  77. cond *sync.Cond
  78. }
  79. func (c *channel) Accept() error {
  80. c.serverConn.lock.Lock()
  81. defer c.serverConn.lock.Unlock()
  82. if c.serverConn.err != nil {
  83. return c.serverConn.err
  84. }
  85. confirm := channelOpenConfirmMsg{
  86. PeersId: c.theirId,
  87. MyId: c.myId,
  88. MyWindow: c.myWindow,
  89. MaxPacketSize: c.maxPacketSize,
  90. }
  91. return c.serverConn.writePacket(marshal(msgChannelOpenConfirm, confirm))
  92. }
  93. func (c *channel) Reject(reason RejectionReason, message string) error {
  94. c.serverConn.lock.Lock()
  95. defer c.serverConn.lock.Unlock()
  96. if c.serverConn.err != nil {
  97. return c.serverConn.err
  98. }
  99. reject := channelOpenFailureMsg{
  100. PeersId: c.theirId,
  101. Reason: reason,
  102. Message: message,
  103. Language: "en",
  104. }
  105. return c.serverConn.writePacket(marshal(msgChannelOpenFailure, reject))
  106. }
  107. func (c *channel) handlePacket(packet interface{}) {
  108. c.lock.Lock()
  109. defer c.lock.Unlock()
  110. switch packet := packet.(type) {
  111. case *channelRequestMsg:
  112. req := ChannelRequest{
  113. Request: packet.Request,
  114. WantReply: packet.WantReply,
  115. Payload: packet.RequestSpecificData,
  116. }
  117. c.pendingRequests = append(c.pendingRequests, req)
  118. c.cond.Signal()
  119. case *channelCloseMsg:
  120. c.theyClosed = true
  121. c.cond.Signal()
  122. case *channelEOFMsg:
  123. c.theySentEOF = true
  124. c.cond.Signal()
  125. case *windowAdjustMsg:
  126. c.theirWindow += packet.AdditionalBytes
  127. c.cond.Signal()
  128. default:
  129. panic("unknown packet type")
  130. }
  131. }
  132. func (c *channel) handleData(data []byte) {
  133. c.lock.Lock()
  134. defer c.lock.Unlock()
  135. // The other side should never send us more than our window.
  136. if len(data)+c.length > len(c.pendingData) {
  137. // TODO(agl): we should tear down the channel with a protocol
  138. // error.
  139. return
  140. }
  141. c.myWindow -= uint32(len(data))
  142. for i := 0; i < 2; i++ {
  143. tail := c.head + c.length
  144. if tail >= len(c.pendingData) {
  145. tail -= len(c.pendingData)
  146. }
  147. n := copy(c.pendingData[tail:], data)
  148. data = data[n:]
  149. c.length += n
  150. }
  151. c.cond.Signal()
  152. }
  153. func (c *channel) Stderr() io.Writer {
  154. return extendedDataChannel{c: c, t: extendedDataStderr}
  155. }
  156. // extendedDataChannel is an io.Writer that writes any data to c as extended
  157. // data of the given type.
  158. type extendedDataChannel struct {
  159. t extendedDataTypeCode
  160. c *channel
  161. }
  162. func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
  163. c := edc.c
  164. for len(data) > 0 {
  165. var space uint32
  166. if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
  167. return 0, err
  168. }
  169. todo := data
  170. if uint32(len(todo)) > space {
  171. todo = todo[:space]
  172. }
  173. packet := make([]byte, 1+4+4+4+len(todo))
  174. packet[0] = msgChannelExtendedData
  175. marshalUint32(packet[1:], c.theirId)
  176. marshalUint32(packet[5:], uint32(edc.t))
  177. marshalUint32(packet[9:], uint32(len(todo)))
  178. copy(packet[13:], todo)
  179. c.serverConn.lock.Lock()
  180. err = c.serverConn.writePacket(packet)
  181. c.serverConn.lock.Unlock()
  182. if err != nil {
  183. return
  184. }
  185. n += len(todo)
  186. data = data[len(todo):]
  187. }
  188. return
  189. }
  190. func (c *channel) Read(data []byte) (n int, err error) {
  191. c.lock.Lock()
  192. defer c.lock.Unlock()
  193. if c.err != nil {
  194. return 0, c.err
  195. }
  196. for {
  197. if c.theySentEOF || c.theyClosed || c.dead {
  198. return 0, io.EOF
  199. }
  200. if len(c.pendingRequests) > 0 {
  201. req := c.pendingRequests[0]
  202. if len(c.pendingRequests) == 1 {
  203. c.pendingRequests = nil
  204. } else {
  205. oldPendingRequests := c.pendingRequests
  206. c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
  207. copy(c.pendingRequests, oldPendingRequests[1:])
  208. }
  209. return 0, req
  210. }
  211. if c.length > 0 {
  212. tail := min(c.head+c.length, len(c.pendingData))
  213. n = copy(data, c.pendingData[c.head:tail])
  214. c.head += n
  215. c.length -= n
  216. if c.head == len(c.pendingData) {
  217. c.head = 0
  218. }
  219. windowAdjustment := uint32(len(c.pendingData)-c.length) - c.myWindow
  220. if windowAdjustment >= uint32(len(c.pendingData)/2) {
  221. packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
  222. PeersId: c.theirId,
  223. AdditionalBytes: windowAdjustment,
  224. })
  225. c.serverConn.lock.Lock()
  226. err = c.serverConn.writePacket(packet)
  227. c.serverConn.lock.Unlock()
  228. if err != nil {
  229. return
  230. }
  231. c.myWindow += windowAdjustment
  232. }
  233. return
  234. }
  235. c.cond.Wait()
  236. }
  237. panic("unreachable")
  238. }
  239. // getWindowSpace takes, at most, max bytes of space from the peer's window. It
  240. // returns the number of bytes actually reserved.
  241. func (c *channel) getWindowSpace(max uint32) (uint32, error) {
  242. c.lock.Lock()
  243. defer c.lock.Unlock()
  244. for {
  245. if c.dead || c.weClosed {
  246. return 0, io.EOF
  247. }
  248. if c.theirWindow > 0 {
  249. break
  250. }
  251. c.cond.Wait()
  252. }
  253. taken := c.theirWindow
  254. if taken > max {
  255. taken = max
  256. }
  257. c.theirWindow -= taken
  258. return taken, nil
  259. }
  260. func (c *channel) Write(data []byte) (n int, err error) {
  261. for len(data) > 0 {
  262. var space uint32
  263. if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
  264. return 0, err
  265. }
  266. todo := data
  267. if uint32(len(todo)) > space {
  268. todo = todo[:space]
  269. }
  270. packet := make([]byte, 1+4+4+len(todo))
  271. packet[0] = msgChannelData
  272. marshalUint32(packet[1:], c.theirId)
  273. marshalUint32(packet[5:], uint32(len(todo)))
  274. copy(packet[9:], todo)
  275. c.serverConn.lock.Lock()
  276. err = c.serverConn.writePacket(packet)
  277. c.serverConn.lock.Unlock()
  278. if err != nil {
  279. return
  280. }
  281. n += len(todo)
  282. data = data[len(todo):]
  283. }
  284. return
  285. }
  286. func (c *channel) Close() error {
  287. c.serverConn.lock.Lock()
  288. defer c.serverConn.lock.Unlock()
  289. if c.serverConn.err != nil {
  290. return c.serverConn.err
  291. }
  292. if c.weClosed {
  293. return errors.New("ssh: channel already closed")
  294. }
  295. c.weClosed = true
  296. closeMsg := channelCloseMsg{
  297. PeersId: c.theirId,
  298. }
  299. return c.serverConn.writePacket(marshal(msgChannelClose, closeMsg))
  300. }
  301. func (c *channel) AckRequest(ok bool) error {
  302. c.serverConn.lock.Lock()
  303. defer c.serverConn.lock.Unlock()
  304. if c.serverConn.err != nil {
  305. return c.serverConn.err
  306. }
  307. if !ok {
  308. ack := channelRequestFailureMsg{
  309. PeersId: c.theirId,
  310. }
  311. return c.serverConn.writePacket(marshal(msgChannelFailure, ack))
  312. }
  313. ack := channelRequestSuccessMsg{
  314. PeersId: c.theirId,
  315. }
  316. return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
  317. }
  318. func (c *channel) ChannelType() string {
  319. return c.chanType
  320. }
  321. func (c *channel) ExtraData() []byte {
  322. return c.extraData
  323. }