瀏覽代碼

x/crypto/otr: clear key slots when handshaking.

The OTR implementation had a bug where key slots would not be marked as
unused after a rehandshake. Since handshaking resets the key ids, some
key slots would be left over with much higher key ids. These key slots
would lead to an error when the code ran out of slots.

Fixes agl/xmpp-client#96.

Change-Id: I013bbc4eaf0616373ab52f14b7f7757c353983ca
Reviewed-on: https://go-review.googlesource.com/16934
Reviewed-by: Andrew Gerrand <adg@golang.org>
Adam Langley 10 年之前
父節點
當前提交
d438f321d3
共有 2 個文件被更改,包括 123 次插入66 次删除
  1. 13 6
      otr/otr.go
  2. 110 60
      otr/otr_test.go

+ 13 - 6
otr/otr.go

@@ -277,7 +277,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
 		in = in[len(msgPrefix) : len(in)-1]
 		in = in[len(msgPrefix) : len(in)-1]
 	} else if version := isQuery(in); version > 0 {
 	} else if version := isQuery(in); version > 0 {
 		c.authState = authStateAwaitingDHKey
 		c.authState = authStateAwaitingDHKey
-		c.myKeyId = 0
+		c.reset()
 		toSend = c.encode(c.generateDHCommit())
 		toSend = c.encode(c.generateDHCommit())
 		return
 		return
 	} else {
 	} else {
@@ -311,7 +311,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
 			if err = c.processDHCommit(msg); err != nil {
 			if err = c.processDHCommit(msg); err != nil {
 				return
 				return
 			}
 			}
-			c.myKeyId = 0
+			c.reset()
 			toSend = c.encode(c.generateDHKey())
 			toSend = c.encode(c.generateDHKey())
 			return
 			return
 		case authStateAwaitingDHKey:
 		case authStateAwaitingDHKey:
@@ -330,7 +330,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
 				if err = c.processDHCommit(msg); err != nil {
 				if err = c.processDHCommit(msg); err != nil {
 					return
 					return
 				}
 				}
-				c.myKeyId = 0
+				c.reset()
 				toSend = c.encode(c.generateDHKey())
 				toSend = c.encode(c.generateDHKey())
 				return
 				return
 			}
 			}
@@ -343,7 +343,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
 			if err = c.processDHCommit(msg); err != nil {
 			if err = c.processDHCommit(msg); err != nil {
 				return
 				return
 			}
 			}
-			c.myKeyId = 0
+			c.reset()
 			toSend = c.encode(c.generateDHKey())
 			toSend = c.encode(c.generateDHKey())
 			c.authState = authStateAwaitingRevealSig
 			c.authState = authStateAwaitingRevealSig
 		default:
 		default:
@@ -1036,8 +1036,7 @@ func (c *Conversation) calcDataKeys(myKeyId, theirKeyId uint32) (slot *keySlot,
 		}
 		}
 	}
 	}
 	if slot == nil {
 	if slot == nil {
-		err = errors.New("otr: internal error: no key slots")
-		return
+		return nil, errors.New("otr: internal error: no more key slots")
 	}
 	}
 
 
 	var myPriv, myPub, theirPub *big.Int
 	var myPriv, myPub, theirPub *big.Int
@@ -1163,6 +1162,14 @@ func (c *Conversation) encode(msg []byte) [][]byte {
 	return ret
 	return ret
 }
 }
 
 
+func (c *Conversation) reset() {
+	c.myKeyId = 0
+
+	for i := range c.keySlots {
+		c.keySlots[i].used = false
+	}
+}
+
 type PublicKey struct {
 type PublicKey struct {
 	dsa.PublicKey
 	dsa.PublicKey
 }
 }

+ 110 - 60
otr/otr_test.go

@@ -121,11 +121,12 @@ func TestSignVerify(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestConversation(t *testing.T) {
+func setupConversation(t *testing.T) (alice, bob *Conversation) {
 	alicePrivateKey, _ := hex.DecodeString(alicePrivateKeyHex)
 	alicePrivateKey, _ := hex.DecodeString(alicePrivateKeyHex)
 	bobPrivateKey, _ := hex.DecodeString(bobPrivateKeyHex)
 	bobPrivateKey, _ := hex.DecodeString(bobPrivateKeyHex)
 
 
-	var alice, bob Conversation
+	alice, bob = new(Conversation), new(Conversation)
+
 	alice.PrivateKey = new(PrivateKey)
 	alice.PrivateKey = new(PrivateKey)
 	bob.PrivateKey = new(PrivateKey)
 	bob.PrivateKey = new(PrivateKey)
 	alice.PrivateKey.Parse(alicePrivateKey)
 	alice.PrivateKey.Parse(alicePrivateKey)
@@ -133,12 +134,6 @@ func TestConversation(t *testing.T) {
 	alice.FragmentSize = 100
 	alice.FragmentSize = 100
 	bob.FragmentSize = 100
 	bob.FragmentSize = 100
 
 
-	var alicesMessage, bobsMessage [][]byte
-	var out []byte
-	var aliceChange, bobChange SecurityChange
-	var err error
-	alicesMessage = append(alicesMessage, []byte(QueryMessage))
-
 	if alice.IsEncrypted() {
 	if alice.IsEncrypted() {
 		t.Error("Alice believes that the conversation is secure before we've started")
 		t.Error("Alice believes that the conversation is secure before we've started")
 	}
 	}
@@ -146,6 +141,17 @@ func TestConversation(t *testing.T) {
 		t.Error("Bob believes that the conversation is secure before we've started")
 		t.Error("Bob believes that the conversation is secure before we've started")
 	}
 	}
 
 
+	performHandshake(t, alice, bob)
+	return alice, bob
+}
+
+func performHandshake(t *testing.T, alice, bob *Conversation) {
+	var alicesMessage, bobsMessage [][]byte
+	var out []byte
+	var aliceChange, bobChange SecurityChange
+	var err error
+	alicesMessage = append(alicesMessage, []byte(QueryMessage))
+
 	for round := 0; len(alicesMessage) > 0 || len(bobsMessage) > 0; round++ {
 	for round := 0; len(alicesMessage) > 0 || len(bobsMessage) > 0; round++ {
 		bobsMessage = nil
 		bobsMessage = nil
 		for i, msg := range alicesMessage {
 		for i, msg := range alicesMessage {
@@ -193,80 +199,109 @@ func TestConversation(t *testing.T) {
 	if !bob.IsEncrypted() {
 	if !bob.IsEncrypted() {
 		t.Error("Bob doesn't believe that the conversation is secure")
 		t.Error("Bob doesn't believe that the conversation is secure")
 	}
 	}
+}
+
+const (
+	firstRoundTrip = iota
+	subsequentRoundTrip
+	noMACKeyCheck
+)
+
+func roundTrip(t *testing.T, alice, bob *Conversation, message []byte, macKeyCheck int) {
+	alicesMessage, err := alice.Send(message)
+	if err != nil {
+		t.Errorf("Error from Alice sending message: %s", err)
+	}
 
 
-	var testMessages = [][]byte{
-		[]byte("hello"), []byte("bye"),
+	if len(alice.oldMACs) != 0 {
+		t.Errorf("Alice has not revealed all MAC keys")
 	}
 	}
 
 
-	for j, testMessage := range testMessages {
-		alicesMessage, err = alice.Send(testMessage)
+	for i, msg := range alicesMessage {
+		out, encrypted, _, _, err := bob.Receive(msg)
 
 
-		if len(alice.oldMACs) != 0 {
-			t.Errorf("Alice has not revealed all MAC keys")
+		if err != nil {
+			t.Errorf("Error generated while processing test message: %s", err.Error())
 		}
 		}
-
-		for i, msg := range alicesMessage {
-			out, encrypted, _, _, err := bob.Receive(msg)
-
-			if err != nil {
-				t.Errorf("Error generated while processing test message: %s", err.Error())
+		if len(out) > 0 {
+			if i != len(alicesMessage)-1 {
+				t.Fatal("Bob produced a message while processing a fragment of Alice's")
 			}
 			}
-			if len(out) > 0 {
-				if i != len(alicesMessage)-1 {
-					t.Fatal("Bob produced a message while processing a fragment of Alice's")
-				}
-				if !encrypted {
-					t.Errorf("Message was not marked as encrypted")
-				}
-				if !bytes.Equal(out, testMessage) {
-					t.Errorf("Message corrupted: got %x, want %x", out, testMessage)
-				}
+			if !encrypted {
+				t.Errorf("Message was not marked as encrypted")
 			}
 			}
-		}
-
-		if j == 0 {
-			if len(bob.oldMACs) != 0 {
-				t.Errorf("Bob should not have MAC keys to reveal")
+			if !bytes.Equal(out, message) {
+				t.Errorf("Message corrupted: got %x, want %x", out, message)
 			}
 			}
-		} else if len(bob.oldMACs) != 40 {
-			t.Errorf("Bob does not have MAC keys to reveal")
 		}
 		}
+	}
 
 
-		bobsMessage, err = bob.Send(testMessage)
-
+	switch macKeyCheck {
+	case firstRoundTrip:
 		if len(bob.oldMACs) != 0 {
 		if len(bob.oldMACs) != 0 {
-			t.Errorf("Bob has not revealed all MAC keys")
+			t.Errorf("Bob should not have MAC keys to reveal")
 		}
 		}
+	case subsequentRoundTrip:
+		if len(bob.oldMACs) != 40 {
+			t.Errorf("Bob has %d bytes of MAC keys to reveal, but should have 40", len(bob.oldMACs))
+		}
+	}
 
 
-		for i, msg := range bobsMessage {
-			out, encrypted, _, _, err := alice.Receive(msg)
+	bobsMessage, err := bob.Send(message)
+	if err != nil {
+		t.Errorf("Error from Bob sending message: %s", err)
+	}
 
 
-			if err != nil {
-				t.Errorf("Error generated while processing test message: %s", err.Error())
+	if len(bob.oldMACs) != 0 {
+		t.Errorf("Bob has not revealed all MAC keys")
+	}
+
+	for i, msg := range bobsMessage {
+		out, encrypted, _, _, err := alice.Receive(msg)
+
+		if err != nil {
+			t.Errorf("Error generated while processing test message: %s", err.Error())
+		}
+		if len(out) > 0 {
+			if i != len(bobsMessage)-1 {
+				t.Fatal("Alice produced a message while processing a fragment of Bob's")
 			}
 			}
-			if len(out) > 0 {
-				if i != len(bobsMessage)-1 {
-					t.Fatal("Alice produced a message while processing a fragment of Bob's")
-				}
-				if !encrypted {
-					t.Errorf("Message was not marked as encrypted")
-				}
-				if !bytes.Equal(out, testMessage) {
-					t.Errorf("Message corrupted: got %x, want %x", out, testMessage)
-				}
+			if !encrypted {
+				t.Errorf("Message was not marked as encrypted")
+			}
+			if !bytes.Equal(out, message) {
+				t.Errorf("Message corrupted: got %x, want %x", out, message)
 			}
 			}
 		}
 		}
+	}
 
 
-		if j == 0 {
-			if len(alice.oldMACs) != 20 {
-				t.Errorf("Alice does not have MAC keys to reveal")
-			}
-		} else if len(alice.oldMACs) != 40 {
-			t.Errorf("Alice does not have MAC keys to reveal")
+	switch macKeyCheck {
+	case firstRoundTrip:
+		if len(alice.oldMACs) != 20 {
+			t.Errorf("Alice has %d bytes of MAC keys to reveal, but should have 20", len(alice.oldMACs))
+		}
+	case subsequentRoundTrip:
+		if len(alice.oldMACs) != 40 {
+			t.Errorf("Alice has %d bytes of MAC keys to reveal, but should have 40", len(alice.oldMACs))
 		}
 		}
 	}
 	}
 }
 }
 
 
+func TestConversation(t *testing.T) {
+	alice, bob := setupConversation(t)
+
+	var testMessages = [][]byte{
+		[]byte("hello"), []byte("bye"),
+	}
+
+	roundTripType := firstRoundTrip
+
+	for _, testMessage := range testMessages {
+		roundTrip(t, alice, bob, testMessage, roundTripType)
+		roundTripType = subsequentRoundTrip
+	}
+}
+
 func TestGoodSMP(t *testing.T) {
 func TestGoodSMP(t *testing.T) {
 	var alice, bob Conversation
 	var alice, bob Conversation
 
 
@@ -348,6 +383,21 @@ func TestBadSMP(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestRehandshaking(t *testing.T) {
+	alice, bob := setupConversation(t)
+	roundTrip(t, alice, bob, []byte("test"), firstRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 2"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 3"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 4"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 5"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 6"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 7"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 8"), subsequentRoundTrip)
+	performHandshake(t, alice, bob)
+	roundTrip(t, alice, bob, []byte("test"), noMACKeyCheck)
+	roundTrip(t, alice, bob, []byte("test 2"), noMACKeyCheck)
+}
+
 func TestAgainstLibOTR(t *testing.T) {
 func TestAgainstLibOTR(t *testing.T) {
 	// This test requires otr.c.test to be built as /tmp/a.out.
 	// This test requires otr.c.test to be built as /tmp/a.out.
 	// If enabled, this tests runs forever performing OTR handshakes in a
 	// If enabled, this tests runs forever performing OTR handshakes in a