Browse Source

x/crypto/otr: clear key slots when handshaking.

The OTR implementation had a bug where key slots would not be marked as
unused after a rehandshake. Since handshaking resets the key ids, some
key slots would be left over with much higher key ids. These key slots
would lead to an error when the code ran out of slots.

Fixes agl/xmpp-client#96.

Change-Id: I013bbc4eaf0616373ab52f14b7f7757c353983ca
Reviewed-on: https://go-review.googlesource.com/16934
Reviewed-by: Andrew Gerrand <adg@golang.org>
Adam Langley 10 years ago
parent
commit
d438f321d3
2 changed files with 123 additions and 66 deletions
  1. 13 6
      otr/otr.go
  2. 110 60
      otr/otr_test.go

+ 13 - 6
otr/otr.go

@@ -277,7 +277,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
 		in = in[len(msgPrefix) : len(in)-1]
 	} else if version := isQuery(in); version > 0 {
 		c.authState = authStateAwaitingDHKey
-		c.myKeyId = 0
+		c.reset()
 		toSend = c.encode(c.generateDHCommit())
 		return
 	} else {
@@ -311,7 +311,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
 			if err = c.processDHCommit(msg); err != nil {
 				return
 			}
-			c.myKeyId = 0
+			c.reset()
 			toSend = c.encode(c.generateDHKey())
 			return
 		case authStateAwaitingDHKey:
@@ -330,7 +330,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
 				if err = c.processDHCommit(msg); err != nil {
 					return
 				}
-				c.myKeyId = 0
+				c.reset()
 				toSend = c.encode(c.generateDHKey())
 				return
 			}
@@ -343,7 +343,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
 			if err = c.processDHCommit(msg); err != nil {
 				return
 			}
-			c.myKeyId = 0
+			c.reset()
 			toSend = c.encode(c.generateDHKey())
 			c.authState = authStateAwaitingRevealSig
 		default:
@@ -1036,8 +1036,7 @@ func (c *Conversation) calcDataKeys(myKeyId, theirKeyId uint32) (slot *keySlot,
 		}
 	}
 	if slot == nil {
-		err = errors.New("otr: internal error: no key slots")
-		return
+		return nil, errors.New("otr: internal error: no more key slots")
 	}
 
 	var myPriv, myPub, theirPub *big.Int
@@ -1163,6 +1162,14 @@ func (c *Conversation) encode(msg []byte) [][]byte {
 	return ret
 }
 
+func (c *Conversation) reset() {
+	c.myKeyId = 0
+
+	for i := range c.keySlots {
+		c.keySlots[i].used = false
+	}
+}
+
 type PublicKey struct {
 	dsa.PublicKey
 }

+ 110 - 60
otr/otr_test.go

@@ -121,11 +121,12 @@ func TestSignVerify(t *testing.T) {
 	}
 }
 
-func TestConversation(t *testing.T) {
+func setupConversation(t *testing.T) (alice, bob *Conversation) {
 	alicePrivateKey, _ := hex.DecodeString(alicePrivateKeyHex)
 	bobPrivateKey, _ := hex.DecodeString(bobPrivateKeyHex)
 
-	var alice, bob Conversation
+	alice, bob = new(Conversation), new(Conversation)
+
 	alice.PrivateKey = new(PrivateKey)
 	bob.PrivateKey = new(PrivateKey)
 	alice.PrivateKey.Parse(alicePrivateKey)
@@ -133,12 +134,6 @@ func TestConversation(t *testing.T) {
 	alice.FragmentSize = 100
 	bob.FragmentSize = 100
 
-	var alicesMessage, bobsMessage [][]byte
-	var out []byte
-	var aliceChange, bobChange SecurityChange
-	var err error
-	alicesMessage = append(alicesMessage, []byte(QueryMessage))
-
 	if alice.IsEncrypted() {
 		t.Error("Alice believes that the conversation is secure before we've started")
 	}
@@ -146,6 +141,17 @@ func TestConversation(t *testing.T) {
 		t.Error("Bob believes that the conversation is secure before we've started")
 	}
 
+	performHandshake(t, alice, bob)
+	return alice, bob
+}
+
+func performHandshake(t *testing.T, alice, bob *Conversation) {
+	var alicesMessage, bobsMessage [][]byte
+	var out []byte
+	var aliceChange, bobChange SecurityChange
+	var err error
+	alicesMessage = append(alicesMessage, []byte(QueryMessage))
+
 	for round := 0; len(alicesMessage) > 0 || len(bobsMessage) > 0; round++ {
 		bobsMessage = nil
 		for i, msg := range alicesMessage {
@@ -193,80 +199,109 @@ func TestConversation(t *testing.T) {
 	if !bob.IsEncrypted() {
 		t.Error("Bob doesn't believe that the conversation is secure")
 	}
+}
+
+const (
+	firstRoundTrip = iota
+	subsequentRoundTrip
+	noMACKeyCheck
+)
+
+func roundTrip(t *testing.T, alice, bob *Conversation, message []byte, macKeyCheck int) {
+	alicesMessage, err := alice.Send(message)
+	if err != nil {
+		t.Errorf("Error from Alice sending message: %s", err)
+	}
 
-	var testMessages = [][]byte{
-		[]byte("hello"), []byte("bye"),
+	if len(alice.oldMACs) != 0 {
+		t.Errorf("Alice has not revealed all MAC keys")
 	}
 
-	for j, testMessage := range testMessages {
-		alicesMessage, err = alice.Send(testMessage)
+	for i, msg := range alicesMessage {
+		out, encrypted, _, _, err := bob.Receive(msg)
 
-		if len(alice.oldMACs) != 0 {
-			t.Errorf("Alice has not revealed all MAC keys")
+		if err != nil {
+			t.Errorf("Error generated while processing test message: %s", err.Error())
 		}
-
-		for i, msg := range alicesMessage {
-			out, encrypted, _, _, err := bob.Receive(msg)
-
-			if err != nil {
-				t.Errorf("Error generated while processing test message: %s", err.Error())
+		if len(out) > 0 {
+			if i != len(alicesMessage)-1 {
+				t.Fatal("Bob produced a message while processing a fragment of Alice's")
 			}
-			if len(out) > 0 {
-				if i != len(alicesMessage)-1 {
-					t.Fatal("Bob produced a message while processing a fragment of Alice's")
-				}
-				if !encrypted {
-					t.Errorf("Message was not marked as encrypted")
-				}
-				if !bytes.Equal(out, testMessage) {
-					t.Errorf("Message corrupted: got %x, want %x", out, testMessage)
-				}
+			if !encrypted {
+				t.Errorf("Message was not marked as encrypted")
 			}
-		}
-
-		if j == 0 {
-			if len(bob.oldMACs) != 0 {
-				t.Errorf("Bob should not have MAC keys to reveal")
+			if !bytes.Equal(out, message) {
+				t.Errorf("Message corrupted: got %x, want %x", out, message)
 			}
-		} else if len(bob.oldMACs) != 40 {
-			t.Errorf("Bob does not have MAC keys to reveal")
 		}
+	}
 
-		bobsMessage, err = bob.Send(testMessage)
-
+	switch macKeyCheck {
+	case firstRoundTrip:
 		if len(bob.oldMACs) != 0 {
-			t.Errorf("Bob has not revealed all MAC keys")
+			t.Errorf("Bob should not have MAC keys to reveal")
 		}
+	case subsequentRoundTrip:
+		if len(bob.oldMACs) != 40 {
+			t.Errorf("Bob has %d bytes of MAC keys to reveal, but should have 40", len(bob.oldMACs))
+		}
+	}
 
-		for i, msg := range bobsMessage {
-			out, encrypted, _, _, err := alice.Receive(msg)
+	bobsMessage, err := bob.Send(message)
+	if err != nil {
+		t.Errorf("Error from Bob sending message: %s", err)
+	}
 
-			if err != nil {
-				t.Errorf("Error generated while processing test message: %s", err.Error())
+	if len(bob.oldMACs) != 0 {
+		t.Errorf("Bob has not revealed all MAC keys")
+	}
+
+	for i, msg := range bobsMessage {
+		out, encrypted, _, _, err := alice.Receive(msg)
+
+		if err != nil {
+			t.Errorf("Error generated while processing test message: %s", err.Error())
+		}
+		if len(out) > 0 {
+			if i != len(bobsMessage)-1 {
+				t.Fatal("Alice produced a message while processing a fragment of Bob's")
 			}
-			if len(out) > 0 {
-				if i != len(bobsMessage)-1 {
-					t.Fatal("Alice produced a message while processing a fragment of Bob's")
-				}
-				if !encrypted {
-					t.Errorf("Message was not marked as encrypted")
-				}
-				if !bytes.Equal(out, testMessage) {
-					t.Errorf("Message corrupted: got %x, want %x", out, testMessage)
-				}
+			if !encrypted {
+				t.Errorf("Message was not marked as encrypted")
+			}
+			if !bytes.Equal(out, message) {
+				t.Errorf("Message corrupted: got %x, want %x", out, message)
 			}
 		}
+	}
 
-		if j == 0 {
-			if len(alice.oldMACs) != 20 {
-				t.Errorf("Alice does not have MAC keys to reveal")
-			}
-		} else if len(alice.oldMACs) != 40 {
-			t.Errorf("Alice does not have MAC keys to reveal")
+	switch macKeyCheck {
+	case firstRoundTrip:
+		if len(alice.oldMACs) != 20 {
+			t.Errorf("Alice has %d bytes of MAC keys to reveal, but should have 20", len(alice.oldMACs))
+		}
+	case subsequentRoundTrip:
+		if len(alice.oldMACs) != 40 {
+			t.Errorf("Alice has %d bytes of MAC keys to reveal, but should have 40", len(alice.oldMACs))
 		}
 	}
 }
 
+func TestConversation(t *testing.T) {
+	alice, bob := setupConversation(t)
+
+	var testMessages = [][]byte{
+		[]byte("hello"), []byte("bye"),
+	}
+
+	roundTripType := firstRoundTrip
+
+	for _, testMessage := range testMessages {
+		roundTrip(t, alice, bob, testMessage, roundTripType)
+		roundTripType = subsequentRoundTrip
+	}
+}
+
 func TestGoodSMP(t *testing.T) {
 	var alice, bob Conversation
 
@@ -348,6 +383,21 @@ func TestBadSMP(t *testing.T) {
 	}
 }
 
+func TestRehandshaking(t *testing.T) {
+	alice, bob := setupConversation(t)
+	roundTrip(t, alice, bob, []byte("test"), firstRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 2"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 3"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 4"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 5"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 6"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 7"), subsequentRoundTrip)
+	roundTrip(t, alice, bob, []byte("test 8"), subsequentRoundTrip)
+	performHandshake(t, alice, bob)
+	roundTrip(t, alice, bob, []byte("test"), noMACKeyCheck)
+	roundTrip(t, alice, bob, []byte("test 2"), noMACKeyCheck)
+}
+
 func TestAgainstLibOTR(t *testing.T) {
 	// This test requires otr.c.test to be built as /tmp/a.out.
 	// If enabled, this tests runs forever performing OTR handshakes in a