transport.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "bufio"
  7. "crypto"
  8. "crypto/cipher"
  9. "crypto/hmac"
  10. "crypto/sha1"
  11. "crypto/subtle"
  12. "errors"
  13. "hash"
  14. "io"
  15. "net"
  16. "sync"
  17. )
  18. const (
  19. packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
  20. minPacketSize = 16
  21. maxPacketSize = 36000
  22. minPaddingSize = 4 // TODO(huin) should this be configurable?
  23. )
  24. // filteredConn reduces the set of methods exposed when embeddeding
  25. // a net.Conn inside ssh.transport.
  26. // TODO(dfc) suggestions for a better name will be warmly received.
  27. type filteredConn interface {
  28. // Close closes the connection.
  29. Close() error
  30. // LocalAddr returns the local network address.
  31. LocalAddr() net.Addr
  32. // RemoteAddr returns the remote network address.
  33. RemoteAddr() net.Addr
  34. }
  35. // Types implementing packetWriter provide the ability to send packets to
  36. // an SSH peer.
  37. type packetWriter interface {
  38. // Encrypt and send a packet of data to the remote peer.
  39. writePacket(packet []byte) error
  40. }
  41. // transport represents the SSH connection to the remote peer.
  42. type transport struct {
  43. reader
  44. writer
  45. filteredConn
  46. }
  47. // reader represents the incoming connection state.
  48. type reader struct {
  49. io.Reader
  50. common
  51. }
  52. // writer represnts the outgoing connection state.
  53. type writer struct {
  54. *sync.Mutex // protects writer.Writer from concurrent writes
  55. *bufio.Writer
  56. rand io.Reader
  57. common
  58. }
  59. // common represents the cipher state needed to process messages in a single
  60. // direction.
  61. type common struct {
  62. seqNum uint32
  63. mac hash.Hash
  64. cipher cipher.Stream
  65. cipherAlgo string
  66. macAlgo string
  67. compressionAlgo string
  68. }
  69. // Read and decrypt a single packet from the remote peer.
  70. func (r *reader) readOnePacket() ([]byte, error) {
  71. var lengthBytes = make([]byte, 5)
  72. var macSize uint32
  73. if _, err := io.ReadFull(r, lengthBytes); err != nil {
  74. return nil, err
  75. }
  76. r.cipher.XORKeyStream(lengthBytes, lengthBytes)
  77. if r.mac != nil {
  78. r.mac.Reset()
  79. seqNumBytes := []byte{
  80. byte(r.seqNum >> 24),
  81. byte(r.seqNum >> 16),
  82. byte(r.seqNum >> 8),
  83. byte(r.seqNum),
  84. }
  85. r.mac.Write(seqNumBytes)
  86. r.mac.Write(lengthBytes)
  87. macSize = uint32(r.mac.Size())
  88. }
  89. length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
  90. paddingLength := uint32(lengthBytes[4])
  91. if length <= paddingLength+1 {
  92. return nil, errors.New("invalid packet length")
  93. }
  94. if length > maxPacketSize {
  95. return nil, errors.New("packet too large")
  96. }
  97. packet := make([]byte, length-1+macSize)
  98. if _, err := io.ReadFull(r, packet); err != nil {
  99. return nil, err
  100. }
  101. mac := packet[length-1:]
  102. r.cipher.XORKeyStream(packet, packet[:length-1])
  103. if r.mac != nil {
  104. r.mac.Write(packet[:length-1])
  105. if subtle.ConstantTimeCompare(r.mac.Sum(nil), mac) != 1 {
  106. return nil, errors.New("ssh: MAC failure")
  107. }
  108. }
  109. r.seqNum++
  110. return packet[:length-paddingLength-1], nil
  111. }
  112. // Read and decrypt next packet discarding debug and noop messages.
  113. func (t *transport) readPacket() ([]byte, error) {
  114. for {
  115. packet, err := t.readOnePacket()
  116. if err != nil {
  117. return nil, err
  118. }
  119. if packet[0] != msgIgnore && packet[0] != msgDebug {
  120. return packet, nil
  121. }
  122. }
  123. panic("unreachable")
  124. }
  125. // Encrypt and send a packet of data to the remote peer.
  126. func (w *writer) writePacket(packet []byte) error {
  127. w.Mutex.Lock()
  128. defer w.Mutex.Unlock()
  129. paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
  130. if paddingLength < 4 {
  131. paddingLength += packetSizeMultiple
  132. }
  133. length := len(packet) + 1 + paddingLength
  134. lengthBytes := []byte{
  135. byte(length >> 24),
  136. byte(length >> 16),
  137. byte(length >> 8),
  138. byte(length),
  139. byte(paddingLength),
  140. }
  141. padding := make([]byte, paddingLength)
  142. _, err := io.ReadFull(w.rand, padding)
  143. if err != nil {
  144. return err
  145. }
  146. if w.mac != nil {
  147. w.mac.Reset()
  148. seqNumBytes := []byte{
  149. byte(w.seqNum >> 24),
  150. byte(w.seqNum >> 16),
  151. byte(w.seqNum >> 8),
  152. byte(w.seqNum),
  153. }
  154. w.mac.Write(seqNumBytes)
  155. w.mac.Write(lengthBytes)
  156. w.mac.Write(packet)
  157. w.mac.Write(padding)
  158. }
  159. // TODO(dfc) lengthBytes, packet and padding should be
  160. // subslices of a single buffer
  161. w.cipher.XORKeyStream(lengthBytes, lengthBytes)
  162. w.cipher.XORKeyStream(packet, packet)
  163. w.cipher.XORKeyStream(padding, padding)
  164. if _, err := w.Write(lengthBytes); err != nil {
  165. return err
  166. }
  167. if _, err := w.Write(packet); err != nil {
  168. return err
  169. }
  170. if _, err := w.Write(padding); err != nil {
  171. return err
  172. }
  173. if w.mac != nil {
  174. if _, err := w.Write(w.mac.Sum(nil)); err != nil {
  175. return err
  176. }
  177. }
  178. if err := w.Flush(); err != nil {
  179. return err
  180. }
  181. w.seqNum++
  182. return err
  183. }
  184. // Send a message to the remote peer
  185. func (t *transport) sendMessage(typ uint8, msg interface{}) error {
  186. packet := marshal(typ, msg)
  187. return t.writePacket(packet)
  188. }
  189. func newTransport(conn net.Conn, rand io.Reader) *transport {
  190. return &transport{
  191. reader: reader{
  192. Reader: bufio.NewReader(conn),
  193. common: common{
  194. cipher: noneCipher{},
  195. },
  196. },
  197. writer: writer{
  198. Writer: bufio.NewWriter(conn),
  199. rand: rand,
  200. Mutex: new(sync.Mutex),
  201. common: common{
  202. cipher: noneCipher{},
  203. },
  204. },
  205. filteredConn: conn,
  206. }
  207. }
  208. type direction struct {
  209. ivTag []byte
  210. keyTag []byte
  211. macKeyTag []byte
  212. }
  213. // TODO(dfc) can this be made a constant ?
  214. var (
  215. serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
  216. clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
  217. )
  218. // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
  219. // described in RFC 4253, section 6.4. direction should either be serverKeys
  220. // (to setup server->client keys) or clientKeys (for client->server keys).
  221. func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
  222. cipherMode := cipherModes[c.cipherAlgo]
  223. macKeySize := 20
  224. iv := make([]byte, cipherMode.ivSize)
  225. key := make([]byte, cipherMode.keySize)
  226. macKey := make([]byte, macKeySize)
  227. h := hashFunc.New()
  228. generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
  229. generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
  230. generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
  231. c.mac = truncatingMAC{12, hmac.New(sha1.New, macKey)}
  232. cipher, err := cipherMode.createCipher(key, iv)
  233. if err != nil {
  234. return err
  235. }
  236. c.cipher = cipher
  237. return nil
  238. }
  239. // generateKeyMaterial fills out with key material generated from tag, K, H
  240. // and sessionId, as specified in RFC 4253, section 7.2.
  241. func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
  242. var digestsSoFar []byte
  243. for len(out) > 0 {
  244. h.Reset()
  245. h.Write(K)
  246. h.Write(H)
  247. if len(digestsSoFar) == 0 {
  248. h.Write(tag)
  249. h.Write(sessionId)
  250. } else {
  251. h.Write(digestsSoFar)
  252. }
  253. digest := h.Sum(nil)
  254. n := copy(out, digest)
  255. out = out[n:]
  256. if len(out) > 0 {
  257. digestsSoFar = append(digestsSoFar, digest...)
  258. }
  259. }
  260. }
  261. // truncatingMAC wraps around a hash.Hash and truncates the output digest to
  262. // a given size.
  263. type truncatingMAC struct {
  264. length int
  265. hmac hash.Hash
  266. }
  267. func (t truncatingMAC) Write(data []byte) (int, error) {
  268. return t.hmac.Write(data)
  269. }
  270. func (t truncatingMAC) Sum(in []byte) []byte {
  271. out := t.hmac.Sum(in)
  272. return out[:len(in)+t.length]
  273. }
  274. func (t truncatingMAC) Reset() {
  275. t.hmac.Reset()
  276. }
  277. func (t truncatingMAC) Size() int {
  278. return t.length
  279. }
  280. func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
  281. // maxVersionStringBytes is the maximum number of bytes that we'll accept as a
  282. // version string. In the event that the client is talking a different protocol
  283. // we need to set a limit otherwise we will keep using more and more memory
  284. // while searching for the end of the version handshake.
  285. const maxVersionStringBytes = 1024
  286. // Read version string as specified by RFC 4253, section 4.2.
  287. func readVersion(r io.Reader) ([]byte, error) {
  288. versionString := make([]byte, 0, 64)
  289. var ok bool
  290. var buf [1]byte
  291. forEachByte:
  292. for len(versionString) < maxVersionStringBytes {
  293. _, err := io.ReadFull(r, buf[:])
  294. if err != nil {
  295. return nil, err
  296. }
  297. // The RFC says that the version should be terminated with \r\n
  298. // but several SSH servers actually only send a \n.
  299. if buf[0] == '\n' {
  300. ok = true
  301. break forEachByte
  302. }
  303. versionString = append(versionString, buf[0])
  304. }
  305. if !ok {
  306. return nil, errors.New("ssh: failed to read version string")
  307. }
  308. // There might be a '\r' on the end which we should remove.
  309. if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
  310. versionString = versionString[:len(versionString)-1]
  311. }
  312. return versionString, nil
  313. }