channel.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. "sync"
  10. "sync/atomic"
  11. )
  12. // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
  13. // section 5.2.
  14. type extendedDataTypeCode uint32
  15. const (
  16. // extendedDataStderr is the extended data type that is used for stderr.
  17. extendedDataStderr extendedDataTypeCode = 1
  18. // minPacketLength defines the smallest valid packet
  19. minPacketLength = 9
  20. )
  21. // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
  22. // SSH connection. Channel.Read can return a ChannelRequest as an error.
  23. type Channel interface {
  24. // Accept accepts the channel creation request.
  25. Accept() error
  26. // Reject rejects the channel creation request. After calling this, no
  27. // other methods on the Channel may be called. If they are then the
  28. // peer is likely to signal a protocol error and drop the connection.
  29. Reject(reason RejectionReason, message string) error
  30. // Read may return a ChannelRequest as an error.
  31. Read(data []byte) (int, error)
  32. Write(data []byte) (int, error)
  33. Close() error
  34. // Stderr returns an io.Writer that writes to this channel with the
  35. // extended data type set to stderr.
  36. Stderr() io.Writer
  37. // AckRequest either sends an ack or nack to the channel request.
  38. AckRequest(ok bool) error
  39. // ChannelType returns the type of the channel, as supplied by the
  40. // client.
  41. ChannelType() string
  42. // ExtraData returns the arbitary payload for this channel, as supplied
  43. // by the client. This data is specific to the channel type.
  44. ExtraData() []byte
  45. }
  46. // ChannelRequest represents a request sent on a channel, outside of the normal
  47. // stream of bytes. It may result from calling Read on a Channel.
  48. type ChannelRequest struct {
  49. Request string
  50. WantReply bool
  51. Payload []byte
  52. }
  53. func (c ChannelRequest) Error() string {
  54. return "ssh: channel request received"
  55. }
  56. // RejectionReason is an enumeration used when rejecting channel creation
  57. // requests. See RFC 4254, section 5.1.
  58. type RejectionReason uint32
  59. const (
  60. Prohibited RejectionReason = iota + 1
  61. ConnectionFailed
  62. UnknownChannelType
  63. ResourceShortage
  64. )
  65. type channel struct {
  66. conn // the underlying transport
  67. localId, remoteId uint32
  68. remoteWin window
  69. maxPacket uint32
  70. isclosed uint32 // atomic bool, non zero if true
  71. }
  72. func (c *channel) sendWindowAdj(n int) error {
  73. msg := windowAdjustMsg{
  74. PeersId: c.remoteId,
  75. AdditionalBytes: uint32(n),
  76. }
  77. return c.writePacket(marshal(msgChannelWindowAdjust, msg))
  78. }
  79. // sendEOF sends EOF to the server. RFC 4254 Section 5.3
  80. func (c *channel) sendEOF() error {
  81. return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
  82. PeersId: c.remoteId,
  83. }))
  84. }
  85. func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
  86. reject := channelOpenFailureMsg{
  87. PeersId: c.remoteId,
  88. Reason: reason,
  89. Message: message,
  90. Language: "en",
  91. }
  92. return c.writePacket(marshal(msgChannelOpenFailure, reject))
  93. }
  94. func (c *channel) writePacket(b []byte) error {
  95. if c.closed() {
  96. return io.EOF
  97. }
  98. if uint32(len(b)) > c.maxPacket {
  99. return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
  100. }
  101. return c.conn.writePacket(b)
  102. }
  103. func (c *channel) closed() bool {
  104. return atomic.LoadUint32(&c.isclosed) > 0
  105. }
  106. func (c *channel) setClosed() bool {
  107. return atomic.CompareAndSwapUint32(&c.isclosed, 0, 1)
  108. }
  109. type serverChan struct {
  110. channel
  111. // immutable once created
  112. chanType string
  113. extraData []byte
  114. serverConn *ServerConn
  115. myWindow uint32
  116. weClosed bool // incidates the close msg has been sent from our side
  117. theyClosed bool // indicates the close msg has been received from the remote side
  118. theySentEOF bool
  119. dead bool
  120. err error
  121. pendingRequests []ChannelRequest
  122. pendingData []byte
  123. head, length int
  124. // This lock is inferior to serverConn.lock
  125. cond *sync.Cond
  126. }
  127. func (c *serverChan) Accept() error {
  128. c.serverConn.lock.Lock()
  129. defer c.serverConn.lock.Unlock()
  130. if c.serverConn.err != nil {
  131. return c.serverConn.err
  132. }
  133. confirm := channelOpenConfirmMsg{
  134. PeersId: c.remoteId,
  135. MyId: c.localId,
  136. MyWindow: c.myWindow,
  137. MaxPacketSize: c.maxPacket,
  138. }
  139. return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
  140. }
  141. func (c *serverChan) Reject(reason RejectionReason, message string) error {
  142. c.serverConn.lock.Lock()
  143. defer c.serverConn.lock.Unlock()
  144. if c.serverConn.err != nil {
  145. return c.serverConn.err
  146. }
  147. return c.sendChannelOpenFailure(reason, message)
  148. }
  149. func (c *serverChan) handlePacket(packet interface{}) {
  150. c.cond.L.Lock()
  151. defer c.cond.L.Unlock()
  152. switch packet := packet.(type) {
  153. case *channelRequestMsg:
  154. req := ChannelRequest{
  155. Request: packet.Request,
  156. WantReply: packet.WantReply,
  157. Payload: packet.RequestSpecificData,
  158. }
  159. c.pendingRequests = append(c.pendingRequests, req)
  160. c.cond.Signal()
  161. case *channelCloseMsg:
  162. c.theyClosed = true
  163. c.cond.Signal()
  164. case *channelEOFMsg:
  165. c.theySentEOF = true
  166. c.cond.Signal()
  167. case *windowAdjustMsg:
  168. if !c.remoteWin.add(packet.AdditionalBytes) {
  169. panic("illegal window update")
  170. }
  171. default:
  172. panic("unknown packet type")
  173. }
  174. }
  175. func (c *serverChan) handleData(data []byte) {
  176. c.cond.L.Lock()
  177. defer c.cond.L.Unlock()
  178. // The other side should never send us more than our window.
  179. if len(data)+c.length > len(c.pendingData) {
  180. // TODO(agl): we should tear down the channel with a protocol
  181. // error.
  182. return
  183. }
  184. c.myWindow -= uint32(len(data))
  185. for i := 0; i < 2; i++ {
  186. tail := c.head + c.length
  187. if tail >= len(c.pendingData) {
  188. tail -= len(c.pendingData)
  189. }
  190. n := copy(c.pendingData[tail:], data)
  191. data = data[n:]
  192. c.length += n
  193. }
  194. c.cond.Signal()
  195. }
  196. func (c *serverChan) Stderr() io.Writer {
  197. return extendedDataChannel{c: c, t: extendedDataStderr}
  198. }
  199. // extendedDataChannel is an io.Writer that writes any data to c as extended
  200. // data of the given type.
  201. type extendedDataChannel struct {
  202. t extendedDataTypeCode
  203. c *serverChan
  204. }
  205. func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
  206. const headerLength = 13 // 1 byte message type, 4 bytes remoteId, 4 bytes extended message type, 4 bytes data length
  207. c := edc.c
  208. for len(data) > 0 {
  209. space := uint32(min(int(c.maxPacket-headerLength), len(data)))
  210. if space, err = c.getWindowSpace(space); err != nil {
  211. return 0, err
  212. }
  213. todo := data
  214. if uint32(len(todo)) > space {
  215. todo = todo[:space]
  216. }
  217. packet := make([]byte, headerLength+len(todo))
  218. packet[0] = msgChannelExtendedData
  219. marshalUint32(packet[1:], c.remoteId)
  220. marshalUint32(packet[5:], uint32(edc.t))
  221. marshalUint32(packet[9:], uint32(len(todo)))
  222. copy(packet[13:], todo)
  223. if err = c.writePacket(packet); err != nil {
  224. return
  225. }
  226. n += len(todo)
  227. data = data[len(todo):]
  228. }
  229. return
  230. }
  231. func (c *serverChan) Read(data []byte) (n int, err error) {
  232. n, err, windowAdjustment := c.read(data)
  233. if windowAdjustment > 0 {
  234. packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
  235. PeersId: c.remoteId,
  236. AdditionalBytes: windowAdjustment,
  237. })
  238. err = c.writePacket(packet)
  239. }
  240. return
  241. }
  242. func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
  243. c.cond.L.Lock()
  244. defer c.cond.L.Unlock()
  245. if c.err != nil {
  246. return 0, c.err, 0
  247. }
  248. for {
  249. if c.theySentEOF || c.theyClosed || c.dead {
  250. return 0, io.EOF, 0
  251. }
  252. if len(c.pendingRequests) > 0 {
  253. req := c.pendingRequests[0]
  254. if len(c.pendingRequests) == 1 {
  255. c.pendingRequests = nil
  256. } else {
  257. oldPendingRequests := c.pendingRequests
  258. c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
  259. copy(c.pendingRequests, oldPendingRequests[1:])
  260. }
  261. return 0, req, 0
  262. }
  263. if c.length > 0 {
  264. tail := min(c.head+c.length, len(c.pendingData))
  265. n = copy(data, c.pendingData[c.head:tail])
  266. c.head += n
  267. c.length -= n
  268. if c.head == len(c.pendingData) {
  269. c.head = 0
  270. }
  271. windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow
  272. if windowAdjustment < uint32(len(c.pendingData)/2) {
  273. windowAdjustment = 0
  274. }
  275. c.myWindow += windowAdjustment
  276. return
  277. }
  278. c.cond.Wait()
  279. }
  280. panic("unreachable")
  281. }
  282. // getWindowSpace takes, at most, max bytes of space from the peer's window. It
  283. // returns the number of bytes actually reserved.
  284. func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
  285. var err error
  286. // TODO(dfc) This lock and check of c.weClosed is necessary because unlike
  287. // clientChan, c.weClosed is observed by more than one goroutine.
  288. c.cond.L.Lock()
  289. if c.dead || c.weClosed {
  290. err = io.EOF
  291. }
  292. c.cond.L.Unlock()
  293. if err != nil {
  294. return 0, err
  295. }
  296. return c.remoteWin.reserve(max), nil
  297. }
  298. func (c *serverChan) Write(data []byte) (n int, err error) {
  299. const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
  300. for len(data) > 0 {
  301. space := uint32(min(int(c.maxPacket-headerLength), len(data)))
  302. if space, err = c.getWindowSpace(space); err != nil {
  303. return 0, err
  304. }
  305. todo := data
  306. if uint32(len(todo)) > space {
  307. todo = todo[:space]
  308. }
  309. packet := make([]byte, headerLength+len(todo))
  310. packet[0] = msgChannelData
  311. marshalUint32(packet[1:], c.remoteId)
  312. marshalUint32(packet[5:], uint32(len(todo)))
  313. copy(packet[9:], todo)
  314. if err = c.writePacket(packet); err != nil {
  315. return
  316. }
  317. n += len(todo)
  318. data = data[len(todo):]
  319. }
  320. return
  321. }
  322. func (c *serverChan) Close() error {
  323. c.serverConn.lock.Lock()
  324. defer c.serverConn.lock.Unlock()
  325. if c.serverConn.err != nil {
  326. return c.serverConn.err
  327. }
  328. if c.weClosed {
  329. return errors.New("ssh: channel already closed")
  330. }
  331. c.weClosed = true
  332. return c.sendClose()
  333. }
  334. // sendClose signals the intent to close the channel.
  335. func (c *serverChan) sendClose() error {
  336. return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
  337. PeersId: c.remoteId,
  338. }))
  339. }
  340. func (c *serverChan) AckRequest(ok bool) error {
  341. c.serverConn.lock.Lock()
  342. defer c.serverConn.lock.Unlock()
  343. if c.serverConn.err != nil {
  344. return c.serverConn.err
  345. }
  346. if !ok {
  347. ack := channelRequestFailureMsg{
  348. PeersId: c.remoteId,
  349. }
  350. return c.writePacket(marshal(msgChannelFailure, ack))
  351. }
  352. ack := channelRequestSuccessMsg{
  353. PeersId: c.remoteId,
  354. }
  355. return c.writePacket(marshal(msgChannelSuccess, ack))
  356. }
  357. func (c *serverChan) ChannelType() string {
  358. return c.chanType
  359. }
  360. func (c *serverChan) ExtraData() []byte {
  361. return c.extraData
  362. }
  363. // A clientChan represents a single RFC 4254 channel multiplexed
  364. // over a SSH connection.
  365. type clientChan struct {
  366. channel
  367. stdin *chanWriter
  368. stdout *chanReader
  369. stderr *chanReader
  370. msg chan interface{}
  371. }
  372. // newClientChan returns a partially constructed *clientChan
  373. // using the local id provided. To be usable clientChan.remoteId
  374. // needs to be assigned once known.
  375. func newClientChan(cc conn, id uint32) *clientChan {
  376. c := &clientChan{
  377. channel: channel{
  378. conn: cc,
  379. localId: id,
  380. remoteWin: window{Cond: newCond()},
  381. },
  382. msg: make(chan interface{}, 16),
  383. }
  384. c.stdin = &chanWriter{
  385. channel: &c.channel,
  386. }
  387. c.stdout = &chanReader{
  388. channel: &c.channel,
  389. buffer: newBuffer(),
  390. }
  391. c.stderr = &chanReader{
  392. channel: &c.channel,
  393. buffer: newBuffer(),
  394. }
  395. return c
  396. }
  397. // waitForChannelOpenResponse, if successful, fills out
  398. // the remoteId and records any initial window advertisement.
  399. func (c *clientChan) waitForChannelOpenResponse() error {
  400. switch msg := (<-c.msg).(type) {
  401. case *channelOpenConfirmMsg:
  402. if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
  403. return errors.New("ssh: invalid MaxPacketSize from peer")
  404. }
  405. // fixup remoteId field
  406. c.remoteId = msg.MyId
  407. c.maxPacket = msg.MaxPacketSize
  408. c.remoteWin.add(msg.MyWindow)
  409. return nil
  410. case *channelOpenFailureMsg:
  411. return errors.New(safeString(msg.Message))
  412. }
  413. return errors.New("ssh: unexpected packet")
  414. }
  415. func (c *clientChan) Close() error {
  416. if !c.setClosed() {
  417. return errors.New("ssh: channel already closed")
  418. }
  419. c.stdout.eof()
  420. c.stderr.eof()
  421. close(c.msg)
  422. // TODO(dfc) step around channel.writePacket() because closed() is now true
  423. return c.channel.conn.writePacket(marshal(msgChannelClose, channelCloseMsg{
  424. PeersId: c.remoteId,
  425. }))
  426. }
  427. // A chanWriter represents the stdin of a remote process.
  428. type chanWriter struct {
  429. *channel
  430. }
  431. // Write writes data to the remote process's standard input.
  432. func (w *chanWriter) Write(data []byte) (written int, err error) {
  433. const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
  434. for len(data) > 0 {
  435. // never send more data than maxPacket even if
  436. // there is sufficent window.
  437. n := min(int(w.maxPacket-headerLength), len(data))
  438. n = int(w.remoteWin.reserve(uint32(n)))
  439. remoteId := w.remoteId
  440. packet := []byte{
  441. msgChannelData,
  442. byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId),
  443. byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
  444. }
  445. if err = w.writePacket(append(packet, data[:n]...)); err != nil {
  446. break
  447. }
  448. data = data[n:]
  449. written += n
  450. }
  451. return
  452. }
  453. func min(a, b int) int {
  454. if a < b {
  455. return a
  456. }
  457. return b
  458. }
  459. func (w *chanWriter) Close() error {
  460. return w.sendEOF()
  461. }
  462. // A chanReader represents stdout or stderr of a remote process.
  463. type chanReader struct {
  464. *channel // the channel backing this reader
  465. *buffer
  466. }
  467. // Read reads data from the remote process's stdout or stderr.
  468. func (r *chanReader) Read(buf []byte) (int, error) {
  469. n, err := r.buffer.Read(buf)
  470. if err != nil {
  471. if err == io.EOF {
  472. return n, err
  473. }
  474. return 0, err
  475. }
  476. return n, r.sendWindowAdj(n)
  477. }