| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396 |
- // Copyright 2011 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package ssh
- import (
- "errors"
- "io"
- "sync"
- )
- // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
- // section 5.2.
- type extendedDataTypeCode uint32
- // extendedDataStderr is the extended data type that is used for stderr.
- const extendedDataStderr extendedDataTypeCode = 1
- // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
- // SSH connection. Channel.Read can return a ChannelRequest as an error.
- type Channel interface {
- // Accept accepts the channel creation request.
- Accept() error
- // Reject rejects the channel creation request. After calling this, no
- // other methods on the Channel may be called. If they are then the
- // peer is likely to signal a protocol error and drop the connection.
- Reject(reason RejectionReason, message string) error
- // Read may return a ChannelRequest as an error.
- Read(data []byte) (int, error)
- Write(data []byte) (int, error)
- Close() error
- // Stderr returns an io.Writer that writes to this channel with the
- // extended data type set to stderr.
- Stderr() io.Writer
- // AckRequest either sends an ack or nack to the channel request.
- AckRequest(ok bool) error
- // ChannelType returns the type of the channel, as supplied by the
- // client.
- ChannelType() string
- // ExtraData returns the arbitary payload for this channel, as supplied
- // by the client. This data is specific to the channel type.
- ExtraData() []byte
- }
- // ChannelRequest represents a request sent on a channel, outside of the normal
- // stream of bytes. It may result from calling Read on a Channel.
- type ChannelRequest struct {
- Request string
- WantReply bool
- Payload []byte
- }
- func (c ChannelRequest) Error() string {
- return "ssh: channel request received"
- }
- // RejectionReason is an enumeration used when rejecting channel creation
- // requests. See RFC 4254, section 5.1.
- type RejectionReason int
- const (
- Prohibited RejectionReason = iota + 1
- ConnectionFailed
- UnknownChannelType
- ResourceShortage
- )
- type channel struct {
- // immutable once created
- chanType string
- extraData []byte
- theyClosed bool
- theySentEOF bool
- weClosed bool
- dead bool
- serverConn *ServerConn
- myId, theirId uint32
- myWindow, theirWindow uint32
- maxPacketSize uint32
- err error
- pendingRequests []ChannelRequest
- pendingData []byte
- head, length int
- // This lock is inferior to serverConn.lock
- lock sync.Mutex
- cond *sync.Cond
- }
- func (c *channel) Accept() error {
- c.serverConn.lock.Lock()
- defer c.serverConn.lock.Unlock()
- if c.serverConn.err != nil {
- return c.serverConn.err
- }
- confirm := channelOpenConfirmMsg{
- PeersId: c.theirId,
- MyId: c.myId,
- MyWindow: c.myWindow,
- MaxPacketSize: c.maxPacketSize,
- }
- return c.serverConn.writePacket(marshal(msgChannelOpenConfirm, confirm))
- }
- func (c *channel) Reject(reason RejectionReason, message string) error {
- c.serverConn.lock.Lock()
- defer c.serverConn.lock.Unlock()
- if c.serverConn.err != nil {
- return c.serverConn.err
- }
- reject := channelOpenFailureMsg{
- PeersId: c.theirId,
- Reason: reason,
- Message: message,
- Language: "en",
- }
- return c.serverConn.writePacket(marshal(msgChannelOpenFailure, reject))
- }
- func (c *channel) handlePacket(packet interface{}) {
- c.lock.Lock()
- defer c.lock.Unlock()
- switch packet := packet.(type) {
- case *channelRequestMsg:
- req := ChannelRequest{
- Request: packet.Request,
- WantReply: packet.WantReply,
- Payload: packet.RequestSpecificData,
- }
- c.pendingRequests = append(c.pendingRequests, req)
- c.cond.Signal()
- case *channelCloseMsg:
- c.theyClosed = true
- c.cond.Signal()
- case *channelEOFMsg:
- c.theySentEOF = true
- c.cond.Signal()
- case *windowAdjustMsg:
- c.theirWindow += packet.AdditionalBytes
- c.cond.Signal()
- default:
- panic("unknown packet type")
- }
- }
- func (c *channel) handleData(data []byte) {
- c.lock.Lock()
- defer c.lock.Unlock()
- // The other side should never send us more than our window.
- if len(data)+c.length > len(c.pendingData) {
- // TODO(agl): we should tear down the channel with a protocol
- // error.
- return
- }
- c.myWindow -= uint32(len(data))
- for i := 0; i < 2; i++ {
- tail := c.head + c.length
- if tail >= len(c.pendingData) {
- tail -= len(c.pendingData)
- }
- n := copy(c.pendingData[tail:], data)
- data = data[n:]
- c.length += n
- }
- c.cond.Signal()
- }
- func (c *channel) Stderr() io.Writer {
- return extendedDataChannel{c: c, t: extendedDataStderr}
- }
- // extendedDataChannel is an io.Writer that writes any data to c as extended
- // data of the given type.
- type extendedDataChannel struct {
- t extendedDataTypeCode
- c *channel
- }
- func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
- c := edc.c
- for len(data) > 0 {
- var space uint32
- if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
- return 0, err
- }
- todo := data
- if uint32(len(todo)) > space {
- todo = todo[:space]
- }
- packet := make([]byte, 1+4+4+4+len(todo))
- packet[0] = msgChannelExtendedData
- marshalUint32(packet[1:], c.theirId)
- marshalUint32(packet[5:], uint32(edc.t))
- marshalUint32(packet[9:], uint32(len(todo)))
- copy(packet[13:], todo)
- c.serverConn.lock.Lock()
- err = c.serverConn.writePacket(packet)
- c.serverConn.lock.Unlock()
- if err != nil {
- return
- }
- n += len(todo)
- data = data[len(todo):]
- }
- return
- }
- func (c *channel) Read(data []byte) (n int, err error) {
- c.lock.Lock()
- defer c.lock.Unlock()
- if c.err != nil {
- return 0, c.err
- }
- for {
- if c.theySentEOF || c.theyClosed || c.dead {
- return 0, io.EOF
- }
- if len(c.pendingRequests) > 0 {
- req := c.pendingRequests[0]
- if len(c.pendingRequests) == 1 {
- c.pendingRequests = nil
- } else {
- oldPendingRequests := c.pendingRequests
- c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
- copy(c.pendingRequests, oldPendingRequests[1:])
- }
- return 0, req
- }
- if c.length > 0 {
- tail := min(c.head+c.length, len(c.pendingData))
- n = copy(data, c.pendingData[c.head:tail])
- c.head += n
- c.length -= n
- if c.head == len(c.pendingData) {
- c.head = 0
- }
- windowAdjustment := uint32(len(c.pendingData)-c.length) - c.myWindow
- if windowAdjustment >= uint32(len(c.pendingData)/2) {
- packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
- PeersId: c.theirId,
- AdditionalBytes: windowAdjustment,
- })
- c.serverConn.lock.Lock()
- err = c.serverConn.writePacket(packet)
- c.serverConn.lock.Unlock()
- if err != nil {
- return
- }
- c.myWindow += windowAdjustment
- }
- return
- }
- c.cond.Wait()
- }
- panic("unreachable")
- }
- // getWindowSpace takes, at most, max bytes of space from the peer's window. It
- // returns the number of bytes actually reserved.
- func (c *channel) getWindowSpace(max uint32) (uint32, error) {
- c.lock.Lock()
- defer c.lock.Unlock()
- for {
- if c.dead || c.weClosed {
- return 0, io.EOF
- }
- if c.theirWindow > 0 {
- break
- }
- c.cond.Wait()
- }
- taken := c.theirWindow
- if taken > max {
- taken = max
- }
- c.theirWindow -= taken
- return taken, nil
- }
- func (c *channel) Write(data []byte) (n int, err error) {
- for len(data) > 0 {
- var space uint32
- if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
- return 0, err
- }
- todo := data
- if uint32(len(todo)) > space {
- todo = todo[:space]
- }
- packet := make([]byte, 1+4+4+len(todo))
- packet[0] = msgChannelData
- marshalUint32(packet[1:], c.theirId)
- marshalUint32(packet[5:], uint32(len(todo)))
- copy(packet[9:], todo)
- c.serverConn.lock.Lock()
- err = c.serverConn.writePacket(packet)
- c.serverConn.lock.Unlock()
- if err != nil {
- return
- }
- n += len(todo)
- data = data[len(todo):]
- }
- return
- }
- func (c *channel) Close() error {
- c.serverConn.lock.Lock()
- defer c.serverConn.lock.Unlock()
- if c.serverConn.err != nil {
- return c.serverConn.err
- }
- if c.weClosed {
- return errors.New("ssh: channel already closed")
- }
- c.weClosed = true
- closeMsg := channelCloseMsg{
- PeersId: c.theirId,
- }
- return c.serverConn.writePacket(marshal(msgChannelClose, closeMsg))
- }
- func (c *channel) AckRequest(ok bool) error {
- c.serverConn.lock.Lock()
- defer c.serverConn.lock.Unlock()
- if c.serverConn.err != nil {
- return c.serverConn.err
- }
- if !ok {
- ack := channelRequestFailureMsg{
- PeersId: c.theirId,
- }
- return c.serverConn.writePacket(marshal(msgChannelFailure, ack))
- }
- ack := channelRequestSuccessMsg{
- PeersId: c.theirId,
- }
- return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
- }
- func (c *channel) ChannelType() string {
- return c.chanType
- }
- func (c *channel) ExtraData() []byte {
- return c.extraData
- }
|