Ver Fonte

refactor of etype

Jonathan Turner há 9 anos atrás
pai
commit
4c6f172d5d

+ 67 - 14
crypto/aes/aes.go

@@ -2,10 +2,9 @@
 package aes
 
 import (
-	"bytes"
 	"crypto/aes"
 	"crypto/cipher"
-	"crypto/sha1"
+	"crypto/hmac"
 	"encoding/binary"
 	"encoding/hex"
 	"errors"
@@ -19,39 +18,93 @@ const (
 	s2kParamsZero = 4294967296
 )
 
-func stringToKey(secret, salt, s2kparams string, e etype.EType) ([]byte, error) {
+func s2kparamsToItertions(s2kparams string) (int, error) {
 	//process s2kparams string
 	//The parameter string is four octets indicating an unsigned
 	//number in big-endian order.  This is the number of iterations to be
 	//performed.  If the value is 00 00 00 00, the number of iterations to
 	//be performed is 4,294,967,296 (2**32).
-	var i int32
+	var i uint32
 	if len(s2kparams) != 8 {
-		return nil, errors.New("Invalid s2kparams length")
+		return s2kParamsZero, errors.New("Invalid s2kparams length")
 	}
 	b, err := hex.DecodeString(s2kparams)
 	if err != nil {
-		return nil, errors.New("Invalid s2kparams, cannot decode string to bytes")
+		return s2kParamsZero, errors.New("Invalid s2kparams, cannot decode string to bytes")
 	}
-	buf := bytes.NewBuffer(b)
-	err = binary.Read(buf, binary.BigEndian, &i)
+	i = binary.BigEndian.Uint32(b)
+	//buf := bytes.NewBuffer(b)
+	//err = binary.Read(buf, binary.BigEndian, &i)
 	if err != nil {
-		return nil, errors.New("Invalid s2kparams, cannot convert to big endian int32")
+		return s2kParamsZero, errors.New("Invalid s2kparams, cannot convert to big endian int32")
 	}
-	if i == 0 {
-		return stringToKeyIter(secret, salt, s2kParamsZero, e)
+	return int(i), nil
+}
+
+func IterationsToS2kparams(i int) string {
+	b := make([]byte, 4, 4)
+	binary.BigEndian.PutUint32(b, uint32(i))
+	return hex.EncodeToString(b)
+}
+
+func stringToKey(secret, salt, s2kparams string, e etype.EType) ([]byte, error) {
+	i, err := s2kparamsToItertions(s2kparams)
+	if err != nil {
+		return nil, err
 	}
 	return stringToKeyIter(secret, salt, int(i), e)
 }
 
+func stringToKeySHA2(secret, salt, s2kparams string, e etype.EType) ([]byte, error) {
+	i, err := s2kparamsToItertions(s2kparams)
+	if err != nil {
+		return nil, err
+	}
+	return stringToKeySHA2Iter(secret, salt, int(i), e), nil
+}
+
 func stringToPBKDF2(secret, salt string, iterations int, e etype.EType) []byte {
-	return pbkdf2.Key([]byte(secret), []byte(salt), iterations, e.GetKeyByteSize(), sha1.New)
+	return pbkdf2.Key([]byte(secret), []byte(salt), iterations, e.GetKeyByteSize(), e.GetHash())
 }
 
 func stringToKeyIter(secret, salt string, iterations int, e etype.EType) ([]byte, error) {
 	tkey := randomToKey(stringToPBKDF2(secret, salt, iterations, e))
-	key, err := deriveKey(tkey, []byte("kerberos"), e)
-	return key, err
+	return deriveKey(tkey, []byte("kerberos"), e)
+}
+
+func stringToKeySHA2Iter(secret, salt string, iterations int, e etype.EType) []byte {
+	tkey := randomToKey(stringToPBKDF2(secret, salt, iterations, e))
+	return deriveKeyKDF_HMAC_SHA2(tkey, []byte("kerberos"), e)
+}
+
+//https://tools.ietf.org/html/rfc8009#section-3
+func KDF_HMAC_SHA2(protocolKey, label, context []byte, kl int, e etype.EType) []byte {
+	//k: Length in bits of the key to be outputted, expressed in big-endian binary representation in 4 bytes.
+	k := make([]byte, 4, 4)
+	binary.BigEndian.PutUint32(k, uint32(kl))
+
+	c := make([]byte, 4, 4)
+	binary.BigEndian.PutUint32(c, uint32(1))
+	c = append(c, label...)
+	c = append(c, byte(uint8(0)))
+	if len(context) > 0 {
+		c = append(c, context...)
+	}
+	c = append(c, k...)
+
+	mac := hmac.New(e.GetHash(), protocolKey)
+	mac.Write(c)
+	return mac.Sum(nil)[:(kl / 8)]
+}
+
+func deriveKeyKDF_HMAC_SHA2(protocolKey, label []byte, e etype.EType) []byte {
+	var context []byte
+	return KDF_HMAC_SHA2(protocolKey, label, context, e.GetKeySeedBitLength(), e)
+}
+
+func deriveRandomKDF_HMAC_SHA2(protocolKey, usage []byte, e etype.EType) ([]byte, error) {
+	h := e.GetHash()()
+	return KDF_HMAC_SHA2(protocolKey, []byte("prf"), usage, h.Size(), e), nil
 }
 
 func randomToKey(b []byte) []byte {

+ 63 - 8
crypto/aes/aes128-cts-hmac-sha1-96.go

@@ -2,7 +2,10 @@ package aes
 
 import (
 	"crypto/aes"
+	"crypto/rand"
 	"crypto/sha1"
+	"errors"
+	"fmt"
 	"github.com/jcmturner/gokrb5/crypto/engine"
 	"github.com/jcmturner/gokrb5/iana/chksumtype"
 	"github.com/jcmturner/gokrb5/iana/etypeID"
@@ -74,8 +77,8 @@ func (e Aes128CtsHmacSha96) GetKeySeedBitLength() int {
 	return e.GetKeyByteSize() * 8
 }
 
-func (e Aes128CtsHmacSha96) GetHash() hash.Hash {
-	return sha1.New()
+func (e Aes128CtsHmacSha96) GetHash() func() hash.Hash {
+	return sha1.New
 }
 
 func (e Aes128CtsHmacSha96) GetMessageBlockByteSize() int {
@@ -106,13 +109,65 @@ func (e Aes128CtsHmacSha96) RandomToKey(b []byte) []byte {
 	return randomToKey(b)
 }
 
-func (e Aes128CtsHmacSha96) Encrypt(key, message []byte) ([]byte, []byte, error) {
+func (e Aes128CtsHmacSha96) EncryptData(key, data []byte) ([]byte, []byte, error) {
 	ivz := make([]byte, aes.BlockSize)
-	return encryptCTS(key, ivz, message, e)
-}
-
-func (e Aes128CtsHmacSha96) Decrypt(key, ciphertext []byte) ([]byte, error) {
-	return decryptCTS(key, ciphertext, e)
+	return encryptCTS(key, ivz, data, e)
+}
+
+func (e Aes128CtsHmacSha96) EncryptMessage(key, message []byte, usage uint32) ([]byte, []byte, error) {
+	//confounder
+	c := make([]byte, e.GetConfounderByteSize())
+	_, err := rand.Read(c)
+	if err != nil {
+		return []byte{}, []byte{}, fmt.Errorf("Could not generate random confounder: %v", err)
+	}
+	plainBytes := append(c, message...)
+
+	// Derive key for encryption from usage
+	var k []byte
+	if usage != 0 {
+		k, err = e.DeriveKey(key, engine.GetUsageKe(usage))
+		if err != nil {
+			return []byte{}, []byte{}, fmt.Errorf("Error deriving key for encryption: %v", err)
+		}
+	}
+
+	// Encrypt the data
+	iv, b, err := e.EncryptData(k, plainBytes)
+	if err != nil {
+		return iv, b, fmt.Errorf("Error encrypting data: %v", err)
+	}
+
+	// Generate and append integrity hash
+	ih, err := engine.GetIntegrityHash(plainBytes, key, usage, e)
+	if err != nil {
+		return iv, b, fmt.Errorf("Error encrypting data: %v", err)
+	}
+	b = append(b, ih...)
+	return iv, b, nil
+}
+
+func (e Aes128CtsHmacSha96) DecryptData(key, data []byte) ([]byte, error) {
+	return decryptCTS(key, data, e)
+}
+
+func (e Aes128CtsHmacSha96) DecryptMessage(key, ciphertext []byte, usage uint32) ([]byte, error) {
+	//Derive the key
+	k, err := e.DeriveKey(key, engine.GetUsageKe(usage))
+	if err != nil {
+		return nil, fmt.Errorf("Error deriving key: %v", err)
+	}
+	// Strip off the checksum from the end
+	b, err := e.DecryptData(k, ciphertext[:len(ciphertext)-e.GetHMACBitLength()/8])
+	if err != nil {
+		return nil, err
+	}
+	//Verify checksum
+	if !e.VerifyIntegrity(key, ciphertext, b, usage) {
+		return nil, errors.New("Integrity verification failed")
+	}
+	//Remove the confounder bytes
+	return b[e.GetConfounderByteSize():], nil
 }
 
 func (e Aes128CtsHmacSha96) DeriveKey(protocolKey, usage []byte) ([]byte, error) {

+ 94 - 0
crypto/aes/aes128-cts-hmac-sha256-128.go

@@ -0,0 +1,94 @@
+// +build disabled
+
+package aes
+
+import (
+	"crypto/aes"
+	"crypto/sha256"
+	"github.com/jcmturner/gokrb5/crypto/engine"
+	"github.com/jcmturner/gokrb5/iana/chksumtype"
+	"github.com/jcmturner/gokrb5/iana/etypeID"
+	"hash"
+)
+
+// RFC https://tools.ietf.org/html/rfc8009
+
+type Aes128CtsHmacSha256128 struct {
+}
+
+func (e Aes128CtsHmacSha256128) GetETypeID() int {
+	return etypeID.AES128_CTS_HMAC_SHA256_128
+}
+
+func (e Aes128CtsHmacSha256128) GetHashID() int {
+	return chksumtype.HMAC_SHA256_128_AES128
+}
+
+func (e Aes128CtsHmacSha256128) GetKeyByteSize() int {
+	return 128 / 8
+}
+
+func (e Aes128CtsHmacSha256128) GetKeySeedBitLength() int {
+	return e.GetKeyByteSize() * 8
+}
+
+func (e Aes128CtsHmacSha256128) GetHash() func() hash.Hash {
+	return sha256.New
+}
+
+func (e Aes128CtsHmacSha256128) GetMessageBlockByteSize() int {
+	return 1
+}
+
+func (e Aes128CtsHmacSha256128) GetDefaultStringToKeyParams() string {
+	return "00008000"
+}
+
+func (e Aes128CtsHmacSha256128) GetConfounderByteSize() int {
+	return aes.BlockSize
+}
+
+func (e Aes128CtsHmacSha256128) GetHMACBitLength() int {
+	return 128
+}
+
+func (e Aes128CtsHmacSha256128) GetCypherBlockBitLength() int {
+	return aes.BlockSize * 8
+}
+
+func (e Aes128CtsHmacSha256128) StringToKey(secret string, salt string, s2kparams string) ([]byte, error) {
+	saltp := e.getSaltP(salt)
+	return stringToKeySHA2(secret, saltp, s2kparams, e)
+}
+
+func (e Aes128CtsHmacSha256128) getSaltP(salt string) string {
+	b := []byte("aes128-cts-hmac-sha256-128")
+	b = append(b, byte(uint8(0)))
+	b = append(b, []byte(salt)...)
+	return string(b)
+}
+
+func (e Aes128CtsHmacSha256128) RandomToKey(b []byte) []byte {
+	return randomToKey(b)
+}
+
+func (e Aes128CtsHmacSha256128) EncryptData(key, message []byte) ([]byte, []byte, error) {
+	ivz := make([]byte, aes.BlockSize)
+	return encryptCTS(key, ivz, message, e)
+}
+
+func (e Aes128CtsHmacSha256128) DecryptData(key, ciphertext []byte) ([]byte, error) {
+	return decryptCTS(key, ciphertext, e)
+}
+
+func (e Aes128CtsHmacSha256128) DeriveKey(protocolKey, usage []byte) ([]byte, error) {
+	return deriveKeyKDF_HMAC_SHA2(protocolKey, usage, e), nil
+}
+
+func (e Aes128CtsHmacSha256128) DeriveRandom(protocolKey, usage []byte) ([]byte, error) {
+	return deriveRandomKDF_HMAC_SHA2(protocolKey, usage, e)
+}
+
+func (e Aes128CtsHmacSha256128) VerifyIntegrity(protocolKey, ct, pt []byte, usage uint32) bool {
+	return engine.VerifyIntegrity(protocolKey, ct, pt, usage, e)
+}

+ 140 - 0
crypto/aes/aes128-cts-hmac-sha256-128_test.go

@@ -0,0 +1,140 @@
+// +build disabled
+
+package aes
+
+import (
+	"encoding/hex"
+	"github.com/jcmturner/gokrb5/crypto/engine"
+	"github.com/stretchr/testify/assert"
+	"testing"
+)
+
+//func TestAesCtsHmacSha196_Encrypt_Decrypt(t *testing.T) {
+//	iv := make([]byte, 16)
+//	key, _ := hex.DecodeString("636869636b656e207465726979616b69")
+//	var tests = []struct {
+//		plain  string
+//		cipher string
+//		nextIV string
+//	}{
+//		//Test vectors from RFC 3962 Appendix B
+//		{"4920776f756c64206c696b652074686520", "c6353568f2bf8cb4d8a580362da7ff7f97", "c6353568f2bf8cb4d8a580362da7ff7f"},
+//		{"4920776f756c64206c696b65207468652047656e6572616c20476175277320", "fc00783e0efdb2c1d445d4c8eff7ed2297687268d6ecccc0c07b25e25ecfe5", "fc00783e0efdb2c1d445d4c8eff7ed22"},
+//		{"4920776f756c64206c696b65207468652047656e6572616c2047617527732043", "39312523a78662d5be7fcbcc98ebf5a897687268d6ecccc0c07b25e25ecfe584", "39312523a78662d5be7fcbcc98ebf5a8"},
+//		{"4920776f756c64206c696b65207468652047656e6572616c20476175277320436869636b656e2c20706c656173652c", "97687268d6ecccc0c07b25e25ecfe584b3fffd940c16a18c1b5549d2f838029e39312523a78662d5be7fcbcc98ebf5", "b3fffd940c16a18c1b5549d2f838029e"},
+//		{"4920776f756c64206c696b65207468652047656e6572616c20476175277320436869636b656e2c20706c656173652c20", "97687268d6ecccc0c07b25e25ecfe5849dad8bbb96c4cdc03bc103e1a194bbd839312523a78662d5be7fcbcc98ebf5a8", "9dad8bbb96c4cdc03bc103e1a194bbd8"},
+//		{"4920776f756c64206c696b65207468652047656e6572616c20476175277320436869636b656e2c20706c656173652c20616e6420776f6e746f6e20736f75702e", "97687268d6ecccc0c07b25e25ecfe58439312523a78662d5be7fcbcc98ebf5a84807efe836ee89a526730dbc2f7bc8409dad8bbb96c4cdc03bc103e1a194bbd8", "4807efe836ee89a526730dbc2f7bc840"},
+//	}
+//	var e Aes128CtsHmacSha96
+//	for i, test := range tests {
+//		m, _ := hex.DecodeString(test.plain)
+//		niv, c, err := encryptCTS(key, iv, m, e)
+//		if err != nil {
+//			t.Errorf("Encryption failed for test %v: %v", i+1, err)
+//		}
+//		assert.Equal(t, test.cipher, hex.EncodeToString(c), "Encrypted result not as expected")
+//		assert.Equal(t, test.nextIV, hex.EncodeToString(niv), "Next state IV not as expected")
+//	}
+//	//t.Log("AES CTS Encryption tests finished")
+//	for i, test := range tests {
+//		b, _ := hex.DecodeString(test.cipher)
+//		p, err := decryptCTS(key, b, e)
+//		if err != nil {
+//			t.Errorf("Decryption failed for test %v: %v", i+1, err)
+//		}
+//		assert.Equal(t, test.plain, hex.EncodeToString(p), "Decrypted result not as expected")
+//	}
+//	//t.Log("AES CTS Decryption tests finished")
+//}
+
+func TestAes128CtsHmacSha256128_StringToKey(t *testing.T) {
+	// Test vectors from RFC 8009 Appendix A
+	// Random 16bytes in test vector as string
+	r, _ := hex.DecodeString("10DF9DD783E5BC8ACEA1730E74355F61")
+	s := string(r)
+	var tests = []struct {
+		iterations int
+		phrase     string
+		salt       string
+		saltp      string
+		key        string
+	}{
+		{32768, "password", s + "ATHENA.MIT.EDUraeburn", "6165733132382d6374732d686d61632d7368613235362d3132380010df9dd783e5bc8acea1730e74355f61415448454e412e4d49542e4544557261656275726e", "089bca48b105ea6ea77ca5d2f39dc5e7"},
+	}
+	var e Aes128CtsHmacSha256128
+	for _, test := range tests {
+		saltp := e.getSaltP(test.salt)
+		assert.Equal(t, test.saltp, hex.EncodeToString(([]byte(saltp))), "SaltP not as expected")
+
+		k, _ := e.StringToKey(test.phrase, test.salt, IterationsToS2kparams(test.iterations))
+		assert.Equal(t, test.key, hex.EncodeToString(k), "String to Key not as expected")
+
+	}
+}
+
+func TestAes128CtsHmacSha256128_DeriveKey(t *testing.T) {
+	// Test vectors from RFC 8009 Appendix A
+	protocolBaseKey, _ := hex.DecodeString("3705d96080c17728a0e800eab6e0d23c")
+	testUsage := uint32(2)
+	var e Aes128CtsHmacSha256128
+	k, err := e.DeriveKey(protocolBaseKey, engine.GetUsageKc(testUsage))
+	if err != nil {
+		t.Fatalf("Error deriving checksum key: %v", err)
+	}
+	assert.Equal(t, "b31a018a48f54776f403e9a396325dc3", hex.EncodeToString(k), "Checksum derived key not as epxected")
+	k, err = e.DeriveKey(protocolBaseKey, engine.GetUsageKe(testUsage))
+	if err != nil {
+		t.Fatalf("Error deriving encryption key: %v", err)
+	}
+	assert.Equal(t, "9b197dd1e8c5609d6e67c3e37c62c72e", hex.EncodeToString(k), "Encryption derived key not as epxected")
+	k, err = e.DeriveKey(protocolBaseKey, engine.GetUsageKi(testUsage))
+	if err != nil {
+		t.Fatalf("Error deriving integrity key: %v", err)
+	}
+	assert.Equal(t, "9fda0e56ab2d85e1569a688696c26a6c", hex.EncodeToString(k), "Integrity derived key not as epxected")
+}
+
+func TestAes128CtsHmacSha256128_Cypto(t *testing.T) {
+	protocolBaseKey, _ := hex.DecodeString("3705d96080c17728a0e800eab6e0d23c")
+	testUsage := uint32(2)
+	var tests = []struct {
+		plain      string
+		confounder string
+		ke         string
+		ki         string
+		encrypted  string // AESOutput
+		hash       string // TruncatedHMACOutput
+		cipher     string // Ciphertext(AESOutput|HMACOutput)
+	}{
+		// Test vectors from RFC 8009 Appendix A
+		{"", "7e5895eaf2672435bad817f545a37148", "9b197dd1e8c5609d6e67c3e37c62c72e", "9fda0e56ab2d85e1569a688696c26a6c", "ef85fb890bb8472f4dab20394dca781d", "ad877eda39d50c870c0d5a0a8e48c718", "ef85fb890bb8472f4dab20394dca781dad877eda39d50c870c0d5a0a8e48c718"},
+		{"000102030405", "7bca285e2fd4130fb55b1a5c83bc5b24", "9b197dd1e8c5609d6e67c3e37c62c72e", "9fda0e56ab2d85e1569a688696c26a6c", "84d7f30754ed987bab0bf3506beb09cfb55402cef7e6", "877ce99e247e52d16ed4421dfdf8976c", "84d7f30754ed987bab0bf3506beb09cfb55402cef7e6877ce99e247e52d16ed4421dfdf8976c"},
+	}
+	var e Aes128CtsHmacSha256128
+	for i, test := range tests {
+		m, _ := hex.DecodeString(test.plain)
+		b, _ := hex.DecodeString(test.encrypted)
+		ke, _ := hex.DecodeString(test.ke)
+		cf, _ := hex.DecodeString(test.confounder)
+		cfm := append(cf, m...)
+
+		_, c, err := e.Encrypt(ke, cfm)
+		if err != nil {
+			t.Errorf("Encryption failed for test %v: %v", i+1, err)
+		}
+		assert.Equal(t, test.encrypted, hex.EncodeToString(c), "Encrypted result not as expected - test %v", i)
+
+		ivz := make([]byte, e.GetConfounderByteSize())
+		hm := append(ivz, b...)
+		mac, _ := engine.GetIntegrityHash(hm, protocolBaseKey, testUsage, e)
+		assert.Equal(t, test.hash, hex.EncodeToString(mac), "HMAC result not as expected - test %v", i)
+
+		p, err := e.Decrypt(ke, b)
+		//Remove the confounder bytes
+		p = p[e.GetConfounderByteSize():]
+		if err != nil {
+			t.Errorf("Decryption failed for test %v: %v", i+1, err)
+		}
+		assert.Equal(t, test.plain, hex.EncodeToString(p), "Decrypted result not as expected - test %v", i)
+	}
+}

+ 63 - 8
crypto/aes/aes256-cts-hmac-sha1-96.go

@@ -2,7 +2,10 @@ package aes
 
 import (
 	"crypto/aes"
+	"crypto/rand"
 	"crypto/sha1"
+	"errors"
+	"fmt"
 	"github.com/jcmturner/gokrb5/crypto/engine"
 	"github.com/jcmturner/gokrb5/iana/chksumtype"
 	"github.com/jcmturner/gokrb5/iana/etypeID"
@@ -74,8 +77,8 @@ func (e Aes256CtsHmacSha96) GetKeySeedBitLength() int {
 	return e.GetKeyByteSize() * 8
 }
 
-func (e Aes256CtsHmacSha96) GetHash() hash.Hash {
-	return sha1.New()
+func (e Aes256CtsHmacSha96) GetHash() func() hash.Hash {
+	return sha1.New
 }
 
 func (e Aes256CtsHmacSha96) GetMessageBlockByteSize() int {
@@ -106,13 +109,65 @@ func (e Aes256CtsHmacSha96) RandomToKey(b []byte) []byte {
 	return randomToKey(b)
 }
 
-func (e Aes256CtsHmacSha96) Encrypt(key, message []byte) ([]byte, []byte, error) {
+func (e Aes256CtsHmacSha96) EncryptData(key, data []byte) ([]byte, []byte, error) {
 	ivz := make([]byte, aes.BlockSize)
-	return encryptCTS(key, ivz, message, e)
-}
-
-func (e Aes256CtsHmacSha96) Decrypt(key, ciphertext []byte) ([]byte, error) {
-	return decryptCTS(key, ciphertext, e)
+	return encryptCTS(key, ivz, data, e)
+}
+
+func (e Aes256CtsHmacSha96) EncryptMessage(key, message []byte, usage uint32) ([]byte, []byte, error) {
+	//confounder
+	c := make([]byte, e.GetConfounderByteSize())
+	_, err := rand.Read(c)
+	if err != nil {
+		return []byte{}, []byte{}, fmt.Errorf("Could not generate random confounder: %v", err)
+	}
+	plainBytes := append(c, message...)
+
+	// Derive key for encryption from usage
+	var k []byte
+	if usage != 0 {
+		k, err = e.DeriveKey(key, engine.GetUsageKe(usage))
+		if err != nil {
+			return []byte{}, []byte{}, fmt.Errorf("Error deriving key for encryption: %v", err)
+		}
+	}
+
+	// Encrypt the data
+	iv, b, err := e.EncryptData(k, plainBytes)
+	if err != nil {
+		return iv, b, fmt.Errorf("Error encrypting data: %v", err)
+	}
+
+	// Generate and append integrity hash
+	ih, err := engine.GetIntegrityHash(plainBytes, key, usage, e)
+	if err != nil {
+		return iv, b, fmt.Errorf("Error encrypting data: %v", err)
+	}
+	b = append(b, ih...)
+	return iv, b, nil
+}
+
+func (e Aes256CtsHmacSha96) DecryptData(key, data []byte) ([]byte, error) {
+	return decryptCTS(key, data, e)
+}
+
+func (e Aes256CtsHmacSha96) DecryptMessage(key, ciphertext []byte, usage uint32) ([]byte, error) {
+	//Derive the key
+	k, err := e.DeriveKey(key, engine.GetUsageKe(usage))
+	if err != nil {
+		return nil, fmt.Errorf("Error deriving key: %v", err)
+	}
+	// Strip off the checksum from the end
+	b, err := e.DecryptData(k, ciphertext[:len(ciphertext)-e.GetHMACBitLength()/8])
+	if err != nil {
+		return nil, err
+	}
+	//Verify checksum
+	if !e.VerifyIntegrity(key, ciphertext, b, usage) {
+		return nil, errors.New("Integrity verification failed")
+	}
+	//Remove the confounder bytes
+	return b[e.GetConfounderByteSize():], nil
 }
 
 func (e Aes256CtsHmacSha96) DeriveKey(protocolKey, usage []byte) ([]byte, error) {

+ 9 - 56
crypto/crypto.go

@@ -2,12 +2,9 @@
 package crypto
 
 import (
-	"crypto/rand"
 	"encoding/hex"
-	"errors"
 	"fmt"
 	"github.com/jcmturner/gokrb5/crypto/aes"
-	"github.com/jcmturner/gokrb5/crypto/engine"
 	"github.com/jcmturner/gokrb5/crypto/etype"
 	"github.com/jcmturner/gokrb5/iana/chksumtype"
 	"github.com/jcmturner/gokrb5/iana/etypeID"
@@ -23,6 +20,9 @@ func GetEtype(id int) (etype.EType, error) {
 	case etypeID.AES256_CTS_HMAC_SHA1_96:
 		var et aes.Aes256CtsHmacSha96
 		return et, nil
+	//case etypeID.AES128_CTS_HMAC_SHA256_128:
+	//	var et aes.Aes128CtsHmacSha256128
+	//	return et, nil
 	default:
 		return nil, fmt.Errorf("Unknown or unsupported EType: %d", id)
 	}
@@ -115,26 +115,11 @@ func GetEncryptedData(plainBytes []byte, key types.EncryptionKey, usage uint32,
 	if err != nil {
 		return ed, fmt.Errorf("Error getting etype: %v", err)
 	}
-	k := key.KeyValue
-	if usage != 0 {
-		k, err = et.DeriveKey(key.KeyValue, engine.GetUsageKe(uint32(usage)))
-		if err != nil {
-			return ed, fmt.Errorf("Error deriving key: %v", err)
-		}
-	}
-	//confounder
-	c := make([]byte, et.GetConfounderByteSize())
-	_, err = rand.Read(c)
+	_, b, err := et.EncryptMessage(key.KeyValue, plainBytes, usage)
 	if err != nil {
-		return ed, fmt.Errorf("Could not generate random confounder: %v", err)
+		return ed, err
 	}
-	plainBytes = append(c, plainBytes...)
-	_, b, err := et.Encrypt(k, plainBytes)
-	if err != nil {
-		return ed, fmt.Errorf("Error encrypting data: %v", err)
-	}
-	ih, err := engine.GetIntegrityHash(plainBytes, key.KeyValue, usage, et)
-	b = append(b, ih...)
+
 	ed = types.EncryptedData{
 		EType:  key.KeyType,
 		Cipher: b,
@@ -144,46 +129,14 @@ func GetEncryptedData(plainBytes []byte, key types.EncryptionKey, usage uint32,
 }
 
 func DecryptEncPart(ed types.EncryptedData, key types.EncryptionKey, usage uint32) ([]byte, error) {
-	//Derive the key
-	et, err := GetEtype(key.KeyType)
-	k, err := et.DeriveKey(key.KeyValue, engine.GetUsageKe(usage))
-	if err != nil {
-		return nil, fmt.Errorf("Error deriving key: %v", err)
-	}
-	// Strip off the checksum from the end
-	b, err := et.Decrypt(k, ed.Cipher[:len(ed.Cipher)-et.GetHMACBitLength()/8])
-	if err != nil {
-		return nil, fmt.Errorf("Error decrypting: %v", err)
-	}
-	//Verify checksum
-	if !et.VerifyIntegrity(key.KeyValue, ed.Cipher, b, usage) {
-		return nil, errors.New("Error decrypting encrypted part: integrity verification failed")
-	}
-	//Remove the confounder bytes
-	b = b[et.GetConfounderByteSize():]
-	if err != nil {
-		return nil, fmt.Errorf("Error decrypting encrypted part: %v", err)
-	}
-	return b, nil
+	return DecryptMessage(ed.Cipher, key, usage)
 }
 
-func DecryptBytes(ed []byte, key types.EncryptionKey, usage uint32) ([]byte, error) {
-	//Derive the key
+func DecryptMessage(ciphertext []byte, key types.EncryptionKey, usage uint32) ([]byte, error) {
 	et, err := GetEtype(key.KeyType)
-	k, err := et.DeriveKey(key.KeyValue, engine.GetUsageKe(usage))
-	if err != nil {
-		return nil, fmt.Errorf("Error deriving key: %v", err)
-	}
-	// Strip off the checksum from the end
-	b, err := et.Decrypt(k, ed[:len(ed)-et.GetHMACBitLength()/8])
+	b, err := et.DecryptMessage(key.KeyValue, ciphertext, usage)
 	if err != nil {
 		return nil, fmt.Errorf("Error decrypting: %v", err)
 	}
-	//Verify checksum
-	if !et.VerifyIntegrity(key.KeyValue, ed, b, usage) {
-		return nil, errors.New("Error decrypting: integrity verification failed")
-	}
-	//Remove the confounder bytes
-	b = b[et.GetConfounderByteSize():]
 	return b, nil
 }

+ 56 - 32
crypto/des3/des3-cbc-sha1-kd.go

@@ -4,6 +4,7 @@ package des3
 import (
 	"crypto/cipher"
 	"crypto/des"
+	"crypto/rand"
 	"crypto/sha1"
 	"errors"
 	"fmt"
@@ -69,8 +70,8 @@ func (e Des3CbcSha1Kd) GetKeySeedBitLength() int {
 	return 21 * 8
 }
 
-func (e Des3CbcSha1Kd) GetHash() hash.Hash {
-	return sha1.New()
+func (e Des3CbcSha1Kd) GetHash() func() hash.Hash {
+	return sha1.New
 }
 
 func (e Des3CbcSha1Kd) GetMessageBlockByteSize() int {
@@ -88,7 +89,7 @@ func (e Des3CbcSha1Kd) GetConfounderByteSize() int {
 }
 
 func (e Des3CbcSha1Kd) GetHMACBitLength() int {
-	return e.GetHash().Size()
+	return e.GetHash()().Size()
 }
 
 func (e Des3CbcSha1Kd) GetCypherBlockBitLength() int {
@@ -118,58 +119,81 @@ func (e Des3CbcSha1Kd) DeriveKey(protocolKey, usage []byte) ([]byte, error) {
 	return e.RandomToKey(r), nil
 }
 
-func (e Des3CbcSha1Kd) Encrypt(key, message []byte) ([]byte, []byte, error) {
+func (e Des3CbcSha1Kd) EncryptData(key, data []byte) ([]byte, []byte, error) {
 	if len(key) != e.GetKeyByteSize() {
 		return nil, nil, fmt.Errorf("Incorrect keysize: expected: %v actual: %v", e.GetKeySeedBitLength(), len(key))
 
 	}
-	if len(message)%e.GetMessageBlockByteSize() != 0 {
-		message, _ = engine.PKCS7Pad(message, e.GetMessageBlockByteSize())
-	}
+	data, _ = engine.ZeroPad(data, e.GetMessageBlockByteSize())
 
 	block, err := des.NewTripleDESCipher(key)
 	if err != nil {
 		return nil, nil, fmt.Errorf("Error creating cipher: %v", err)
-
 	}
 
 	//RFC 3961: initial cipher state      All bits zero
-	iv := make([]byte, e.GetConfounderByteSize())
-	//_, err = rand.Read(iv) //Not needed as all bits need to be zero
+	ivz := make([]byte, e.GetConfounderByteSize())
 
-	ct := make([]byte, len(message))
-	mode := cipher.NewCBCEncrypter(block, iv)
-	mode.CryptBlocks(ct, message)
-	return ct[:e.GetConfounderByteSize()], ct, nil
+	ct := make([]byte, len(data))
+	mode := cipher.NewCBCEncrypter(block, ivz)
+	mode.CryptBlocks(ct, data)
+	return ivz, ct, nil
 }
 
-func (e Des3CbcSha1Kd) Decrypt(key, ciphertext []byte) (message []byte, err error) {
-	if len(key) != e.GetKeySeedBitLength() {
-		err = fmt.Errorf("Incorrect keysize: expected: %v actual: %v", e.GetKeySeedBitLength(), len(key))
-		return
+func (e Des3CbcSha1Kd) EncryptMessage(key, message []byte, usage uint32) ([]byte, []byte, error) {
+	//confounder
+	c := make([]byte, e.GetConfounderByteSize())
+	_, err := rand.Read(c)
+	if err != nil {
+		return []byte{}, []byte{}, fmt.Errorf("Could not generate random confounder: %v", err)
 	}
+	plainBytes := append(c, message...)
 
-	if len(ciphertext) < des.BlockSize || len(ciphertext)%des.BlockSize != 0 {
-		err = errors.New("Ciphertext is not a multiple of the block size.")
-		return
+	iv, b, err := e.EncryptData(key, plainBytes)
+	if err != nil {
+		return iv, b, fmt.Errorf("Error encrypting data: %v", err)
 	}
 
-	block, err := des.NewTripleDESCipher(key)
+	// Generate and append integrity hash
+	ih, err := engine.GetIntegrityHash(plainBytes, key, usage, e)
 	if err != nil {
-		err = fmt.Errorf("Error creating cipher: %v", err)
-		return
+		return iv, b, fmt.Errorf("Error encrypting data: %v", err)
+	}
+	b = append(b, ih...)
+	return iv, b, nil
+}
+
+func (e Des3CbcSha1Kd) DecryptData(key, data []byte) ([]byte, error) {
+	if len(key) != e.GetKeySeedBitLength() {
+		return []byte{}, fmt.Errorf("Incorrect keysize: expected: %v actual: %v", e.GetKeySeedBitLength(), len(key))
 	}
 
-	iv := ciphertext[:e.GetConfounderByteSize()]
-	ciphertext = ciphertext[e.GetConfounderByteSize():]
+	if len(data) < des.BlockSize || len(data)%des.BlockSize != 0 {
+		return []byte{}, errors.New("Ciphertext is not a multiple of the block size.")
+	}
+	block, err := des.NewTripleDESCipher(key)
+	if err != nil {
+		return []byte{}, fmt.Errorf("Error creating cipher: %v", err)
+	}
+	pt := make([]byte, len(data))
+	ivz := make([]byte, e.GetConfounderByteSize())
+	mode := cipher.NewCBCDecrypter(block, ivz)
+	mode.CryptBlocks(pt, data)
+	return pt, nil
+}
 
-	mode := cipher.NewCBCDecrypter(block, iv)
-	mode.CryptBlocks(message, ciphertext)
-	m, er := engine.PKCS7Unpad(message, e.GetMessageBlockByteSize())
-	if er == nil {
-		message = m
+func (e Des3CbcSha1Kd) DecryptMessage(key, ciphertext []byte, usage uint32) (message []byte, err error) {
+	// Strip off the checksum from the end
+	b, err := e.DecryptData(key, ciphertext[:len(ciphertext)-e.GetHMACBitLength()/8])
+	if err != nil {
+		return nil, fmt.Errorf("Error decrypting: %v", err)
 	}
-	return
+	//Verify checksum
+	if !e.VerifyIntegrity(key, ciphertext, b, usage) {
+		return nil, errors.New("Error decrypting: integrity verification failed")
+	}
+	//Remove the confounder bytes
+	return b[e.GetConfounderByteSize():], nil
 }
 
 func (e Des3CbcSha1Kd) VerifyIntegrity(protocolKey, ct, pt []byte, usage uint32) bool {

+ 3 - 3
crypto/engine/engine.go

@@ -36,12 +36,12 @@ func DeriveRandom(key, usage []byte, n, k int, e etype.EType) ([]byte, error) {
 	K4 = ...
 
 	DR(Key, Constant) = k-truncate(K1 | K2 | K3 | K4 ...)*/
-	_, K, err := e.Encrypt(key, nFoldUsage)
+	_, K, err := e.EncryptData(key, nFoldUsage)
 	if err != nil {
 		return out, err
 	}
 	for i := copy(out, K); i < len(out); {
-		_, K, _ = e.Encrypt(key, K)
+		_, K, _ = e.EncryptData(key, K)
 		i = i + copy(out[i:], K)
 	}
 	return out, nil
@@ -107,7 +107,7 @@ func getHash(pt, key []byte, usage []byte, etype etype.EType) ([]byte, error) {
 	if err != nil {
 		return nil, fmt.Errorf("Unable to derive key for checksum: %v", err)
 	}
-	mac := hmac.New(etype.GetHash, k)
+	mac := hmac.New(etype.GetHash(), k)
 	p := make([]byte, len(pt))
 	copy(p, pt)
 	mac.Write(p)

+ 9 - 7
crypto/etype/etype.go

@@ -14,12 +14,14 @@ type EType interface {
 	RandomToKey(b []byte) []byte                                // random-to-key (bitstring[K])->(protocol-key)
 	GetHMACBitLength() int                                      // HMAC output size, h
 	GetMessageBlockByteSize() int                               // message block size, m
-	Encrypt(key, message []byte) ([]byte, []byte, error)        // E function - encrypt (specific-key, state, octet string)->(state, octet string)
-	Decrypt(key, ciphertext []byte) ([]byte, error)             // D function
-	GetCypherBlockBitLength() int                               // cipher block size, c
-	GetConfounderByteSize() int                                 // This is the same as the cipher block size but in bytes.
-	DeriveKey(protocolKey, usage []byte) ([]byte, error)        // DK key-derivation (protocol-key, integer)->(specific-key)
-	DeriveRandom(protocolKey, usage []byte) ([]byte, error)     // DR pseudo-random (protocol-key, octet-string)->(octet-string)
+	EncryptData(key, data []byte) ([]byte, []byte, error)
+	EncryptMessage(key, message []byte, usage uint32) ([]byte, []byte, error) // E function - encrypt (specific-key, state, octet string)->(state, octet string)
+	DecryptData(key, data []byte) ([]byte, error)
+	DecryptMessage(key, ciphertext []byte, usage uint32) ([]byte, error) // D function
+	GetCypherBlockBitLength() int                                        // cipher block size, c
+	GetConfounderByteSize() int                                          // This is the same as the cipher block size but in bytes.
+	DeriveKey(protocolKey, usage []byte) ([]byte, error)                 // DK key-derivation (protocol-key, integer)->(specific-key)
+	DeriveRandom(protocolKey, usage []byte) ([]byte, error)              // DR pseudo-random (protocol-key, octet-string)->(octet-string)
 	VerifyIntegrity(protocolKey, ct, pt []byte, usage uint32) bool
-	GetHash() hash.Hash
+	GetHash() func() hash.Hash
 }

+ 1 - 1
pac/credentials_info.go

@@ -45,7 +45,7 @@ func (c *PAC_CredentialsInfo) DecryptEncPart(k types.EncryptionKey, e *binary.By
 	if k.KeyType != int(c.EType) {
 		return fmt.Errorf("Key provided is not the correct type. Type needed: %d, type provided: %d", c.EType, k.KeyType)
 	}
-	pt, err := crypto.DecryptBytes(c.PAC_CredentialData_Encrypted, k, keyusage.KERB_NON_KERB_SALT)
+	pt, err := crypto.DecryptMessage(c.PAC_CredentialData_Encrypted, k, keyusage.KERB_NON_KERB_SALT)
 	if err != nil {
 		return err
 	}

+ 25 - 10
service/APExchange_test.go

@@ -96,8 +96,11 @@ func TestValidateAPREQ_KRB_AP_ERR_BADMATCH(t *testing.T) {
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}
-	assert.IsType(t, messages.KRBError{}, err, "Error is not a KRBError")
-	assert.Equal(t, errorcode.KRB_AP_ERR_BADMATCH, err.(messages.KRBError).ErrorCode, "Error code not as expected")
+	if _, ok := err.(messages.KRBError); ok {
+		assert.Equal(t, errorcode.KRB_AP_ERR_BADMATCH, err.(messages.KRBError).ErrorCode, "Error code not as expected")
+	} else {
+		t.Fatalf("Error is not a KRBError: %v", err)
+	}
 }
 
 func TestValidateAPREQ_LargeClockSkew(t *testing.T) {
@@ -138,8 +141,11 @@ func TestValidateAPREQ_LargeClockSkew(t *testing.T) {
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}
-	assert.IsType(t, messages.KRBError{}, err, "Error is not a KRBError")
-	assert.Equal(t, errorcode.KRB_AP_ERR_SKEW, err.(messages.KRBError).ErrorCode, "Error code not as expected")
+	if _, ok := err.(messages.KRBError); ok {
+		assert.Equal(t, errorcode.KRB_AP_ERR_SKEW, err.(messages.KRBError).ErrorCode, "Error code not as expected")
+	} else {
+		t.Fatalf("Error is not a KRBError: %v", err)
+	}
 }
 
 func TestValidateAPREQ_Replay(t *testing.T) {
@@ -224,8 +230,11 @@ func TestValidateAPREQ_FutureTicket(t *testing.T) {
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}
-	assert.IsType(t, messages.KRBError{}, err, "Error is not a KRBError")
-	assert.Equal(t, errorcode.KRB_AP_ERR_TKT_NYV, err.(messages.KRBError).ErrorCode, "Error code not as expected")
+	if _, ok := err.(messages.KRBError); ok {
+		assert.Equal(t, errorcode.KRB_AP_ERR_TKT_NYV, err.(messages.KRBError).ErrorCode, "Error code not as expected")
+	} else {
+		t.Fatalf("Error is not a KRBError: %v", err)
+	}
 }
 
 func TestValidateAPREQ_InvalidTicket(t *testing.T) {
@@ -266,8 +275,11 @@ func TestValidateAPREQ_InvalidTicket(t *testing.T) {
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}
-	assert.IsType(t, messages.KRBError{}, err, "Error is not a KRBError")
-	assert.Equal(t, errorcode.KRB_AP_ERR_TKT_NYV, err.(messages.KRBError).ErrorCode, "Error code not as expected")
+	if _, ok := err.(messages.KRBError); ok {
+		assert.Equal(t, errorcode.KRB_AP_ERR_TKT_NYV, err.(messages.KRBError).ErrorCode, "Error code not as expected")
+	} else {
+		t.Fatalf("Error is not a KRBError: %v", err)
+	}
 }
 
 func TestValidateAPREQ_ExpiredTicket(t *testing.T) {
@@ -307,8 +319,11 @@ func TestValidateAPREQ_ExpiredTicket(t *testing.T) {
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}
-	assert.IsType(t, messages.KRBError{}, err, "Error is not a KRBError")
-	assert.Equal(t, errorcode.KRB_AP_ERR_TKT_EXPIRED, err.(messages.KRBError).ErrorCode, "Error code not as expected")
+	if _, ok := err.(messages.KRBError); ok {
+		assert.Equal(t, errorcode.KRB_AP_ERR_TKT_EXPIRED, err.(messages.KRBError).ErrorCode, "Error code not as expected")
+	} else {
+		t.Fatalf("Error is not a KRBError: %v", err)
+	}
 }
 
 func newTestAuthenticator(creds credentials.Credentials) types.Authenticator {