Browse Source

golang.org/x/crypto/openssh: don't loop forever after a bad password.

SymmetricKeyEncrypted cached the results of decryption so, if a bad
password was given, ReadMessage would prompt forever because a later,
correct password wouldn't override the cached decryption.

The SymmetricKeyEncrypted object can't know whether a given passphrase
is correct so it should never have been a mutable object in the first
place. This change makes it so that it doesn't cache anything.

Fixes #9315

Change-Id: Ic2b75f7f60a575e2182ac7e5c5d4198597c5d0a2
Reviewed-on: https://go-review.googlesource.com/14038
Reviewed-by: Andrew Gerrand <adg@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
Adam Langley 10 years ago
parent
commit
0c93e1ff9f

+ 28 - 37
openpgp/packet/symmetric_key_encrypted.go

@@ -22,20 +22,17 @@ const maxSessionKeySizeInBytes = 64
 // 4880, section 5.3.
 type SymmetricKeyEncrypted struct {
 	CipherFunc   CipherFunction
-	Encrypted    bool
-	Key          []byte // Empty unless Encrypted is false.
 	s2k          func(out, in []byte)
 	encryptedKey []byte
 }
 
 const symmetricKeyEncryptedVersion = 4
 
-func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
+func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error {
 	// RFC 4880, section 5.3.
 	var buf [2]byte
-	_, err = readFull(r, buf[:])
-	if err != nil {
-		return
+	if _, err := readFull(r, buf[:]); err != nil {
+		return err
 	}
 	if buf[0] != symmetricKeyEncryptedVersion {
 		return errors.UnsupportedError("SymmetricKeyEncrypted version")
@@ -46,9 +43,10 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
 		return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
 	}
 
+	var err error
 	ske.s2k, err = s2k.Parse(r)
 	if err != nil {
-		return
+		return err
 	}
 
 	encryptedKey := make([]byte, maxSessionKeySizeInBytes)
@@ -56,9 +54,9 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
 	// out. If it exists then we limit it to maxSessionKeySizeInBytes.
 	n, err := readFull(r, encryptedKey)
 	if err != nil && err != io.ErrUnexpectedEOF {
-		return
+		return err
 	}
-	err = nil
+
 	if n != 0 {
 		if n == maxSessionKeySizeInBytes {
 			return errors.UnsupportedError("oversized encrypted session key")
@@ -66,42 +64,35 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
 		ske.encryptedKey = encryptedKey[:n]
 	}
 
-	ske.Encrypted = true
-
-	return
+	return nil
 }
 
-// Decrypt attempts to decrypt an encrypted session key. If it returns nil,
-// ske.Key will contain the session key.
-func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error {
-	if !ske.Encrypted {
-		return nil
-	}
-
+// Decrypt attempts to decrypt an encrypted session key and returns the key and
+// the cipher to use when decrypting a subsequent Symmetrically Encrypted Data
+// packet.
+func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) ([]byte, CipherFunction, error) {
 	key := make([]byte, ske.CipherFunc.KeySize())
 	ske.s2k(key, passphrase)
 
 	if len(ske.encryptedKey) == 0 {
-		ske.Key = key
-	} else {
-		// the IV is all zeros
-		iv := make([]byte, ske.CipherFunc.blockSize())
-		c := cipher.NewCFBDecrypter(ske.CipherFunc.new(key), iv)
-		c.XORKeyStream(ske.encryptedKey, ske.encryptedKey)
-		ske.CipherFunc = CipherFunction(ske.encryptedKey[0])
-		if ske.CipherFunc.blockSize() == 0 {
-			return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc)))
-		}
-		ske.CipherFunc = CipherFunction(ske.encryptedKey[0])
-		ske.Key = ske.encryptedKey[1:]
-		if len(ske.Key)%ske.CipherFunc.blockSize() != 0 {
-			ske.Key = nil
-			return errors.StructuralError("length of decrypted key not a multiple of block size")
-		}
+		return key, ske.CipherFunc, nil
 	}
 
-	ske.Encrypted = false
-	return nil
+	// the IV is all zeros
+	iv := make([]byte, ske.CipherFunc.blockSize())
+	c := cipher.NewCFBDecrypter(ske.CipherFunc.new(key), iv)
+	plaintextKey := make([]byte, len(ske.encryptedKey))
+	c.XORKeyStream(plaintextKey, ske.encryptedKey)
+	cipherFunc := CipherFunction(plaintextKey[0])
+	if cipherFunc.blockSize() == 0 {
+		return nil, ske.CipherFunc, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc)))
+	}
+	plaintextKey = plaintextKey[1:]
+	if l := len(plaintextKey); l == 0 || l%cipherFunc.blockSize() != 0 {
+		return nil, cipherFunc, errors.StructuralError("length of decrypted key not a multiple of block size")
+	}
+
+	return plaintextKey, cipherFunc, nil
 }
 
 // SerializeSymmetricKeyEncrypted serializes a symmetric key packet to w. The

+ 10 - 9
openpgp/packet/symmetric_key_encrypted_test.go

@@ -24,7 +24,7 @@ func TestSymmetricKeyEncrypted(t *testing.T) {
 		t.Error("didn't find SymmetricKeyEncrypted packet")
 		return
 	}
-	err = ske.Decrypt([]byte("password"))
+	key, cipherFunc, err := ske.Decrypt([]byte("password"))
 	if err != nil {
 		t.Error(err)
 		return
@@ -40,7 +40,7 @@ func TestSymmetricKeyEncrypted(t *testing.T) {
 		t.Error("didn't find SymmetricallyEncrypted packet")
 		return
 	}
-	r, err := se.Decrypt(ske.CipherFunc, ske.Key)
+	r, err := se.Decrypt(cipherFunc, key)
 	if err != nil {
 		t.Error(err)
 		return
@@ -64,8 +64,9 @@ const symmetricallyEncryptedContentsHex = "cb1062004d14c4df636f6e74656e74732e0a"
 func TestSerializeSymmetricKeyEncrypted(t *testing.T) {
 	buf := bytes.NewBuffer(nil)
 	passphrase := []byte("testing")
+	const cipherFunc = CipherAES128
 	config := &Config{
-		DefaultCipher: CipherAES128,
+		DefaultCipher: cipherFunc,
 	}
 
 	key, err := SerializeSymmetricKeyEncrypted(buf, passphrase, config)
@@ -85,18 +86,18 @@ func TestSerializeSymmetricKeyEncrypted(t *testing.T) {
 		return
 	}
 
-	if !ske.Encrypted {
-		t.Errorf("SKE not encrypted but should be")
-	}
 	if ske.CipherFunc != config.DefaultCipher {
 		t.Errorf("SKE cipher function is %d (expected %d)", ske.CipherFunc, config.DefaultCipher)
 	}
-	err = ske.Decrypt(passphrase)
+	parsedKey, parsedCipherFunc, err := ske.Decrypt(passphrase)
 	if err != nil {
 		t.Errorf("failed to decrypt reparsed SKE: %s", err)
 		return
 	}
-	if !bytes.Equal(key, ske.Key) {
-		t.Errorf("keys don't match after Decrpyt: %x (original) vs %x (parsed)", key, ske.Key)
+	if !bytes.Equal(key, parsedKey) {
+		t.Errorf("keys don't match after Decrypt: %x (original) vs %x (parsed)", key, parsedKey)
+	}
+	if parsedCipherFunc != cipherFunc {
+		t.Errorf("cipher function doesn't match after Decrypt: %d (original) vs %d (parsed)", cipherFunc, parsedCipherFunc)
 	}
 }

+ 3 - 3
openpgp/read.go

@@ -196,9 +196,9 @@ FindKey:
 		// Try the symmetric passphrase first
 		if len(symKeys) != 0 && passphrase != nil {
 			for _, s := range symKeys {
-				err = s.Decrypt(passphrase)
-				if err == nil && !s.Encrypted {
-					decrypted, err = se.Decrypt(s.CipherFunc, s.Key)
+				key, cipherFunc, err := s.Decrypt(passphrase)
+				if err == nil {
+					decrypted, err = se.Decrypt(cipherFunc, key)
 					if err != nil && err != errors.ErrKeyIncorrect {
 						return nil, err
 					}

+ 7 - 1
openpgp/read_test.go

@@ -243,7 +243,7 @@ func TestUnspecifiedRecipient(t *testing.T) {
 }
 
 func TestSymmetricallyEncrypted(t *testing.T) {
-	expected := "Symmetrically encrypted.\n"
+	firstTimeCalled := true
 
 	prompt := func(keys []Key, symmetric bool) ([]byte, error) {
 		if len(keys) != 0 {
@@ -254,6 +254,11 @@ func TestSymmetricallyEncrypted(t *testing.T) {
 			t.Errorf("symmetric is not set")
 		}
 
+		if firstTimeCalled {
+			firstTimeCalled = false
+			return []byte("wrongpassword"), nil
+		}
+
 		return []byte("password"), nil
 	}
 
@@ -273,6 +278,7 @@ func TestSymmetricallyEncrypted(t *testing.T) {
 		t.Errorf("LiteralData.Time is %d, want %d", md.LiteralData.Time, expectedCreationTime)
 	}
 
+	const expected = "Symmetrically encrypted.\n"
 	if string(contents) != expected {
 		t.Errorf("contents got: %s want: %s", string(contents), expected)
 	}