|
@@ -0,0 +1,180 @@
|
|
|
|
|
+// Copyright 2012 The Go Authors. All rights reserved.
|
|
|
|
|
+// Use of this source code is governed by a BSD-style
|
|
|
|
|
+// license that can be found in the LICENSE file.
|
|
|
|
|
+
|
|
|
|
|
+package ssh
|
|
|
|
|
+
|
|
|
|
|
+import (
|
|
|
|
|
+ "bytes"
|
|
|
|
|
+ crypto_rand "crypto/rand"
|
|
|
|
|
+ "io"
|
|
|
|
|
+ "math/rand"
|
|
|
|
|
+ "testing"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+// windowTestBytes is the number of bytes that we'll send to the SSH server.
|
|
|
|
|
+const windowTestBytes = 16000 * 200
|
|
|
|
|
+
|
|
|
|
|
+// CopyNRandomly copies n bytes from src to dst. It uses a variable, and random,
|
|
|
|
|
+// buffer size to exercise more code paths.
|
|
|
|
|
+func CopyNRandomly(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
|
|
|
|
|
+ buf := make([]byte, 32*1024)
|
|
|
|
|
+ for written < n {
|
|
|
|
|
+ l := (rand.Intn(30) + 1) * 1024
|
|
|
|
|
+ if d := n - written; d < int64(l) {
|
|
|
|
|
+ l = int(d)
|
|
|
|
|
+ }
|
|
|
|
|
+ nr, er := src.Read(buf[0:l])
|
|
|
|
|
+ if nr > 0 {
|
|
|
|
|
+ nw, ew := dst.Write(buf[0:nr])
|
|
|
|
|
+ if nw > 0 {
|
|
|
|
|
+ written += int64(nw)
|
|
|
|
|
+ }
|
|
|
|
|
+ if ew != nil {
|
|
|
|
|
+ err = ew
|
|
|
|
|
+ break
|
|
|
|
|
+ }
|
|
|
|
|
+ if nr != nw {
|
|
|
|
|
+ err = io.ErrShortWrite
|
|
|
|
|
+ break
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ if er != nil {
|
|
|
|
|
+ err = er
|
|
|
|
|
+ break
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return written, err
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func TestServerWindow(t *testing.T) {
|
|
|
|
|
+ addr := startSSHServer(t)
|
|
|
|
|
+ runSSHClient(t, addr)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// runSSHClient writes random data to the server. The server is expected to echo
|
|
|
|
|
+// the same data back, which is compared against the original.
|
|
|
|
|
+func runSSHClient(t *testing.T, addr string) {
|
|
|
|
|
+ conn, err := Dial("tcp", addr, &ClientConfig{})
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatal(err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ session, err := conn.NewSession()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatal(err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
|
|
|
|
|
+ echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
|
|
|
|
|
+ io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
|
|
|
|
|
+ origBytes := origBuf.Bytes()
|
|
|
|
|
+
|
|
|
|
|
+ wait := make(chan bool)
|
|
|
|
|
+
|
|
|
|
|
+ // Read back the data from the server.
|
|
|
|
|
+ go func() {
|
|
|
|
|
+ defer session.Close()
|
|
|
|
|
+ serverStdout, err := session.StdoutPipe()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatal(err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ n, err := CopyNRandomly(echoedBuf, serverStdout, windowTestBytes)
|
|
|
|
|
+ if err != nil && err != io.EOF {
|
|
|
|
|
+ t.Fatal(err)
|
|
|
|
|
+ }
|
|
|
|
|
+ if n != windowTestBytes {
|
|
|
|
|
+ t.Fatalf("Read only %d bytes from server, expected %d", n, windowTestBytes)
|
|
|
|
|
+ }
|
|
|
|
|
+ wait <- true
|
|
|
|
|
+ }()
|
|
|
|
|
+
|
|
|
|
|
+ serverStdin, err := session.StdinPipe()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatal(err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ written, err := CopyNRandomly(serverStdin, origBuf, windowTestBytes)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatal(err)
|
|
|
|
|
+ }
|
|
|
|
|
+ if written != windowTestBytes {
|
|
|
|
|
+ t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ <-wait
|
|
|
|
|
+
|
|
|
|
|
+ if !bytes.Equal(origBytes, echoedBuf.Bytes()) {
|
|
|
|
|
+ t.Error("Echoed buffer differed from original")
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func startSSHServer(t *testing.T) (addr string) {
|
|
|
|
|
+ config := &ServerConfig{
|
|
|
|
|
+ NoClientAuth: true,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ err := config.SetRSAPrivateKey([]byte(testServerPrivateKey))
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatalf("Failed to parse private key: %s", err.Error())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ listener, err := Listen("tcp", ":0", config)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatalf("Bind error: %s", err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ addr = listener.Addr().String()
|
|
|
|
|
+
|
|
|
|
|
+ go func() {
|
|
|
|
|
+ for {
|
|
|
|
|
+ sConn, err := listener.Accept()
|
|
|
|
|
+
|
|
|
|
|
+ err = sConn.Handshake()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ if err != io.EOF {
|
|
|
|
|
+ t.Fatalf("failed to handshake: %s", err)
|
|
|
|
|
+ }
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ go connRun(t, sConn)
|
|
|
|
|
+ }
|
|
|
|
|
+ }()
|
|
|
|
|
+
|
|
|
|
|
+ return
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func connRun(t *testing.T, sConn *ServerConn) {
|
|
|
|
|
+ for {
|
|
|
|
|
+ channel, err := sConn.Accept()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ if err == io.EOF {
|
|
|
|
|
+ break
|
|
|
|
|
+ }
|
|
|
|
|
+ t.Fatalf("ServerConn.Accept failed: %s", err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if channel.ChannelType() != "session" {
|
|
|
|
|
+ channel.Reject(UnknownChannelType, "unknown channel type")
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ err = channel.Accept()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatalf("Channel.Accept failed: %s", err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ go func() {
|
|
|
|
|
+ defer channel.Close()
|
|
|
|
|
+
|
|
|
|
|
+ n, err := CopyNRandomly(channel, channel, windowTestBytes)
|
|
|
|
|
+ if err != nil && err != io.EOF {
|
|
|
|
|
+ if err == io.ErrShortWrite {
|
|
|
|
|
+ t.Fatalf("short write, wrote %d, expected %d", n, windowTestBytes)
|
|
|
|
|
+ }
|
|
|
|
|
+ t.Fatal(err)
|
|
|
|
|
+ }
|
|
|
|
|
+ }()
|
|
|
|
|
+ }
|
|
|
|
|
+}
|