channel.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "errors"
  7. "io"
  8. "sync"
  9. )
  10. // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
  11. // section 5.2.
  12. type extendedDataTypeCode uint32
  13. // extendedDataStderr is the extended data type that is used for stderr.
  14. const extendedDataStderr extendedDataTypeCode = 1
  15. // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
  16. // SSH connection. Channel.Read can return a ChannelRequest as an error.
  17. type Channel interface {
  18. // Accept accepts the channel creation request.
  19. Accept() error
  20. // Reject rejects the channel creation request. After calling this, no
  21. // other methods on the Channel may be called. If they are then the
  22. // peer is likely to signal a protocol error and drop the connection.
  23. Reject(reason RejectionReason, message string) error
  24. // Read may return a ChannelRequest as an error.
  25. Read(data []byte) (int, error)
  26. Write(data []byte) (int, error)
  27. Close() error
  28. // Stderr returns an io.Writer that writes to this channel with the
  29. // extended data type set to stderr.
  30. Stderr() io.Writer
  31. // AckRequest either sends an ack or nack to the channel request.
  32. AckRequest(ok bool) error
  33. // ChannelType returns the type of the channel, as supplied by the
  34. // client.
  35. ChannelType() string
  36. // ExtraData returns the arbitary payload for this channel, as supplied
  37. // by the client. This data is specific to the channel type.
  38. ExtraData() []byte
  39. }
  40. // ChannelRequest represents a request sent on a channel, outside of the normal
  41. // stream of bytes. It may result from calling Read on a Channel.
  42. type ChannelRequest struct {
  43. Request string
  44. WantReply bool
  45. Payload []byte
  46. }
  47. func (c ChannelRequest) Error() string {
  48. return "ssh: channel request received"
  49. }
  50. // RejectionReason is an enumeration used when rejecting channel creation
  51. // requests. See RFC 4254, section 5.1.
  52. type RejectionReason int
  53. const (
  54. Prohibited RejectionReason = iota + 1
  55. ConnectionFailed
  56. UnknownChannelType
  57. ResourceShortage
  58. )
  59. type channel struct {
  60. conn // the underlying transport
  61. localId, remoteId uint32
  62. theyClosed bool // indicates the close msg has been received from the remote side
  63. weClosed bool // incidates the close msg has been sent from our side
  64. theySentEOF bool // used by serverChan
  65. dead bool // used by ServerChan to force close
  66. }
  67. func (c *channel) sendWindowAdj(n int) error {
  68. msg := windowAdjustMsg{
  69. PeersId: c.remoteId,
  70. AdditionalBytes: uint32(n),
  71. }
  72. return c.writePacket(marshal(msgChannelWindowAdjust, msg))
  73. }
  74. // sendClose signals the intent to close the channel.
  75. func (c *channel) sendClose() error {
  76. return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
  77. PeersId: c.remoteId,
  78. }))
  79. }
  80. // sendEOF sends EOF to the server. RFC 4254 Section 5.3
  81. func (c *channel) sendEOF() error {
  82. return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
  83. PeersId: c.remoteId,
  84. }))
  85. }
  86. func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
  87. reject := channelOpenFailureMsg{
  88. PeersId: c.remoteId,
  89. Reason: reason,
  90. Message: message,
  91. Language: "en",
  92. }
  93. return c.writePacket(marshal(msgChannelOpenFailure, reject))
  94. }
  95. type serverChan struct {
  96. channel
  97. // immutable once created
  98. chanType string
  99. extraData []byte
  100. serverConn *ServerConn
  101. myWindow, theirWindow uint32
  102. maxPacketSize uint32
  103. err error
  104. pendingRequests []ChannelRequest
  105. pendingData []byte
  106. head, length int
  107. // This lock is inferior to serverConn.lock
  108. cond *sync.Cond
  109. }
  110. func (c *serverChan) Accept() error {
  111. c.serverConn.lock.Lock()
  112. defer c.serverConn.lock.Unlock()
  113. if c.serverConn.err != nil {
  114. return c.serverConn.err
  115. }
  116. confirm := channelOpenConfirmMsg{
  117. PeersId: c.remoteId,
  118. MyId: c.localId,
  119. MyWindow: c.myWindow,
  120. MaxPacketSize: c.maxPacketSize,
  121. }
  122. return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
  123. }
  124. func (c *serverChan) Reject(reason RejectionReason, message string) error {
  125. c.serverConn.lock.Lock()
  126. defer c.serverConn.lock.Unlock()
  127. if c.serverConn.err != nil {
  128. return c.serverConn.err
  129. }
  130. return c.sendChannelOpenFailure(reason, message)
  131. }
  132. func (c *serverChan) handlePacket(packet interface{}) {
  133. c.cond.L.Lock()
  134. defer c.cond.L.Unlock()
  135. switch packet := packet.(type) {
  136. case *channelRequestMsg:
  137. req := ChannelRequest{
  138. Request: packet.Request,
  139. WantReply: packet.WantReply,
  140. Payload: packet.RequestSpecificData,
  141. }
  142. c.pendingRequests = append(c.pendingRequests, req)
  143. c.cond.Signal()
  144. case *channelCloseMsg:
  145. c.theyClosed = true
  146. c.cond.Signal()
  147. case *channelEOFMsg:
  148. c.theySentEOF = true
  149. c.cond.Signal()
  150. case *windowAdjustMsg:
  151. c.theirWindow += packet.AdditionalBytes
  152. c.cond.Signal()
  153. default:
  154. panic("unknown packet type")
  155. }
  156. }
  157. func (c *serverChan) handleData(data []byte) {
  158. c.cond.L.Lock()
  159. defer c.cond.L.Unlock()
  160. // The other side should never send us more than our window.
  161. if len(data)+c.length > len(c.pendingData) {
  162. // TODO(agl): we should tear down the channel with a protocol
  163. // error.
  164. return
  165. }
  166. c.myWindow -= uint32(len(data))
  167. for i := 0; i < 2; i++ {
  168. tail := c.head + c.length
  169. if tail >= len(c.pendingData) {
  170. tail -= len(c.pendingData)
  171. }
  172. n := copy(c.pendingData[tail:], data)
  173. data = data[n:]
  174. c.length += n
  175. }
  176. c.cond.Signal()
  177. }
  178. func (c *serverChan) Stderr() io.Writer {
  179. return extendedDataChannel{c: c, t: extendedDataStderr}
  180. }
  181. // extendedDataChannel is an io.Writer that writes any data to c as extended
  182. // data of the given type.
  183. type extendedDataChannel struct {
  184. t extendedDataTypeCode
  185. c *serverChan
  186. }
  187. func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
  188. c := edc.c
  189. for len(data) > 0 {
  190. var space uint32
  191. if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
  192. return 0, err
  193. }
  194. todo := data
  195. if uint32(len(todo)) > space {
  196. todo = todo[:space]
  197. }
  198. packet := make([]byte, 1+4+4+4+len(todo))
  199. packet[0] = msgChannelExtendedData
  200. marshalUint32(packet[1:], c.remoteId)
  201. marshalUint32(packet[5:], uint32(edc.t))
  202. marshalUint32(packet[9:], uint32(len(todo)))
  203. copy(packet[13:], todo)
  204. if err = c.writePacket(packet); err != nil {
  205. return
  206. }
  207. n += len(todo)
  208. data = data[len(todo):]
  209. }
  210. return
  211. }
  212. func (c *serverChan) Read(data []byte) (n int, err error) {
  213. n, err, windowAdjustment := c.read(data)
  214. if windowAdjustment > 0 {
  215. packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
  216. PeersId: c.remoteId,
  217. AdditionalBytes: windowAdjustment,
  218. })
  219. err = c.writePacket(packet)
  220. }
  221. return
  222. }
  223. func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
  224. c.cond.L.Lock()
  225. defer c.cond.L.Unlock()
  226. if c.err != nil {
  227. return 0, c.err, 0
  228. }
  229. for {
  230. if c.theySentEOF || c.theyClosed || c.dead {
  231. return 0, io.EOF, 0
  232. }
  233. if len(c.pendingRequests) > 0 {
  234. req := c.pendingRequests[0]
  235. if len(c.pendingRequests) == 1 {
  236. c.pendingRequests = nil
  237. } else {
  238. oldPendingRequests := c.pendingRequests
  239. c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
  240. copy(c.pendingRequests, oldPendingRequests[1:])
  241. }
  242. return 0, req, 0
  243. }
  244. if c.length > 0 {
  245. tail := min(c.head+c.length, len(c.pendingData))
  246. n = copy(data, c.pendingData[c.head:tail])
  247. c.head += n
  248. c.length -= n
  249. if c.head == len(c.pendingData) {
  250. c.head = 0
  251. }
  252. windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow
  253. if windowAdjustment < uint32(len(c.pendingData)/2) {
  254. windowAdjustment = 0
  255. }
  256. c.myWindow += windowAdjustment
  257. return
  258. }
  259. c.cond.Wait()
  260. }
  261. panic("unreachable")
  262. }
  263. // getWindowSpace takes, at most, max bytes of space from the peer's window. It
  264. // returns the number of bytes actually reserved.
  265. func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
  266. c.cond.L.Lock()
  267. defer c.cond.L.Unlock()
  268. for {
  269. if c.dead || c.weClosed {
  270. return 0, io.EOF
  271. }
  272. if c.theirWindow > 0 {
  273. break
  274. }
  275. c.cond.Wait()
  276. }
  277. taken := c.theirWindow
  278. if taken > max {
  279. taken = max
  280. }
  281. c.theirWindow -= taken
  282. return taken, nil
  283. }
  284. func (c *serverChan) Write(data []byte) (n int, err error) {
  285. for len(data) > 0 {
  286. var space uint32
  287. if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
  288. return 0, err
  289. }
  290. todo := data
  291. if uint32(len(todo)) > space {
  292. todo = todo[:space]
  293. }
  294. packet := make([]byte, 1+4+4+len(todo))
  295. packet[0] = msgChannelData
  296. marshalUint32(packet[1:], c.remoteId)
  297. marshalUint32(packet[5:], uint32(len(todo)))
  298. copy(packet[9:], todo)
  299. if err = c.writePacket(packet); err != nil {
  300. return
  301. }
  302. n += len(todo)
  303. data = data[len(todo):]
  304. }
  305. return
  306. }
  307. func (c *serverChan) Close() error {
  308. c.serverConn.lock.Lock()
  309. defer c.serverConn.lock.Unlock()
  310. if c.serverConn.err != nil {
  311. return c.serverConn.err
  312. }
  313. if c.weClosed {
  314. return errors.New("ssh: channel already closed")
  315. }
  316. c.weClosed = true
  317. return c.sendClose()
  318. }
  319. func (c *serverChan) AckRequest(ok bool) error {
  320. c.serverConn.lock.Lock()
  321. defer c.serverConn.lock.Unlock()
  322. if c.serverConn.err != nil {
  323. return c.serverConn.err
  324. }
  325. if !ok {
  326. ack := channelRequestFailureMsg{
  327. PeersId: c.remoteId,
  328. }
  329. return c.writePacket(marshal(msgChannelFailure, ack))
  330. }
  331. ack := channelRequestSuccessMsg{
  332. PeersId: c.remoteId,
  333. }
  334. return c.writePacket(marshal(msgChannelSuccess, ack))
  335. }
  336. func (c *serverChan) ChannelType() string {
  337. return c.chanType
  338. }
  339. func (c *serverChan) ExtraData() []byte {
  340. return c.extraData
  341. }