channel.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. "sync"
  10. "sync/atomic"
  11. )
  12. // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
  13. // section 5.2.
  14. type extendedDataTypeCode uint32
  15. const (
  16. // extendedDataStderr is the extended data type that is used for stderr.
  17. extendedDataStderr extendedDataTypeCode = 1
  18. // minPacketLength defines the smallest valid packet
  19. minPacketLength = 9
  20. )
  21. // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
  22. // SSH connection. Channel.Read can return a ChannelRequest as an error.
  23. type Channel interface {
  24. // Accept accepts the channel creation request.
  25. Accept() error
  26. // Reject rejects the channel creation request. After calling this, no
  27. // other methods on the Channel may be called. If they are then the
  28. // peer is likely to signal a protocol error and drop the connection.
  29. Reject(reason RejectionReason, message string) error
  30. // Read may return a ChannelRequest as an error.
  31. Read(data []byte) (int, error)
  32. Write(data []byte) (int, error)
  33. Close() error
  34. // Stderr returns an io.Writer that writes to this channel with the
  35. // extended data type set to stderr.
  36. Stderr() io.Writer
  37. // AckRequest either sends an ack or nack to the channel request.
  38. AckRequest(ok bool) error
  39. // ChannelType returns the type of the channel, as supplied by the
  40. // client.
  41. ChannelType() string
  42. // ExtraData returns the arbitrary payload for this channel, as supplied
  43. // by the client. This data is specific to the channel type.
  44. ExtraData() []byte
  45. }
  46. // ChannelRequest represents a request sent on a channel, outside of the normal
  47. // stream of bytes. It may result from calling Read on a Channel.
  48. type ChannelRequest struct {
  49. Request string
  50. WantReply bool
  51. Payload []byte
  52. }
  53. func (c ChannelRequest) Error() string {
  54. return "ssh: channel request received"
  55. }
  56. // RejectionReason is an enumeration used when rejecting channel creation
  57. // requests. See RFC 4254, section 5.1.
  58. type RejectionReason uint32
  59. const (
  60. Prohibited RejectionReason = iota + 1
  61. ConnectionFailed
  62. UnknownChannelType
  63. ResourceShortage
  64. )
  65. // String converts the rejection reason to human readable form.
  66. func (r RejectionReason) String() string {
  67. switch r {
  68. case Prohibited:
  69. return "administratively prohibited"
  70. case ConnectionFailed:
  71. return "connect failed"
  72. case UnknownChannelType:
  73. return "unknown channel type"
  74. case ResourceShortage:
  75. return "resource shortage"
  76. }
  77. return fmt.Sprintf("unknown reason %d", int(r))
  78. }
  79. type channel struct {
  80. packetConn // the underlying transport
  81. localId, remoteId uint32
  82. remoteWin window
  83. maxPacket uint32
  84. isClosed uint32 // atomic bool, non zero if true
  85. }
  86. func (c *channel) sendWindowAdj(n int) error {
  87. msg := windowAdjustMsg{
  88. PeersId: c.remoteId,
  89. AdditionalBytes: uint32(n),
  90. }
  91. return c.writePacket(marshal(msgChannelWindowAdjust, msg))
  92. }
  93. // sendEOF sends EOF to the remote side. RFC 4254 Section 5.3
  94. func (c *channel) sendEOF() error {
  95. return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
  96. PeersId: c.remoteId,
  97. }))
  98. }
  99. // sendClose informs the remote side of our intent to close the channel.
  100. func (c *channel) sendClose() error {
  101. return c.packetConn.writePacket(marshal(msgChannelClose, channelCloseMsg{
  102. PeersId: c.remoteId,
  103. }))
  104. }
  105. func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
  106. reject := channelOpenFailureMsg{
  107. PeersId: c.remoteId,
  108. Reason: reason,
  109. Message: message,
  110. Language: "en",
  111. }
  112. return c.writePacket(marshal(msgChannelOpenFailure, reject))
  113. }
  114. func (c *channel) writePacket(b []byte) error {
  115. if c.closed() {
  116. return io.EOF
  117. }
  118. if uint32(len(b)) > c.maxPacket {
  119. return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
  120. }
  121. return c.packetConn.writePacket(b)
  122. }
  123. func (c *channel) closed() bool {
  124. return atomic.LoadUint32(&c.isClosed) > 0
  125. }
  126. func (c *channel) setClosed() bool {
  127. return atomic.CompareAndSwapUint32(&c.isClosed, 0, 1)
  128. }
  129. type serverChan struct {
  130. channel
  131. // immutable once created
  132. chanType string
  133. extraData []byte
  134. serverConn *ServerConn
  135. myWindow uint32
  136. theyClosed bool // indicates the close msg has been received from the remote side
  137. theySentEOF bool
  138. isDead uint32
  139. err error
  140. pendingRequests []ChannelRequest
  141. pendingData []byte
  142. head, length int
  143. // This lock is inferior to serverConn.lock
  144. cond *sync.Cond
  145. }
  146. func (c *serverChan) Accept() error {
  147. c.serverConn.lock.Lock()
  148. defer c.serverConn.lock.Unlock()
  149. if c.serverConn.err != nil {
  150. return c.serverConn.err
  151. }
  152. confirm := channelOpenConfirmMsg{
  153. PeersId: c.remoteId,
  154. MyId: c.localId,
  155. MyWindow: c.myWindow,
  156. MaxPacketSize: c.maxPacket,
  157. }
  158. return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
  159. }
  160. func (c *serverChan) Reject(reason RejectionReason, message string) error {
  161. c.serverConn.lock.Lock()
  162. defer c.serverConn.lock.Unlock()
  163. if c.serverConn.err != nil {
  164. return c.serverConn.err
  165. }
  166. return c.sendChannelOpenFailure(reason, message)
  167. }
  168. func (c *serverChan) handlePacket(packet interface{}) {
  169. c.cond.L.Lock()
  170. defer c.cond.L.Unlock()
  171. switch packet := packet.(type) {
  172. case *channelRequestMsg:
  173. req := ChannelRequest{
  174. Request: packet.Request,
  175. WantReply: packet.WantReply,
  176. Payload: packet.RequestSpecificData,
  177. }
  178. c.pendingRequests = append(c.pendingRequests, req)
  179. c.cond.Signal()
  180. case *channelCloseMsg:
  181. c.theyClosed = true
  182. c.cond.Signal()
  183. case *channelEOFMsg:
  184. c.theySentEOF = true
  185. c.cond.Signal()
  186. case *windowAdjustMsg:
  187. if !c.remoteWin.add(packet.AdditionalBytes) {
  188. panic("illegal window update")
  189. }
  190. default:
  191. panic("unknown packet type")
  192. }
  193. }
  194. func (c *serverChan) handleData(data []byte) {
  195. c.cond.L.Lock()
  196. defer c.cond.L.Unlock()
  197. // The other side should never send us more than our window.
  198. if len(data)+c.length > len(c.pendingData) {
  199. // TODO(agl): we should tear down the channel with a protocol
  200. // error.
  201. return
  202. }
  203. c.myWindow -= uint32(len(data))
  204. for i := 0; i < 2; i++ {
  205. tail := c.head + c.length
  206. if tail >= len(c.pendingData) {
  207. tail -= len(c.pendingData)
  208. }
  209. n := copy(c.pendingData[tail:], data)
  210. data = data[n:]
  211. c.length += n
  212. }
  213. c.cond.Signal()
  214. }
  215. func (c *serverChan) Stderr() io.Writer {
  216. return extendedDataChannel{c: c, t: extendedDataStderr}
  217. }
  218. // extendedDataChannel is an io.Writer that writes any data to c as extended
  219. // data of the given type.
  220. type extendedDataChannel struct {
  221. t extendedDataTypeCode
  222. c *serverChan
  223. }
  224. func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
  225. const headerLength = 13 // 1 byte message type, 4 bytes remoteId, 4 bytes extended message type, 4 bytes data length
  226. c := edc.c
  227. for len(data) > 0 {
  228. space := min(c.maxPacket-headerLength, len(data))
  229. if space, err = c.getWindowSpace(space); err != nil {
  230. return 0, err
  231. }
  232. todo := data
  233. if uint32(len(todo)) > space {
  234. todo = todo[:space]
  235. }
  236. packet := make([]byte, headerLength+len(todo))
  237. packet[0] = msgChannelExtendedData
  238. marshalUint32(packet[1:], c.remoteId)
  239. marshalUint32(packet[5:], uint32(edc.t))
  240. marshalUint32(packet[9:], uint32(len(todo)))
  241. copy(packet[13:], todo)
  242. if err = c.writePacket(packet); err != nil {
  243. return
  244. }
  245. n += len(todo)
  246. data = data[len(todo):]
  247. }
  248. return
  249. }
  250. func (c *serverChan) Read(data []byte) (n int, err error) {
  251. n, err, windowAdjustment := c.read(data)
  252. if windowAdjustment > 0 {
  253. packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
  254. PeersId: c.remoteId,
  255. AdditionalBytes: windowAdjustment,
  256. })
  257. err = c.writePacket(packet)
  258. }
  259. return
  260. }
  261. func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
  262. c.cond.L.Lock()
  263. defer c.cond.L.Unlock()
  264. if c.err != nil {
  265. return 0, c.err, 0
  266. }
  267. for {
  268. if c.theySentEOF || c.theyClosed || c.dead() {
  269. return 0, io.EOF, 0
  270. }
  271. if len(c.pendingRequests) > 0 {
  272. req := c.pendingRequests[0]
  273. if len(c.pendingRequests) == 1 {
  274. c.pendingRequests = nil
  275. } else {
  276. oldPendingRequests := c.pendingRequests
  277. c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
  278. copy(c.pendingRequests, oldPendingRequests[1:])
  279. }
  280. return 0, req, 0
  281. }
  282. if c.length > 0 {
  283. tail := min(uint32(c.head+c.length), len(c.pendingData))
  284. n = copy(data, c.pendingData[c.head:tail])
  285. c.head += n
  286. c.length -= n
  287. if c.head == len(c.pendingData) {
  288. c.head = 0
  289. }
  290. windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow
  291. if windowAdjustment < uint32(len(c.pendingData)/2) {
  292. windowAdjustment = 0
  293. }
  294. c.myWindow += windowAdjustment
  295. return
  296. }
  297. c.cond.Wait()
  298. }
  299. panic("unreachable")
  300. }
  301. // getWindowSpace takes, at most, max bytes of space from the peer's window. It
  302. // returns the number of bytes actually reserved.
  303. func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
  304. if c.dead() || c.closed() {
  305. return 0, io.EOF
  306. }
  307. return c.remoteWin.reserve(max), nil
  308. }
  309. func (c *serverChan) dead() bool {
  310. return atomic.LoadUint32(&c.isDead) > 0
  311. }
  312. func (c *serverChan) setDead() {
  313. atomic.StoreUint32(&c.isDead, 1)
  314. }
  315. func (c *serverChan) Write(data []byte) (n int, err error) {
  316. const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
  317. for len(data) > 0 {
  318. space := min(c.maxPacket-headerLength, len(data))
  319. if space, err = c.getWindowSpace(space); err != nil {
  320. return 0, err
  321. }
  322. todo := data
  323. if uint32(len(todo)) > space {
  324. todo = todo[:space]
  325. }
  326. packet := make([]byte, headerLength+len(todo))
  327. packet[0] = msgChannelData
  328. marshalUint32(packet[1:], c.remoteId)
  329. marshalUint32(packet[5:], uint32(len(todo)))
  330. copy(packet[9:], todo)
  331. if err = c.writePacket(packet); err != nil {
  332. return
  333. }
  334. n += len(todo)
  335. data = data[len(todo):]
  336. }
  337. return
  338. }
  339. // Close signals the intent to close the channel.
  340. func (c *serverChan) Close() error {
  341. c.serverConn.lock.Lock()
  342. defer c.serverConn.lock.Unlock()
  343. if c.serverConn.err != nil {
  344. return c.serverConn.err
  345. }
  346. if !c.setClosed() {
  347. return errors.New("ssh: channel already closed")
  348. }
  349. return c.sendClose()
  350. }
  351. func (c *serverChan) AckRequest(ok bool) error {
  352. c.serverConn.lock.Lock()
  353. defer c.serverConn.lock.Unlock()
  354. if c.serverConn.err != nil {
  355. return c.serverConn.err
  356. }
  357. if !ok {
  358. ack := channelRequestFailureMsg{
  359. PeersId: c.remoteId,
  360. }
  361. return c.writePacket(marshal(msgChannelFailure, ack))
  362. }
  363. ack := channelRequestSuccessMsg{
  364. PeersId: c.remoteId,
  365. }
  366. return c.writePacket(marshal(msgChannelSuccess, ack))
  367. }
  368. func (c *serverChan) ChannelType() string {
  369. return c.chanType
  370. }
  371. func (c *serverChan) ExtraData() []byte {
  372. return c.extraData
  373. }
  374. // A clientChan represents a single RFC 4254 channel multiplexed
  375. // over a SSH connection.
  376. type clientChan struct {
  377. channel
  378. stdin *chanWriter
  379. stdout *chanReader
  380. stderr *chanReader
  381. msg chan interface{}
  382. }
  383. // newClientChan returns a partially constructed *clientChan
  384. // using the local id provided. To be usable clientChan.remoteId
  385. // needs to be assigned once known.
  386. func newClientChan(cc packetConn, id uint32) *clientChan {
  387. c := &clientChan{
  388. channel: channel{
  389. packetConn: cc,
  390. localId: id,
  391. remoteWin: window{Cond: newCond()},
  392. },
  393. msg: make(chan interface{}, 16),
  394. }
  395. c.stdin = &chanWriter{
  396. channel: &c.channel,
  397. }
  398. c.stdout = &chanReader{
  399. channel: &c.channel,
  400. buffer: newBuffer(),
  401. }
  402. c.stderr = &chanReader{
  403. channel: &c.channel,
  404. buffer: newBuffer(),
  405. }
  406. return c
  407. }
  408. // waitForChannelOpenResponse, if successful, fills out
  409. // the remoteId and records any initial window advertisement.
  410. func (c *clientChan) waitForChannelOpenResponse() error {
  411. switch msg := (<-c.msg).(type) {
  412. case *channelOpenConfirmMsg:
  413. if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
  414. return errors.New("ssh: invalid MaxPacketSize from peer")
  415. }
  416. // fixup remoteId field
  417. c.remoteId = msg.MyId
  418. c.maxPacket = msg.MaxPacketSize
  419. c.remoteWin.add(msg.MyWindow)
  420. return nil
  421. case *channelOpenFailureMsg:
  422. return errors.New(safeString(msg.Message))
  423. }
  424. return errors.New("ssh: unexpected packet")
  425. }
  426. // Close signals the intent to close the channel.
  427. func (c *clientChan) Close() error {
  428. if !c.setClosed() {
  429. return errors.New("ssh: channel already closed")
  430. }
  431. c.stdout.eof()
  432. c.stderr.eof()
  433. return c.sendClose()
  434. }
  435. // A chanWriter represents the stdin of a remote process.
  436. type chanWriter struct {
  437. *channel
  438. // indicates the writer has been closed. eof is owned by the
  439. // caller of Write/Close.
  440. eof bool
  441. }
  442. // Write writes data to the remote process's standard input.
  443. func (w *chanWriter) Write(data []byte) (written int, err error) {
  444. const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
  445. for len(data) > 0 {
  446. if w.eof || w.closed() {
  447. err = io.EOF
  448. return
  449. }
  450. // never send more data than maxPacket even if
  451. // there is sufficient window.
  452. n := min(w.maxPacket-headerLength, len(data))
  453. r := w.remoteWin.reserve(n)
  454. n = r
  455. remoteId := w.remoteId
  456. packet := []byte{
  457. msgChannelData,
  458. byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId),
  459. byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
  460. }
  461. if err = w.writePacket(append(packet, data[:n]...)); err != nil {
  462. break
  463. }
  464. data = data[n:]
  465. written += int(n)
  466. }
  467. return
  468. }
  469. func min(a uint32, b int) uint32 {
  470. if a < uint32(b) {
  471. return a
  472. }
  473. return uint32(b)
  474. }
  475. func (w *chanWriter) Close() error {
  476. w.eof = true
  477. return w.sendEOF()
  478. }
  479. // A chanReader represents stdout or stderr of a remote process.
  480. type chanReader struct {
  481. *channel // the channel backing this reader
  482. *buffer
  483. }
  484. // Read reads data from the remote process's stdout or stderr.
  485. func (r *chanReader) Read(buf []byte) (int, error) {
  486. n, err := r.buffer.Read(buf)
  487. if err != nil {
  488. if err == io.EOF {
  489. return n, err
  490. }
  491. return 0, err
  492. }
  493. err = r.sendWindowAdj(n)
  494. if err == io.EOF && n > 0 {
  495. // sendWindowAdjust can return io.EOF if the remote peer has
  496. // closed the connection, however we want to defer forwarding io.EOF to the
  497. // caller of Read until the buffer has been drained.
  498. err = nil
  499. }
  500. return n, err
  501. }