瀏覽代碼

ssh/terminal: fix line endings handling in ReadPassword

Fixes golang/go#16552

Change-Id: I18a9c9b42fe042c4871b3efb3f51bef7cca335d0
Reviewed-on: https://go-review.googlesource.com/25355
Reviewed-by: Adam Langley <alangley@gmail.com>
Reviewed-by: Adam Langley <agl@golang.org>
Alex Brainman 9 年之前
父節點
當前提交
13d9f6188e
共有 4 個文件被更改,包括 88 次插入51 次删除
  1. 28 0
      ssh/terminal/terminal.go
  2. 44 0
      ssh/terminal/terminal_test.go
  3. 8 24
      ssh/terminal/util.go
  4. 8 27
      ssh/terminal/util_windows.go

+ 28 - 0
ssh/terminal/terminal.go

@@ -920,3 +920,31 @@ func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) {
 	}
 	}
 	return s.entries[index], true
 	return s.entries[index], true
 }
 }
+
+// readPasswordLine reads from reader until it finds \n or io.EOF.
+// The slice returned does not include the \n.
+// readPasswordLine also ignores any \r it finds.
+func readPasswordLine(reader io.Reader) ([]byte, error) {
+	var buf [1]byte
+	var ret []byte
+
+	for {
+		n, err := reader.Read(buf[:])
+		if err != nil {
+			if err == io.EOF && len(ret) > 0 {
+				return ret, nil
+			}
+			return ret, err
+		}
+		if n > 0 {
+			switch buf[0] {
+			case '\n':
+				return ret, nil
+			case '\r':
+				// remove \r from passwords on Windows
+			default:
+				ret = append(ret, buf[0])
+			}
+		}
+	}
+}

+ 44 - 0
ssh/terminal/terminal_test.go

@@ -270,6 +270,50 @@ func TestTerminalSetSize(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestReadPasswordLineEnd(t *testing.T) {
+	var tests = []struct {
+		input string
+		want  string
+	}{
+		{"\n", ""},
+		{"\r\n", ""},
+		{"test\r\n", "test"},
+		{"testtesttesttes\n", "testtesttesttes"},
+		{"testtesttesttes\r\n", "testtesttesttes"},
+		{"testtesttesttesttest\n", "testtesttesttesttest"},
+		{"testtesttesttesttest\r\n", "testtesttesttesttest"},
+	}
+	for _, test := range tests {
+		buf := new(bytes.Buffer)
+		if _, err := buf.WriteString(test.input); err != nil {
+			t.Fatal(err)
+		}
+
+		have, err := readPasswordLine(buf)
+		if err != nil {
+			t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
+			continue
+		}
+		if string(have) != test.want {
+			t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
+			continue
+		}
+
+		if _, err = buf.WriteString(test.input); err != nil {
+			t.Fatal(err)
+		}
+		have, err = readPasswordLine(buf)
+		if err != nil {
+			t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
+			continue
+		}
+		if string(have) != test.want {
+			t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
+			continue
+		}
+	}
+}
+
 func TestMakeRawState(t *testing.T) {
 func TestMakeRawState(t *testing.T) {
 	fd := int(os.Stdout.Fd())
 	fd := int(os.Stdout.Fd())
 	if !IsTerminal(fd) {
 	if !IsTerminal(fd) {

+ 8 - 24
ssh/terminal/util.go

@@ -17,7 +17,6 @@
 package terminal // import "golang.org/x/crypto/ssh/terminal"
 package terminal // import "golang.org/x/crypto/ssh/terminal"
 
 
 import (
 import (
-	"io"
 	"syscall"
 	"syscall"
 	"unsafe"
 	"unsafe"
 )
 )
@@ -88,6 +87,13 @@ func GetSize(fd int) (width, height int, err error) {
 	return int(dimensions[1]), int(dimensions[0]), nil
 	return int(dimensions[1]), int(dimensions[0]), nil
 }
 }
 
 
+// passwordReader is an io.Reader that reads from a specific file descriptor.
+type passwordReader int
+
+func (r passwordReader) Read(buf []byte) (int, error) {
+	return syscall.Read(int(r), buf)
+}
+
 // ReadPassword reads a line of input from a terminal without local echo.  This
 // ReadPassword reads a line of input from a terminal without local echo.  This
 // is commonly used for inputting passwords and other sensitive data. The slice
 // is commonly used for inputting passwords and other sensitive data. The slice
 // returned does not include the \n.
 // returned does not include the \n.
@@ -109,27 +115,5 @@ func ReadPassword(fd int) ([]byte, error) {
 		syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0)
 		syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0)
 	}()
 	}()
 
 
-	var buf [16]byte
-	var ret []byte
-	for {
-		n, err := syscall.Read(fd, buf[:])
-		if err != nil {
-			return nil, err
-		}
-		if n == 0 {
-			if len(ret) == 0 {
-				return nil, io.EOF
-			}
-			break
-		}
-		if buf[n-1] == '\n' {
-			n--
-		}
-		ret = append(ret, buf[:n]...)
-		if n < len(buf) {
-			break
-		}
-	}
-
-	return ret, nil
+	return readPasswordLine(passwordReader(fd))
 }
 }

+ 8 - 27
ssh/terminal/util_windows.go

@@ -17,7 +17,6 @@
 package terminal
 package terminal
 
 
 import (
 import (
-	"io"
 	"syscall"
 	"syscall"
 	"unsafe"
 	"unsafe"
 )
 )
@@ -123,6 +122,13 @@ func GetSize(fd int) (width, height int, err error) {
 	return int(info.size.x), int(info.size.y), nil
 	return int(info.size.x), int(info.size.y), nil
 }
 }
 
 
+// passwordReader is an io.Reader that reads from a specific Windows HANDLE.
+type passwordReader int
+
+func (r passwordReader) Read(buf []byte) (int, error) {
+	return syscall.Read(syscall.Handle(r), buf)
+}
+
 // ReadPassword reads a line of input from a terminal without local echo.  This
 // ReadPassword reads a line of input from a terminal without local echo.  This
 // is commonly used for inputting passwords and other sensitive data. The slice
 // is commonly used for inputting passwords and other sensitive data. The slice
 // returned does not include the \n.
 // returned does not include the \n.
@@ -145,30 +151,5 @@ func ReadPassword(fd int) ([]byte, error) {
 		syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0)
 		syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0)
 	}()
 	}()
 
 
-	var buf [16]byte
-	var ret []byte
-	for {
-		n, err := syscall.Read(syscall.Handle(fd), buf[:])
-		if err != nil {
-			return nil, err
-		}
-		if n == 0 {
-			if len(ret) == 0 {
-				return nil, io.EOF
-			}
-			break
-		}
-		if buf[n-1] == '\n' {
-			n--
-		}
-		if n > 0 && buf[n-1] == '\r' {
-			n--
-		}
-		ret = append(ret, buf[:n]...)
-		if n < len(buf) {
-			break
-		}
-	}
-
-	return ret, nil
+	return readPasswordLine(passwordReader(fd))
 }
 }