|
|
@@ -5,47 +5,65 @@
|
|
|
package ssh
|
|
|
|
|
|
import (
|
|
|
- "bufio"
|
|
|
"bytes"
|
|
|
+ "strings"
|
|
|
"testing"
|
|
|
)
|
|
|
|
|
|
func TestReadVersion(t *testing.T) {
|
|
|
- buf := serverVersion
|
|
|
- result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
|
|
|
- if err != nil {
|
|
|
- t.Errorf("readVersion didn't read version correctly: %s", err)
|
|
|
+ longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
|
|
|
+ cases := map[string]string{
|
|
|
+ "SSH-2.0-bla\r\n": "SSH-2.0-bla",
|
|
|
+ "SSH-2.0-bla\n": "SSH-2.0-bla",
|
|
|
+ longversion + "\r\n": longversion,
|
|
|
}
|
|
|
- if !bytes.Equal(buf[:len(buf)-2], result) {
|
|
|
- t.Error("version read did not match expected")
|
|
|
+
|
|
|
+ for in, want := range cases {
|
|
|
+ result, err := readVersion(bytes.NewBufferString(in))
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("readVersion(%q): %s", in, err)
|
|
|
+ }
|
|
|
+ got := string(result)
|
|
|
+ if got != want {
|
|
|
+ t.Errorf("got %q, want %q", got, want)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func TestReadVersionWithJustLF(t *testing.T) {
|
|
|
- var buf []byte
|
|
|
- buf = append(buf, serverVersion...)
|
|
|
- buf = buf[:len(buf)-1]
|
|
|
- buf[len(buf)-1] = '\n'
|
|
|
- result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
|
|
|
- if err != nil {
|
|
|
- t.Error("readVersion failed to handle just a \n")
|
|
|
+func TestReadVersionError(t *testing.T) {
|
|
|
+ longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
|
|
|
+ cases := []string{
|
|
|
+ longversion + "too-long\r\n",
|
|
|
}
|
|
|
- if !bytes.Equal(buf[:len(buf)-1], result) {
|
|
|
- t.Errorf("version read did not match expected: got %x, want %x", result, buf[:len(buf)-1])
|
|
|
+ for _, in := range cases {
|
|
|
+ if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
|
|
|
+ t.Errorf("readVersion(%q) should have failed", in)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func TestReadVersionTooLong(t *testing.T) {
|
|
|
- buf := make([]byte, maxVersionStringBytes+1)
|
|
|
- if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
|
|
|
- t.Errorf("readVersion consumed %d bytes without error", len(buf))
|
|
|
+func TestExchangeVersionsBasic(t *testing.T) {
|
|
|
+ v := "SSH-2.0-bla"
|
|
|
+ buf := bytes.NewBufferString(v + "\r\n")
|
|
|
+ them, err := exchangeVersions(buf, []byte("xyz"))
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("exchangeVersions: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if want := "SSH-2.0-bla"; string(them) != want {
|
|
|
+ t.Errorf("got %q want %q for our version", them, want)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func TestReadVersionWithoutCRLF(t *testing.T) {
|
|
|
- buf := serverVersion
|
|
|
- buf = buf[:len(buf)-1]
|
|
|
- if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
|
|
|
- t.Error("readVersion did not notice \\n was missing")
|
|
|
+func TestExchangeVersions(t *testing.T) {
|
|
|
+ cases := []string{
|
|
|
+ "not\x000allowed",
|
|
|
+ "not allowed\n",
|
|
|
+ }
|
|
|
+ for _, c := range cases {
|
|
|
+ buf := bytes.NewBufferString("SSH-2.0-bla\r\n")
|
|
|
+ if _, err := exchangeVersions(buf, []byte(c)); err == nil {
|
|
|
+ t.Errorf("exchangeVersions(%q): should have failed", c)
|
|
|
+ }
|
|
|
}
|
|
|
}
|