channel.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. "sync"
  10. "sync/atomic"
  11. )
  12. // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
  13. // section 5.2.
  14. type extendedDataTypeCode uint32
  15. const (
  16. // extendedDataStderr is the extended data type that is used for stderr.
  17. extendedDataStderr extendedDataTypeCode = 1
  18. // minPacketLength defines the smallest valid packet
  19. minPacketLength = 9
  20. // channelMaxPacketSize defines the maximum packet size advertised in open messages
  21. channelMaxPacketSize = 1 << 15 // RFC 4253 6.1, minimum 32 KiB
  22. // channelWindowSize defines the window size advertised in open messages
  23. channelWindowSize = 64 * channelMaxPacketSize // Like OpenSSH
  24. )
  25. // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
  26. // SSH connection. Channel.Read can return a ChannelRequest as an error.
  27. type Channel interface {
  28. // Accept accepts the channel creation request.
  29. Accept() error
  30. // Reject rejects the channel creation request. After calling this, no
  31. // other methods on the Channel may be called. If they are then the
  32. // peer is likely to signal a protocol error and drop the connection.
  33. Reject(reason RejectionReason, message string) error
  34. // Read may return a ChannelRequest as an error.
  35. Read(data []byte) (int, error)
  36. Write(data []byte) (int, error)
  37. Close() error
  38. // Stderr returns an io.Writer that writes to this channel with the
  39. // extended data type set to stderr.
  40. Stderr() io.Writer
  41. // AckRequest either sends an ack or nack to the channel request.
  42. AckRequest(ok bool) error
  43. // ChannelType returns the type of the channel, as supplied by the
  44. // client.
  45. ChannelType() string
  46. // ExtraData returns the arbitrary payload for this channel, as supplied
  47. // by the client. This data is specific to the channel type.
  48. ExtraData() []byte
  49. }
  50. // ChannelRequest represents a request sent on a channel, outside of the normal
  51. // stream of bytes. It may result from calling Read on a Channel.
  52. type ChannelRequest struct {
  53. Request string
  54. WantReply bool
  55. Payload []byte
  56. }
  57. func (c ChannelRequest) Error() string {
  58. return "ssh: channel request received"
  59. }
  60. // RejectionReason is an enumeration used when rejecting channel creation
  61. // requests. See RFC 4254, section 5.1.
  62. type RejectionReason uint32
  63. const (
  64. Prohibited RejectionReason = iota + 1
  65. ConnectionFailed
  66. UnknownChannelType
  67. ResourceShortage
  68. )
  69. // String converts the rejection reason to human readable form.
  70. func (r RejectionReason) String() string {
  71. switch r {
  72. case Prohibited:
  73. return "administratively prohibited"
  74. case ConnectionFailed:
  75. return "connect failed"
  76. case UnknownChannelType:
  77. return "unknown channel type"
  78. case ResourceShortage:
  79. return "resource shortage"
  80. }
  81. return fmt.Sprintf("unknown reason %d", int(r))
  82. }
  83. type channel struct {
  84. packetConn // the underlying transport
  85. localId, remoteId uint32
  86. remoteWin window
  87. maxPacket uint32
  88. isClosed uint32 // atomic bool, non zero if true
  89. }
  90. func (c *channel) sendWindowAdj(n int) error {
  91. msg := windowAdjustMsg{
  92. PeersId: c.remoteId,
  93. AdditionalBytes: uint32(n),
  94. }
  95. return c.writePacket(marshal(msgChannelWindowAdjust, msg))
  96. }
  97. // sendEOF sends EOF to the remote side. RFC 4254 Section 5.3
  98. func (c *channel) sendEOF() error {
  99. return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
  100. PeersId: c.remoteId,
  101. }))
  102. }
  103. // sendClose informs the remote side of our intent to close the channel.
  104. func (c *channel) sendClose() error {
  105. return c.packetConn.writePacket(marshal(msgChannelClose, channelCloseMsg{
  106. PeersId: c.remoteId,
  107. }))
  108. }
  109. func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
  110. reject := channelOpenFailureMsg{
  111. PeersId: c.remoteId,
  112. Reason: reason,
  113. Message: message,
  114. Language: "en",
  115. }
  116. return c.writePacket(marshal(msgChannelOpenFailure, reject))
  117. }
  118. func (c *channel) writePacket(b []byte) error {
  119. if c.closed() {
  120. return io.EOF
  121. }
  122. if uint32(len(b)) > c.maxPacket {
  123. return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
  124. }
  125. return c.packetConn.writePacket(b)
  126. }
  127. func (c *channel) closed() bool {
  128. return atomic.LoadUint32(&c.isClosed) > 0
  129. }
  130. func (c *channel) setClosed() bool {
  131. return atomic.CompareAndSwapUint32(&c.isClosed, 0, 1)
  132. }
  133. type serverChan struct {
  134. channel
  135. // immutable once created
  136. chanType string
  137. extraData []byte
  138. serverConn *ServerConn
  139. myWindow uint32
  140. theyClosed bool // indicates the close msg has been received from the remote side
  141. theySentEOF bool
  142. isDead uint32
  143. err error
  144. pendingRequests []ChannelRequest
  145. pendingData []byte
  146. head, length int
  147. // This lock is inferior to serverConn.lock
  148. cond *sync.Cond
  149. }
  150. func (c *serverChan) Accept() error {
  151. c.serverConn.lock.Lock()
  152. defer c.serverConn.lock.Unlock()
  153. if c.serverConn.err != nil {
  154. return c.serverConn.err
  155. }
  156. confirm := channelOpenConfirmMsg{
  157. PeersId: c.remoteId,
  158. MyId: c.localId,
  159. MyWindow: c.myWindow,
  160. MaxPacketSize: c.maxPacket,
  161. }
  162. return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
  163. }
  164. func (c *serverChan) Reject(reason RejectionReason, message string) error {
  165. c.serverConn.lock.Lock()
  166. defer c.serverConn.lock.Unlock()
  167. if c.serverConn.err != nil {
  168. return c.serverConn.err
  169. }
  170. return c.sendChannelOpenFailure(reason, message)
  171. }
  172. func (c *serverChan) handlePacket(packet interface{}) {
  173. c.cond.L.Lock()
  174. defer c.cond.L.Unlock()
  175. switch packet := packet.(type) {
  176. case *channelRequestMsg:
  177. req := ChannelRequest{
  178. Request: packet.Request,
  179. WantReply: packet.WantReply,
  180. Payload: packet.RequestSpecificData,
  181. }
  182. c.pendingRequests = append(c.pendingRequests, req)
  183. c.cond.Signal()
  184. case *channelCloseMsg:
  185. c.theyClosed = true
  186. c.cond.Signal()
  187. case *channelEOFMsg:
  188. c.theySentEOF = true
  189. c.cond.Signal()
  190. case *windowAdjustMsg:
  191. if !c.remoteWin.add(packet.AdditionalBytes) {
  192. panic("illegal window update")
  193. }
  194. default:
  195. panic("unknown packet type")
  196. }
  197. }
  198. func (c *serverChan) handleData(data []byte) {
  199. c.cond.L.Lock()
  200. defer c.cond.L.Unlock()
  201. // The other side should never send us more than our window.
  202. if len(data)+c.length > len(c.pendingData) {
  203. // TODO(agl): we should tear down the channel with a protocol
  204. // error.
  205. return
  206. }
  207. c.myWindow -= uint32(len(data))
  208. for i := 0; i < 2; i++ {
  209. tail := c.head + c.length
  210. if tail >= len(c.pendingData) {
  211. tail -= len(c.pendingData)
  212. }
  213. n := copy(c.pendingData[tail:], data)
  214. data = data[n:]
  215. c.length += n
  216. }
  217. c.cond.Signal()
  218. }
  219. func (c *serverChan) Stderr() io.Writer {
  220. return extendedDataChannel{c: c, t: extendedDataStderr}
  221. }
  222. // extendedDataChannel is an io.Writer that writes any data to c as extended
  223. // data of the given type.
  224. type extendedDataChannel struct {
  225. t extendedDataTypeCode
  226. c *serverChan
  227. }
  228. func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
  229. const headerLength = 13 // 1 byte message type, 4 bytes remoteId, 4 bytes extended message type, 4 bytes data length
  230. c := edc.c
  231. for len(data) > 0 {
  232. space := min(c.maxPacket-headerLength, len(data))
  233. if space, err = c.getWindowSpace(space); err != nil {
  234. return 0, err
  235. }
  236. todo := data
  237. if uint32(len(todo)) > space {
  238. todo = todo[:space]
  239. }
  240. packet := make([]byte, headerLength+len(todo))
  241. packet[0] = msgChannelExtendedData
  242. marshalUint32(packet[1:], c.remoteId)
  243. marshalUint32(packet[5:], uint32(edc.t))
  244. marshalUint32(packet[9:], uint32(len(todo)))
  245. copy(packet[13:], todo)
  246. if err = c.writePacket(packet); err != nil {
  247. return
  248. }
  249. n += len(todo)
  250. data = data[len(todo):]
  251. }
  252. return
  253. }
  254. func (c *serverChan) Read(data []byte) (n int, err error) {
  255. n, err, windowAdjustment := c.read(data)
  256. if windowAdjustment > 0 {
  257. packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
  258. PeersId: c.remoteId,
  259. AdditionalBytes: windowAdjustment,
  260. })
  261. err = c.writePacket(packet)
  262. }
  263. return
  264. }
  265. func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
  266. c.cond.L.Lock()
  267. defer c.cond.L.Unlock()
  268. if c.err != nil {
  269. return 0, c.err, 0
  270. }
  271. for {
  272. if c.theySentEOF || c.theyClosed || c.dead() {
  273. return 0, io.EOF, 0
  274. }
  275. if len(c.pendingRequests) > 0 {
  276. req := c.pendingRequests[0]
  277. if len(c.pendingRequests) == 1 {
  278. c.pendingRequests = nil
  279. } else {
  280. oldPendingRequests := c.pendingRequests
  281. c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
  282. copy(c.pendingRequests, oldPendingRequests[1:])
  283. }
  284. return 0, req, 0
  285. }
  286. if c.length > 0 {
  287. tail := min(uint32(c.head+c.length), len(c.pendingData))
  288. n = copy(data, c.pendingData[c.head:tail])
  289. c.head += n
  290. c.length -= n
  291. if c.head == len(c.pendingData) {
  292. c.head = 0
  293. }
  294. windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow
  295. if windowAdjustment < uint32(len(c.pendingData)/2) {
  296. windowAdjustment = 0
  297. }
  298. c.myWindow += windowAdjustment
  299. return
  300. }
  301. c.cond.Wait()
  302. }
  303. panic("unreachable")
  304. }
  305. // getWindowSpace takes, at most, max bytes of space from the peer's window. It
  306. // returns the number of bytes actually reserved.
  307. func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
  308. if c.dead() || c.closed() {
  309. return 0, io.EOF
  310. }
  311. return c.remoteWin.reserve(max), nil
  312. }
  313. func (c *serverChan) dead() bool {
  314. return atomic.LoadUint32(&c.isDead) > 0
  315. }
  316. func (c *serverChan) setDead() {
  317. atomic.StoreUint32(&c.isDead, 1)
  318. }
  319. func (c *serverChan) Write(data []byte) (n int, err error) {
  320. const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
  321. for len(data) > 0 {
  322. space := min(c.maxPacket-headerLength, len(data))
  323. if space, err = c.getWindowSpace(space); err != nil {
  324. return 0, err
  325. }
  326. todo := data
  327. if uint32(len(todo)) > space {
  328. todo = todo[:space]
  329. }
  330. packet := make([]byte, headerLength+len(todo))
  331. packet[0] = msgChannelData
  332. marshalUint32(packet[1:], c.remoteId)
  333. marshalUint32(packet[5:], uint32(len(todo)))
  334. copy(packet[9:], todo)
  335. if err = c.writePacket(packet); err != nil {
  336. return
  337. }
  338. n += len(todo)
  339. data = data[len(todo):]
  340. }
  341. return
  342. }
  343. // Close signals the intent to close the channel.
  344. func (c *serverChan) Close() error {
  345. c.serverConn.lock.Lock()
  346. defer c.serverConn.lock.Unlock()
  347. if c.serverConn.err != nil {
  348. return c.serverConn.err
  349. }
  350. if !c.setClosed() {
  351. return errors.New("ssh: channel already closed")
  352. }
  353. return c.sendClose()
  354. }
  355. func (c *serverChan) AckRequest(ok bool) error {
  356. c.serverConn.lock.Lock()
  357. defer c.serverConn.lock.Unlock()
  358. if c.serverConn.err != nil {
  359. return c.serverConn.err
  360. }
  361. if !ok {
  362. ack := channelRequestFailureMsg{
  363. PeersId: c.remoteId,
  364. }
  365. return c.writePacket(marshal(msgChannelFailure, ack))
  366. }
  367. ack := channelRequestSuccessMsg{
  368. PeersId: c.remoteId,
  369. }
  370. return c.writePacket(marshal(msgChannelSuccess, ack))
  371. }
  372. func (c *serverChan) ChannelType() string {
  373. return c.chanType
  374. }
  375. func (c *serverChan) ExtraData() []byte {
  376. return c.extraData
  377. }
  378. // A clientChan represents a single RFC 4254 channel multiplexed
  379. // over a SSH connection.
  380. type clientChan struct {
  381. channel
  382. stdin *chanWriter
  383. stdout *chanReader
  384. stderr *chanReader
  385. msg chan interface{}
  386. }
  387. // newClientChan returns a partially constructed *clientChan
  388. // using the local id provided. To be usable clientChan.remoteId
  389. // needs to be assigned once known.
  390. func newClientChan(cc packetConn, id uint32) *clientChan {
  391. c := &clientChan{
  392. channel: channel{
  393. packetConn: cc,
  394. localId: id,
  395. remoteWin: window{Cond: newCond()},
  396. },
  397. msg: make(chan interface{}, 16),
  398. }
  399. c.stdin = &chanWriter{
  400. channel: &c.channel,
  401. }
  402. c.stdout = &chanReader{
  403. channel: &c.channel,
  404. buffer: newBuffer(),
  405. }
  406. c.stderr = &chanReader{
  407. channel: &c.channel,
  408. buffer: newBuffer(),
  409. }
  410. return c
  411. }
  412. // waitForChannelOpenResponse, if successful, fills out
  413. // the remoteId and records any initial window advertisement.
  414. func (c *clientChan) waitForChannelOpenResponse() error {
  415. switch msg := (<-c.msg).(type) {
  416. case *channelOpenConfirmMsg:
  417. if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
  418. return errors.New("ssh: invalid MaxPacketSize from peer")
  419. }
  420. // fixup remoteId field
  421. c.remoteId = msg.MyId
  422. c.maxPacket = msg.MaxPacketSize
  423. c.remoteWin.add(msg.MyWindow)
  424. return nil
  425. case *channelOpenFailureMsg:
  426. return errors.New(safeString(msg.Message))
  427. }
  428. return errors.New("ssh: unexpected packet")
  429. }
  430. // Close signals the intent to close the channel.
  431. func (c *clientChan) Close() error {
  432. if !c.setClosed() {
  433. return errors.New("ssh: channel already closed")
  434. }
  435. c.stdout.eof()
  436. c.stderr.eof()
  437. return c.sendClose()
  438. }
  439. // A chanWriter represents the stdin of a remote process.
  440. type chanWriter struct {
  441. *channel
  442. // indicates the writer has been closed. eof is owned by the
  443. // caller of Write/Close.
  444. eof bool
  445. }
  446. // Write writes data to the remote process's standard input.
  447. func (w *chanWriter) Write(data []byte) (written int, err error) {
  448. const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
  449. for len(data) > 0 {
  450. if w.eof || w.closed() {
  451. err = io.EOF
  452. return
  453. }
  454. // never send more data than maxPacket even if
  455. // there is sufficient window.
  456. n := min(w.maxPacket-headerLength, len(data))
  457. r := w.remoteWin.reserve(n)
  458. n = r
  459. remoteId := w.remoteId
  460. packet := []byte{
  461. msgChannelData,
  462. byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId),
  463. byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
  464. }
  465. if err = w.writePacket(append(packet, data[:n]...)); err != nil {
  466. break
  467. }
  468. data = data[n:]
  469. written += int(n)
  470. }
  471. return
  472. }
  473. func min(a uint32, b int) uint32 {
  474. if a < uint32(b) {
  475. return a
  476. }
  477. return uint32(b)
  478. }
  479. func (w *chanWriter) Close() error {
  480. w.eof = true
  481. return w.sendEOF()
  482. }
  483. // A chanReader represents stdout or stderr of a remote process.
  484. type chanReader struct {
  485. *channel // the channel backing this reader
  486. *buffer
  487. }
  488. // Read reads data from the remote process's stdout or stderr.
  489. func (r *chanReader) Read(buf []byte) (int, error) {
  490. n, err := r.buffer.Read(buf)
  491. if err != nil {
  492. if err == io.EOF {
  493. return n, err
  494. }
  495. return 0, err
  496. }
  497. err = r.sendWindowAdj(n)
  498. if err == io.EOF && n > 0 {
  499. // sendWindowAdjust can return io.EOF if the remote peer has
  500. // closed the connection, however we want to defer forwarding io.EOF to the
  501. // caller of Read until the buffer has been drained.
  502. err = nil
  503. }
  504. return n, err
  505. }