channel.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. "sync"
  10. )
  11. // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
  12. // section 5.2.
  13. type extendedDataTypeCode uint32
  14. const (
  15. // extendedDataStderr is the extended data type that is used for stderr.
  16. extendedDataStderr extendedDataTypeCode = 1
  17. // minPacketLength defines the smallest valid packet
  18. minPacketLength = 9
  19. )
  20. // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
  21. // SSH connection. Channel.Read can return a ChannelRequest as an error.
  22. type Channel interface {
  23. // Accept accepts the channel creation request.
  24. Accept() error
  25. // Reject rejects the channel creation request. After calling this, no
  26. // other methods on the Channel may be called. If they are then the
  27. // peer is likely to signal a protocol error and drop the connection.
  28. Reject(reason RejectionReason, message string) error
  29. // Read may return a ChannelRequest as an error.
  30. Read(data []byte) (int, error)
  31. Write(data []byte) (int, error)
  32. Close() error
  33. // Stderr returns an io.Writer that writes to this channel with the
  34. // extended data type set to stderr.
  35. Stderr() io.Writer
  36. // AckRequest either sends an ack or nack to the channel request.
  37. AckRequest(ok bool) error
  38. // ChannelType returns the type of the channel, as supplied by the
  39. // client.
  40. ChannelType() string
  41. // ExtraData returns the arbitary payload for this channel, as supplied
  42. // by the client. This data is specific to the channel type.
  43. ExtraData() []byte
  44. }
  45. // ChannelRequest represents a request sent on a channel, outside of the normal
  46. // stream of bytes. It may result from calling Read on a Channel.
  47. type ChannelRequest struct {
  48. Request string
  49. WantReply bool
  50. Payload []byte
  51. }
  52. func (c ChannelRequest) Error() string {
  53. return "ssh: channel request received"
  54. }
  55. // RejectionReason is an enumeration used when rejecting channel creation
  56. // requests. See RFC 4254, section 5.1.
  57. type RejectionReason uint32
  58. const (
  59. Prohibited RejectionReason = iota + 1
  60. ConnectionFailed
  61. UnknownChannelType
  62. ResourceShortage
  63. )
  64. type channel struct {
  65. conn // the underlying transport
  66. localId, remoteId uint32
  67. remoteWin window
  68. maxPacket uint32
  69. theyClosed bool // indicates the close msg has been received from the remote side
  70. weClosed bool // incidates the close msg has been sent from our side
  71. theySentEOF bool // used by serverChan
  72. dead bool // used by ServerChan to force close
  73. }
  74. func (c *channel) sendWindowAdj(n int) error {
  75. msg := windowAdjustMsg{
  76. PeersId: c.remoteId,
  77. AdditionalBytes: uint32(n),
  78. }
  79. return c.writePacket(marshal(msgChannelWindowAdjust, msg))
  80. }
  81. // sendClose signals the intent to close the channel.
  82. func (c *channel) sendClose() error {
  83. return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
  84. PeersId: c.remoteId,
  85. }))
  86. }
  87. // sendEOF sends EOF to the server. RFC 4254 Section 5.3
  88. func (c *channel) sendEOF() error {
  89. return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
  90. PeersId: c.remoteId,
  91. }))
  92. }
  93. func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
  94. reject := channelOpenFailureMsg{
  95. PeersId: c.remoteId,
  96. Reason: reason,
  97. Message: message,
  98. Language: "en",
  99. }
  100. return c.writePacket(marshal(msgChannelOpenFailure, reject))
  101. }
  102. func (c *channel) writePacket(b []byte) error {
  103. if uint32(len(b)) > c.maxPacket {
  104. return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
  105. }
  106. return c.conn.writePacket(b)
  107. }
  108. type serverChan struct {
  109. channel
  110. // immutable once created
  111. chanType string
  112. extraData []byte
  113. serverConn *ServerConn
  114. myWindow uint32
  115. err error
  116. pendingRequests []ChannelRequest
  117. pendingData []byte
  118. head, length int
  119. // This lock is inferior to serverConn.lock
  120. cond *sync.Cond
  121. }
  122. func (c *serverChan) Accept() error {
  123. c.serverConn.lock.Lock()
  124. defer c.serverConn.lock.Unlock()
  125. if c.serverConn.err != nil {
  126. return c.serverConn.err
  127. }
  128. confirm := channelOpenConfirmMsg{
  129. PeersId: c.remoteId,
  130. MyId: c.localId,
  131. MyWindow: c.myWindow,
  132. MaxPacketSize: c.maxPacket,
  133. }
  134. return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
  135. }
  136. func (c *serverChan) Reject(reason RejectionReason, message string) error {
  137. c.serverConn.lock.Lock()
  138. defer c.serverConn.lock.Unlock()
  139. if c.serverConn.err != nil {
  140. return c.serverConn.err
  141. }
  142. return c.sendChannelOpenFailure(reason, message)
  143. }
  144. func (c *serverChan) handlePacket(packet interface{}) {
  145. c.cond.L.Lock()
  146. defer c.cond.L.Unlock()
  147. switch packet := packet.(type) {
  148. case *channelRequestMsg:
  149. req := ChannelRequest{
  150. Request: packet.Request,
  151. WantReply: packet.WantReply,
  152. Payload: packet.RequestSpecificData,
  153. }
  154. c.pendingRequests = append(c.pendingRequests, req)
  155. c.cond.Signal()
  156. case *channelCloseMsg:
  157. c.theyClosed = true
  158. c.cond.Signal()
  159. case *channelEOFMsg:
  160. c.theySentEOF = true
  161. c.cond.Signal()
  162. case *windowAdjustMsg:
  163. if !c.remoteWin.add(packet.AdditionalBytes) {
  164. panic("illegal window update")
  165. }
  166. default:
  167. panic("unknown packet type")
  168. }
  169. }
  170. func (c *serverChan) handleData(data []byte) {
  171. c.cond.L.Lock()
  172. defer c.cond.L.Unlock()
  173. // The other side should never send us more than our window.
  174. if len(data)+c.length > len(c.pendingData) {
  175. // TODO(agl): we should tear down the channel with a protocol
  176. // error.
  177. return
  178. }
  179. c.myWindow -= uint32(len(data))
  180. for i := 0; i < 2; i++ {
  181. tail := c.head + c.length
  182. if tail >= len(c.pendingData) {
  183. tail -= len(c.pendingData)
  184. }
  185. n := copy(c.pendingData[tail:], data)
  186. data = data[n:]
  187. c.length += n
  188. }
  189. c.cond.Signal()
  190. }
  191. func (c *serverChan) Stderr() io.Writer {
  192. return extendedDataChannel{c: c, t: extendedDataStderr}
  193. }
  194. // extendedDataChannel is an io.Writer that writes any data to c as extended
  195. // data of the given type.
  196. type extendedDataChannel struct {
  197. t extendedDataTypeCode
  198. c *serverChan
  199. }
  200. func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
  201. c := edc.c
  202. for len(data) > 0 {
  203. var space uint32
  204. if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
  205. return 0, err
  206. }
  207. todo := data
  208. if uint32(len(todo)) > space {
  209. todo = todo[:space]
  210. }
  211. packet := make([]byte, 1+4+4+4+len(todo))
  212. packet[0] = msgChannelExtendedData
  213. marshalUint32(packet[1:], c.remoteId)
  214. marshalUint32(packet[5:], uint32(edc.t))
  215. marshalUint32(packet[9:], uint32(len(todo)))
  216. copy(packet[13:], todo)
  217. if err = c.writePacket(packet); err != nil {
  218. return
  219. }
  220. n += len(todo)
  221. data = data[len(todo):]
  222. }
  223. return
  224. }
  225. func (c *serverChan) Read(data []byte) (n int, err error) {
  226. n, err, windowAdjustment := c.read(data)
  227. if windowAdjustment > 0 {
  228. packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
  229. PeersId: c.remoteId,
  230. AdditionalBytes: windowAdjustment,
  231. })
  232. err = c.writePacket(packet)
  233. }
  234. return
  235. }
  236. func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
  237. c.cond.L.Lock()
  238. defer c.cond.L.Unlock()
  239. if c.err != nil {
  240. return 0, c.err, 0
  241. }
  242. for {
  243. if c.theySentEOF || c.theyClosed || c.dead {
  244. return 0, io.EOF, 0
  245. }
  246. if len(c.pendingRequests) > 0 {
  247. req := c.pendingRequests[0]
  248. if len(c.pendingRequests) == 1 {
  249. c.pendingRequests = nil
  250. } else {
  251. oldPendingRequests := c.pendingRequests
  252. c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
  253. copy(c.pendingRequests, oldPendingRequests[1:])
  254. }
  255. return 0, req, 0
  256. }
  257. if c.length > 0 {
  258. tail := min(c.head+c.length, len(c.pendingData))
  259. n = copy(data, c.pendingData[c.head:tail])
  260. c.head += n
  261. c.length -= n
  262. if c.head == len(c.pendingData) {
  263. c.head = 0
  264. }
  265. windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow
  266. if windowAdjustment < uint32(len(c.pendingData)/2) {
  267. windowAdjustment = 0
  268. }
  269. c.myWindow += windowAdjustment
  270. return
  271. }
  272. c.cond.Wait()
  273. }
  274. panic("unreachable")
  275. }
  276. // getWindowSpace takes, at most, max bytes of space from the peer's window. It
  277. // returns the number of bytes actually reserved.
  278. func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
  279. var err error
  280. // TODO(dfc) This lock and check of c.weClosed is necessary because unlike
  281. // clientChan, c.weClosed is observed by more than one goroutine.
  282. c.cond.L.Lock()
  283. if c.dead || c.weClosed {
  284. err = io.EOF
  285. }
  286. c.cond.L.Unlock()
  287. if err != nil {
  288. return 0, err
  289. }
  290. return c.remoteWin.reserve(max), nil
  291. }
  292. func (c *serverChan) Write(data []byte) (n int, err error) {
  293. for len(data) > 0 {
  294. var space uint32
  295. if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
  296. return 0, err
  297. }
  298. todo := data
  299. if uint32(len(todo)) > space {
  300. todo = todo[:space]
  301. }
  302. packet := make([]byte, 1+4+4+len(todo))
  303. packet[0] = msgChannelData
  304. marshalUint32(packet[1:], c.remoteId)
  305. marshalUint32(packet[5:], uint32(len(todo)))
  306. copy(packet[9:], todo)
  307. if err = c.writePacket(packet); err != nil {
  308. return
  309. }
  310. n += len(todo)
  311. data = data[len(todo):]
  312. }
  313. return
  314. }
  315. func (c *serverChan) Close() error {
  316. c.serverConn.lock.Lock()
  317. defer c.serverConn.lock.Unlock()
  318. if c.serverConn.err != nil {
  319. return c.serverConn.err
  320. }
  321. if c.weClosed {
  322. return errors.New("ssh: channel already closed")
  323. }
  324. c.weClosed = true
  325. return c.sendClose()
  326. }
  327. func (c *serverChan) AckRequest(ok bool) error {
  328. c.serverConn.lock.Lock()
  329. defer c.serverConn.lock.Unlock()
  330. if c.serverConn.err != nil {
  331. return c.serverConn.err
  332. }
  333. if !ok {
  334. ack := channelRequestFailureMsg{
  335. PeersId: c.remoteId,
  336. }
  337. return c.writePacket(marshal(msgChannelFailure, ack))
  338. }
  339. ack := channelRequestSuccessMsg{
  340. PeersId: c.remoteId,
  341. }
  342. return c.writePacket(marshal(msgChannelSuccess, ack))
  343. }
  344. func (c *serverChan) ChannelType() string {
  345. return c.chanType
  346. }
  347. func (c *serverChan) ExtraData() []byte {
  348. return c.extraData
  349. }
  350. // A clientChan represents a single RFC 4254 channel multiplexed
  351. // over a SSH connection.
  352. type clientChan struct {
  353. channel
  354. stdin *chanWriter
  355. stdout *chanReader
  356. stderr *chanReader
  357. msg chan interface{}
  358. }
  359. // newClientChan returns a partially constructed *clientChan
  360. // using the local id provided. To be usable clientChan.remoteId
  361. // needs to be assigned once known.
  362. func newClientChan(cc conn, id uint32) *clientChan {
  363. c := &clientChan{
  364. channel: channel{
  365. conn: cc,
  366. localId: id,
  367. remoteWin: window{Cond: newCond()},
  368. },
  369. msg: make(chan interface{}, 16),
  370. }
  371. c.stdin = &chanWriter{
  372. channel: &c.channel,
  373. }
  374. c.stdout = &chanReader{
  375. channel: &c.channel,
  376. buffer: newBuffer(),
  377. }
  378. c.stderr = &chanReader{
  379. channel: &c.channel,
  380. buffer: newBuffer(),
  381. }
  382. return c
  383. }
  384. // waitForChannelOpenResponse, if successful, fills out
  385. // the remoteId and records any initial window advertisement.
  386. func (c *clientChan) waitForChannelOpenResponse() error {
  387. switch msg := (<-c.msg).(type) {
  388. case *channelOpenConfirmMsg:
  389. if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
  390. return errors.New("ssh: invalid MaxPacketSize from peer")
  391. }
  392. // fixup remoteId field
  393. c.remoteId = msg.MyId
  394. c.maxPacket = msg.MaxPacketSize
  395. c.remoteWin.add(msg.MyWindow)
  396. return nil
  397. case *channelOpenFailureMsg:
  398. return errors.New(safeString(msg.Message))
  399. }
  400. return errors.New("ssh: unexpected packet")
  401. }
  402. // Close closes the channel. This does not close the underlying connection.
  403. func (c *clientChan) Close() error {
  404. if !c.weClosed {
  405. c.weClosed = true
  406. return c.sendClose()
  407. }
  408. return nil
  409. }
  410. // A chanWriter represents the stdin of a remote process.
  411. type chanWriter struct {
  412. *channel
  413. }
  414. // Write writes data to the remote process's standard input.
  415. func (w *chanWriter) Write(data []byte) (written int, err error) {
  416. const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
  417. for len(data) > 0 {
  418. // never send more data than maxPacket even if
  419. // there is sufficent window.
  420. n := min(int(w.maxPacket-headerLength), len(data))
  421. n = int(w.remoteWin.reserve(uint32(n)))
  422. remoteId := w.remoteId
  423. packet := []byte{
  424. msgChannelData,
  425. byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId),
  426. byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
  427. }
  428. if err = w.writePacket(append(packet, data[:n]...)); err != nil {
  429. break
  430. }
  431. data = data[n:]
  432. written += n
  433. }
  434. return
  435. }
  436. func min(a, b int) int {
  437. if a < b {
  438. return a
  439. }
  440. return b
  441. }
  442. func (w *chanWriter) Close() error {
  443. return w.sendEOF()
  444. }
  445. // A chanReader represents stdout or stderr of a remote process.
  446. type chanReader struct {
  447. *channel // the channel backing this reader
  448. *buffer
  449. }
  450. // Read reads data from the remote process's stdout or stderr.
  451. func (r *chanReader) Read(buf []byte) (int, error) {
  452. n, err := r.buffer.Read(buf)
  453. if err != nil {
  454. if err == io.EOF {
  455. return n, err
  456. }
  457. return 0, err
  458. }
  459. return n, r.sendWindowAdj(n)
  460. }