소스 검색

refactor of crypto package

Jonathan Turner 9 년 전
부모
커밋
a1eb49d00b

+ 3 - 4
config/krb5conf.go

@@ -7,7 +7,7 @@ import (
 	"errors"
 	"fmt"
 	"github.com/jcmturner/asn1"
-	"github.com/jcmturner/gokrb5/iana/etype"
+	"github.com/jcmturner/gokrb5/iana/etypeID"
 	"io"
 	"os"
 	"os/user"
@@ -17,7 +17,6 @@ import (
 	"time"
 )
 
-
 // Struct representing the KRB5 configuration.
 type Config struct {
 	LibDefaults *LibDefaults
@@ -409,7 +408,7 @@ func (d *DomainRealm) deleteMapping(domain, realm string) {
 func (c *Config) ResolveRealm(domainName string) string {
 	domainName = strings.TrimSuffix(domainName, ".")
 	periods := strings.Count(domainName, ".") + 1
-	for i := 1; i <= periods; i +=1 {
+	for i := 1; i <= periods; i += 1 {
 		z := strings.SplitN(domainName, ".", i)
 		if r, ok := c.DomainRealm[z[len(z)-1]]; ok {
 			return r
@@ -521,7 +520,7 @@ func parseETypes(s []string, w bool) []int {
 				continue
 			}
 		}
-		i := etype.ETypesByName[et]
+		i := etypeID.ETypesByName[et]
 		if i != 0 {
 			eti = append(eti, i)
 		}

+ 0 - 344
crypto/EncryptionEngine.go

@@ -1,344 +0,0 @@
-package crypto
-
-import (
-	"bytes"
-	"crypto/hmac"
-	"crypto/rand"
-	"encoding/binary"
-	"encoding/hex"
-	"errors"
-	"fmt"
-	"github.com/jcmturner/gokrb5/iana/chksumtype"
-	"github.com/jcmturner/gokrb5/iana/etype"
-	"github.com/jcmturner/gokrb5/iana/patype"
-	"github.com/jcmturner/gokrb5/types"
-	"hash"
-)
-
-type EType interface {
-	GetETypeID() int
-	GetHashID() int
-	GetKeyByteSize() int                                        // See "protocol key format" for defined values
-	GetKeySeedBitLength() int                                   // key-generation seed length, k
-	GetDefaultStringToKeyParams() string                        // default string-to-key parameters (s2kparams)
-	StringToKey(string, salt, s2kparams string) ([]byte, error) // string-to-key (UTF-8 string, UTF-8 string, opaque)->(protocol-key)
-	RandomToKey(b []byte) []byte                                // random-to-key (bitstring[K])->(protocol-key)
-	GetHMACBitLength() int                                      // HMAC output size, h
-	GetMessageBlockByteSize() int                               // message block size, m
-	Encrypt(key, message []byte) ([]byte, []byte, error)        // E function - encrypt (specific-key, state, octet string)->(state, octet string)
-	Decrypt(key, ciphertext []byte) ([]byte, error)             // D function
-	GetCypherBlockBitLength() int                               // cipher block size, c
-	GetConfounderByteSize() int                                 // This is the same as the cipher block size but in bytes.
-	DeriveKey(protocolKey, usage []byte) ([]byte, error)        // DK key-derivation (protocol-key, integer)->(specific-key)
-	DeriveRandom(protocolKey, usage []byte) ([]byte, error)     // DR pseudo-random (protocol-key, octet-string)->(octet-string)
-	VerifyIntegrity(protocolKey, ct, pt []byte, usage uint32) bool
-	GetHash() hash.Hash
-}
-
-func GetEtype(id int) (EType, error) {
-	switch id {
-	case etype.AES128_CTS_HMAC_SHA1_96:
-		var et Aes128CtsHmacSha96
-		return et, nil
-	case etype.AES256_CTS_HMAC_SHA1_96:
-		var et Aes256CtsHmacSha96
-		return et, nil
-	default:
-		return nil, fmt.Errorf("Unknown or unsupported EType: %d", id)
-	}
-}
-
-func GetChksumEtype(id int) (EType, error) {
-	switch id {
-	case chksumtype.HMAC_SHA1_96_AES128:
-		var et Aes128CtsHmacSha96
-		return et, nil
-	case chksumtype.HMAC_SHA1_96_AES256:
-		var et Aes256CtsHmacSha96
-		return et, nil
-	default:
-		return nil, fmt.Errorf("Unknown or unsupported checksum type: %d", id)
-	}
-}
-
-// RFC3961: DR(Key, Constant) = k-truncate(E(Key, Constant, initial-cipher-state))
-// key - base key or protocol key. Likely to be a key from a keytab file
-// usage - a constant
-// n - block size in bits (not bytes) - note if you use something like aes.BlockSize this is in bytes.
-// k - key length / key seed length in bits. Eg. for AES256 this value is 256
-// encrypt - the encryption function to use
-func deriveRandom(key, usage []byte, n, k int, e EType) ([]byte, error) {
-	//Ensure the usage constant is at least the size of the cypher block size. Pass it through the nfold algorithm that will "stretch" it if needs be.
-	nFoldUsage := Nfold(usage, n)
-	//k-truncate implemented by creating a byte array the size of k (k is in bits hence /8)
-	out := make([]byte, k/8)
-
-	/*If the output	of E is shorter than k bits, it is fed back into the encryption as many times as necessary.
-	The construct is as follows (where | indicates concatentation):
-
-	K1 = E(Key, n-fold(Constant), initial-cipher-state)
-	K2 = E(Key, K1, initial-cipher-state)
-	K3 = E(Key, K2, initial-cipher-state)
-	K4 = ...
-
-	DR(Key, Constant) = k-truncate(K1 | K2 | K3 | K4 ...)*/
-	_, K, err := e.Encrypt(key, nFoldUsage)
-	if err != nil {
-		return out, err
-	}
-	for i := copy(out, K); i < len(out); {
-		_, K, _ = e.Encrypt(key, K)
-		i = i + copy(out[i:], K)
-	}
-	return out, nil
-}
-
-func zeroPad(b []byte, m int) ([]byte, error) {
-	if m <= 0 {
-		return nil, errors.New("Invalid message block size when padding")
-	}
-	if b == nil || len(b) == 0 {
-		return nil, errors.New("Data not valid to pad: Zero size")
-	}
-	if l := len(b) % m; l != 0 {
-		n := m - l
-		z := make([]byte, n)
-		b = append(b, z...)
-	}
-	return b, nil
-}
-
-func pkcs7Pad(b []byte, m int) ([]byte, error) {
-	if m <= 0 {
-		return nil, errors.New("Invalid message block size when padding")
-	}
-	if b == nil || len(b) == 0 {
-		return nil, errors.New("Data not valid to pad: Zero size")
-	}
-	n := m - (len(b) % m)
-	pb := make([]byte, len(b)+n)
-	copy(pb, b)
-	copy(pb[len(b):], bytes.Repeat([]byte{byte(n)}, n))
-	return pb, nil
-}
-
-func pkcs7Unpad(b []byte, m int) ([]byte, error) {
-	if m <= 0 {
-		return nil, errors.New("Invalid message block size when unpadding")
-	}
-	if b == nil || len(b) == 0 {
-		return nil, errors.New("Padded data not valid: Zero size")
-	}
-	if len(b)%m != 0 {
-		return nil, errors.New("Padded data not valid: Not multiple of message block size")
-	}
-	c := b[len(b)-1]
-	n := int(c)
-	if n == 0 || n > len(b) {
-		return nil, errors.New("Padded data not valid: Data may not have been padded")
-	}
-	for i := 0; i < n; i++ {
-		if b[len(b)-n+i] != c {
-			return nil, errors.New("Padded data not valid")
-		}
-	}
-	return b[:len(b)-n], nil
-}
-
-func DecryptEncPart(key []byte, pe types.EncryptedData, etype EType, usage uint32) ([]byte, error) {
-	//Derive the key
-	k, err := etype.DeriveKey(key, GetUsageKe(usage))
-	if err != nil {
-		return nil, fmt.Errorf("Error deriving key: %v", err)
-	}
-	// Strip off the checksum from the end
-	b, err := etype.Decrypt(k, pe.Cipher[:len(pe.Cipher)-etype.GetHMACBitLength()/8])
-	if err != nil {
-		return nil, fmt.Errorf("Error decrypting: %v", err)
-	}
-	//Verify checksum
-	if !etype.VerifyIntegrity(key, pe.Cipher, b, usage) {
-		return nil, errors.New("Error decrypting encrypted part: integrity verification failed")
-	}
-	//Remove the confounder bytes
-	b = b[etype.GetConfounderByteSize():]
-	if err != nil {
-		return nil, fmt.Errorf("Error decrypting encrypted part: %v", err)
-	}
-	return b, nil
-}
-
-func GetKeyFromPassword(passwd string, cn types.PrincipalName, realm string, etypeId int, pas types.PADataSequence) (types.EncryptionKey, EType, error) {
-	var key types.EncryptionKey
-	etype, err := GetEtype(etypeId)
-	if err != nil {
-		return key, etype, fmt.Errorf("Error getting encryption type: %v", err)
-	}
-	sk2p := etype.GetDefaultStringToKeyParams()
-	var salt string
-	var paID int
-	for _, pa := range pas {
-		switch pa.PADataType {
-		case patype.PA_PW_SALT:
-			if paID > pa.PADataType {
-				continue
-			}
-			salt = string(pa.PADataValue)
-		case patype.PA_ETYPE_INFO:
-			if paID > pa.PADataType {
-				continue
-			}
-			var et types.ETypeInfo
-			err := et.Unmarshal(pa.PADataValue)
-			if err != nil {
-				return key, etype, fmt.Errorf("Error unmashalling PA Data to PA-ETYPE-INFO2: %v", err)
-			}
-			if etypeId != et[0].EType {
-				etype, err = GetEtype(et[0].EType)
-				if err != nil {
-					return key, etype, fmt.Errorf("Error getting encryption type: %v", err)
-				}
-			}
-			salt = string(et[0].Salt)
-		case patype.PA_ETYPE_INFO2:
-			if paID > pa.PADataType {
-				continue
-			}
-			var et2 types.ETypeInfo2
-			err := et2.Unmarshal(pa.PADataValue)
-			if err != nil {
-				return key, etype, fmt.Errorf("Error unmashalling PA Data to PA-ETYPE-INFO2: %v", err)
-			}
-			if etypeId != et2[0].EType {
-				etype, err = GetEtype(et2[0].EType)
-				if err != nil {
-					return key, etype, fmt.Errorf("Error getting encryption type: %v", err)
-				}
-			}
-			if len(et2[0].S2KParams) == 4 {
-				sk2p = hex.EncodeToString(et2[0].S2KParams)
-			}
-			salt = et2[0].Salt
-		}
-	}
-	if salt == "" {
-		salt = cn.GetSalt(realm)
-	}
-	k, err := etype.StringToKey(passwd, salt, sk2p)
-	if err != nil {
-		return key, etype, fmt.Errorf("Error deriving key from string: %+v", err)
-	}
-	key = types.EncryptionKey{
-		KeyType:  etypeId,
-		KeyValue: k,
-	}
-	return key, etype, nil
-}
-
-func getHash(pt, key []byte, usage []byte, etype EType) ([]byte, error) {
-	k, err := etype.DeriveKey(key, usage)
-	if err != nil {
-		return nil, fmt.Errorf("Unable to derive key for checksum: %v", err)
-	}
-	mac := hmac.New(etype.GetHash, k)
-	p := make([]byte, len(pt))
-	copy(p, pt)
-	mac.Write(p)
-	return mac.Sum(nil)[:etype.GetHMACBitLength()/8], nil
-}
-
-func GetChecksumHash(pt, key []byte, usage uint32, etype EType) ([]byte, error) {
-	return getHash(pt, key, GetUsageKc(usage), etype)
-}
-
-func GetIntegrityHash(pt, key []byte, usage uint32, etype EType) ([]byte, error) {
-	return getHash(pt, key, GetUsageKi(usage), etype)
-}
-
-func VerifyIntegrity(key, ct, pt []byte, usage uint32, etype EType) bool {
-	//The ciphertext output is the concatenation of the output of the basic
-	//encryption function E and a (possibly truncated) HMAC using the
-	//specified hash function H, both applied to the plaintext with a
-	//random confounder prefix and sufficient padding to bring it to a
-	//multiple of the message block size.  When the HMAC is computed, the
-	//key is used in the protocol key form.
-	h := make([]byte, etype.GetHMACBitLength()/8)
-	copy(h, ct[len(ct)-etype.GetHMACBitLength()/8:])
-	expectedMAC, _ := GetIntegrityHash(pt, key, usage, etype)
-	return hmac.Equal(h, expectedMAC)
-}
-
-func VerifyChecksum(key, chksum, msg []byte, usage uint32, etype EType) bool {
-	//The ciphertext output is the concatenation of the output of the basic
-	//encryption function E and a (possibly truncated) HMAC using the
-	//specified hash function H, both applied to the plaintext with a
-	//random confounder prefix and sufficient padding to bring it to a
-	//multiple of the message block size.  When the HMAC is computed, the
-	//key is used in the protocol key form.
-	expectedMAC, _ := GetChecksumHash(msg, key, usage, etype)
-	return hmac.Equal(chksum, expectedMAC)
-}
-
-/*
-Key Usage Numbers
-RFC 3961: The "well-known constant" used for the DK function is the key usage number, expressed as four octets in big-endian order, followed by one octet indicated below.
-Kc = DK(base-key, usage | 0x99);
-Ke = DK(base-key, usage | 0xAA);
-Ki = DK(base-key, usage | 0x55);
-*/
-
-// un - usage number
-func GetUsageKc(un uint32) []byte {
-	return getUsage(un, 0x99)
-}
-
-// un - usage number
-func GetUsageKe(un uint32) []byte {
-	return getUsage(un, 0xAA)
-}
-
-// un - usage number
-func GetUsageKi(un uint32) []byte {
-	return getUsage(un, 0x55)
-}
-
-func getUsage(un uint32, o byte) []byte {
-	var buf bytes.Buffer
-	binary.Write(&buf, binary.BigEndian, un)
-	return append(buf.Bytes(), o)
-}
-
-// Pass a usage value of zero to use the key provided directly rather than deriving one
-func GetEncryptedData(pt []byte, key types.EncryptionKey, usage int, kvno int) (types.EncryptedData, error) {
-	var ed types.EncryptedData
-	etype, err := GetEtype(key.KeyType)
-	if err != nil {
-		return ed, fmt.Errorf("Error getting etype: %v", err)
-	}
-	k := key.KeyValue
-	if usage != 0 {
-		k, err = etype.DeriveKey(key.KeyValue, GetUsageKe(uint32(usage)))
-		if err != nil {
-			return ed, fmt.Errorf("Error deriving key: %v", err)
-		}
-	}
-	//confounder
-	c := make([]byte, etype.GetConfounderByteSize())
-	_, err = rand.Read(c)
-	if err != nil {
-		return ed, fmt.Errorf("Could not generate random confounder: %v", err)
-	}
-	pt = append(c, pt...)
-	_, b, err := etype.Encrypt(k, pt)
-	if err != nil {
-		return ed, fmt.Errorf("Error encrypting data: %v", err)
-	}
-	ih, err := GetIntegrityHash(pt, key.KeyValue, uint32(usage), etype)
-	b = append(b, ih...)
-	ed = types.EncryptedData{
-		EType:  key.KeyType,
-		Cipher: b,
-		KVNO:   kvno,
-	}
-	return ed, nil
-}

+ 20 - 18
crypto/aes-cts-hmac-sha1-96.go → crypto/aes/aes.go

@@ -1,4 +1,4 @@
-package crypto
+package aes
 
 import (
 	"bytes"
@@ -9,6 +9,8 @@ import (
 	"encoding/hex"
 	"errors"
 	"fmt"
+	"github.com/jcmturner/gokrb5/crypto/engine"
+	"github.com/jcmturner/gokrb5/crypto/etype"
 	"golang.org/x/crypto/pbkdf2"
 )
 
@@ -16,7 +18,7 @@ const (
 	s2kParamsZero = 4294967296
 )
 
-func AESStringToKey(secret, salt, s2kparams string, e EType) ([]byte, error) {
+func stringToKey(secret, salt, s2kparams string, e etype.EType) ([]byte, error) {
 	//process s2kparams string
 	//The parameter string is four octets indicating an unsigned
 	//number in big-endian order.  This is the number of iterations to be
@@ -36,39 +38,39 @@ func AESStringToKey(secret, salt, s2kparams string, e EType) ([]byte, error) {
 		return nil, errors.New("Invalid s2kparams, cannot convert to big endian int32")
 	}
 	if i == 0 {
-		return AESStringToKeyIter(secret, salt, s2kParamsZero, e)
+		return stringToKeyIter(secret, salt, s2kParamsZero, e)
 	}
-	return AESStringToKeyIter(secret, salt, int(i), e)
+	return stringToKeyIter(secret, salt, int(i), e)
 }
 
-func AESStringToPBKDF2(secret, salt string, iterations int, e EType) []byte {
+func stringToPBKDF2(secret, salt string, iterations int, e etype.EType) []byte {
 	return pbkdf2.Key([]byte(secret), []byte(salt), iterations, e.GetKeyByteSize(), sha1.New)
 }
 
-func AESStringToKeyIter(secret, salt string, iterations int, e EType) ([]byte, error) {
-	tkey := AESRandomToKey(AESStringToPBKDF2(secret, salt, iterations, e))
-	key, err := AESDeriveKey(tkey, []byte("kerberos"), e)
+func stringToKeyIter(secret, salt string, iterations int, e etype.EType) ([]byte, error) {
+	tkey := randomToKey(stringToPBKDF2(secret, salt, iterations, e))
+	key, err := deriveKey(tkey, []byte("kerberos"), e)
 	return key, err
 }
 
-func AESRandomToKey(b []byte) []byte {
+func randomToKey(b []byte) []byte {
 	return b
 }
 
-func AESDeriveRandom(protocolKey, usage []byte, e EType) ([]byte, error) {
-	r, err := deriveRandom(protocolKey, usage, e.GetCypherBlockBitLength(), e.GetKeySeedBitLength(), e)
+func deriveRandom(protocolKey, usage []byte, e etype.EType) ([]byte, error) {
+	r, err := engine.DeriveRandom(protocolKey, usage, e.GetCypherBlockBitLength(), e.GetKeySeedBitLength(), e)
 	return r, err
 }
 
-func AESDeriveKey(protocolKey, usage []byte, e EType) ([]byte, error) {
-	r, err := AESDeriveRandom(protocolKey, usage, e)
+func deriveKey(protocolKey, usage []byte, e etype.EType) ([]byte, error) {
+	r, err := deriveRandom(protocolKey, usage, e)
 	if err != nil {
 		return nil, err
 	}
-	return AESRandomToKey(r), nil
+	return randomToKey(r), nil
 }
 
-func AESCTSEncrypt(key, iv, message []byte, e EType) ([]byte, []byte, error) {
+func encryptCTS(key, iv, message []byte, e etype.EType) ([]byte, []byte, error) {
 	if len(key) != e.GetKeyByteSize() {
 		return nil, nil, fmt.Errorf("Incorrect keysize: expected: %v actual: %v", e.GetKeyByteSize(), len(key))
 	}
@@ -93,7 +95,7 @@ func AESCTSEncrypt(key, iv, message []byte, e EType) ([]byte, []byte, error) {
 	subsequent encryption is the next-to-last block of the encryption
 	output; this is the encrypted form of the last plaintext block.*/
 	if l <= aes.BlockSize {
-		m, _ = zeroPad(m, aes.BlockSize)
+		m, _ = engine.ZeroPad(m, aes.BlockSize)
 		mode.CryptBlocks(m, m)
 		return m, m, nil
 	}
@@ -103,7 +105,7 @@ func AESCTSEncrypt(key, iv, message []byte, e EType) ([]byte, []byte, error) {
 		rb, _ := swapLastTwoBlocks(m, aes.BlockSize)
 		return iv, rb, nil
 	}
-	m, _ = zeroPad(m, aes.BlockSize)
+	m, _ = engine.ZeroPad(m, aes.BlockSize)
 	rb, pb, lb, err := tailBlocks(m, aes.BlockSize)
 	var ct []byte
 	if rb != nil {
@@ -125,7 +127,7 @@ func AESCTSEncrypt(key, iv, message []byte, e EType) ([]byte, []byte, error) {
 	//TODO do we need to add the hash to the end?
 }
 
-func AESCTSDecrypt(key, ciphertext []byte, e EType) ([]byte, error) {
+func decryptCTS(key, ciphertext []byte, e etype.EType) ([]byte, error) {
 	// Copy the cipher text as golang slices even when passed by value to this method can result in the backing arrays of the calling code value being updated.
 	ct := make([]byte, len(ciphertext))
 	copy(ct, ciphertext)

+ 12 - 11
crypto/aes128-cts-hmac-sha1-96.go → crypto/aes/aes128-cts-hmac-sha1-96.go

@@ -1,11 +1,12 @@
-package crypto
+package aes
 
 import (
 	"crypto/aes"
 	"crypto/sha1"
-	"hash"
+	"github.com/jcmturner/gokrb5/crypto/engine"
 	"github.com/jcmturner/gokrb5/iana/chksumtype"
-	"github.com/jcmturner/gokrb5/iana/etype"
+	"github.com/jcmturner/gokrb5/iana/etypeID"
+	"hash"
 )
 
 // RFC 3962
@@ -58,7 +59,7 @@ type Aes128CtsHmacSha96 struct {
 }
 
 func (e Aes128CtsHmacSha96) GetETypeID() int {
-	return etype.AES128_CTS_HMAC_SHA1_96
+	return etypeID.AES128_CTS_HMAC_SHA1_96
 }
 
 func (e Aes128CtsHmacSha96) GetHashID() int {
@@ -98,30 +99,30 @@ func (e Aes128CtsHmacSha96) GetCypherBlockBitLength() int {
 }
 
 func (e Aes128CtsHmacSha96) StringToKey(secret string, salt string, s2kparams string) ([]byte, error) {
-	return AESStringToKey(secret, salt, s2kparams, e)
+	return stringToKey(secret, salt, s2kparams, e)
 }
 
 func (e Aes128CtsHmacSha96) RandomToKey(b []byte) []byte {
-	return AESRandomToKey(b)
+	return randomToKey(b)
 }
 
 func (e Aes128CtsHmacSha96) Encrypt(key, message []byte) ([]byte, []byte, error) {
 	ivz := make([]byte, aes.BlockSize)
-	return AESCTSEncrypt(key, ivz, message, e)
+	return encryptCTS(key, ivz, message, e)
 }
 
 func (e Aes128CtsHmacSha96) Decrypt(key, ciphertext []byte) ([]byte, error) {
-	return AESCTSDecrypt(key, ciphertext, e)
+	return decryptCTS(key, ciphertext, e)
 }
 
 func (e Aes128CtsHmacSha96) DeriveKey(protocolKey, usage []byte) ([]byte, error) {
-	return AESDeriveKey(protocolKey, usage, e)
+	return deriveKey(protocolKey, usage, e)
 }
 
 func (e Aes128CtsHmacSha96) DeriveRandom(protocolKey, usage []byte) ([]byte, error) {
-	return AESDeriveRandom(protocolKey, usage, e)
+	return deriveRandom(protocolKey, usage, e)
 }
 
 func (e Aes128CtsHmacSha96) VerifyIntegrity(protocolKey, ct, pt []byte, usage uint32) bool {
-	return VerifyIntegrity(protocolKey, ct, pt, usage, e)
+	return engine.VerifyIntegrity(protocolKey, ct, pt, usage, e)
 }

+ 12 - 11
crypto/aes256-cts-hmac-sha1-96.go → crypto/aes/aes256-cts-hmac-sha1-96.go

@@ -1,11 +1,12 @@
-package crypto
+package aes
 
 import (
 	"crypto/aes"
 	"crypto/sha1"
-	"hash"
-	"github.com/jcmturner/gokrb5/iana/etype"
+	"github.com/jcmturner/gokrb5/crypto/engine"
 	"github.com/jcmturner/gokrb5/iana/chksumtype"
+	"github.com/jcmturner/gokrb5/iana/etypeID"
+	"hash"
 )
 
 // RFC 3962
@@ -58,7 +59,7 @@ type Aes256CtsHmacSha96 struct {
 }
 
 func (e Aes256CtsHmacSha96) GetETypeID() int {
-	return etype.AES256_CTS_HMAC_SHA1_96
+	return etypeID.AES256_CTS_HMAC_SHA1_96
 }
 
 func (e Aes256CtsHmacSha96) GetHashID() int {
@@ -98,30 +99,30 @@ func (e Aes256CtsHmacSha96) GetCypherBlockBitLength() int {
 }
 
 func (e Aes256CtsHmacSha96) StringToKey(secret string, salt string, s2kparams string) ([]byte, error) {
-	return AESStringToKey(secret, salt, s2kparams, e)
+	return stringToKey(secret, salt, s2kparams, e)
 }
 
 func (e Aes256CtsHmacSha96) RandomToKey(b []byte) []byte {
-	return AESRandomToKey(b)
+	return randomToKey(b)
 }
 
 func (e Aes256CtsHmacSha96) Encrypt(key, message []byte) ([]byte, []byte, error) {
 	ivz := make([]byte, aes.BlockSize)
-	return AESCTSEncrypt(key, ivz, message, e)
+	return encryptCTS(key, ivz, message, e)
 }
 
 func (e Aes256CtsHmacSha96) Decrypt(key, ciphertext []byte) ([]byte, error) {
-	return AESCTSDecrypt(key, ciphertext, e)
+	return decryptCTS(key, ciphertext, e)
 }
 
 func (e Aes256CtsHmacSha96) DeriveKey(protocolKey, usage []byte) ([]byte, error) {
-	return AESDeriveKey(protocolKey, usage, e)
+	return deriveKey(protocolKey, usage, e)
 }
 
 func (e Aes256CtsHmacSha96) DeriveRandom(protocolKey, usage []byte) ([]byte, error) {
-	return AESDeriveRandom(protocolKey, usage, e)
+	return deriveRandom(protocolKey, usage, e)
 }
 
 func (e Aes256CtsHmacSha96) VerifyIntegrity(protocolKey, ct, pt []byte, usage uint32) bool {
-	return VerifyIntegrity(protocolKey, ct, pt, usage, e)
+	return engine.VerifyIntegrity(protocolKey, ct, pt, usage, e)
 }

+ 5 - 5
crypto/aes-cts-hmac-sha1-96_test.go → crypto/aes/aes_test.go

@@ -1,4 +1,4 @@
-package crypto
+package aes
 
 import (
 	"encoding/hex"
@@ -25,7 +25,7 @@ func TestAesCtsHmacSha196_Encrypt_Decrypt(t *testing.T) {
 	var e Aes128CtsHmacSha96
 	for i, test := range tests {
 		m, _ := hex.DecodeString(test.plain)
-		niv, c, err := AESCTSEncrypt(key, iv, m, e)
+		niv, c, err := encryptCTS(key, iv, m, e)
 		if err != nil {
 			t.Errorf("Encryption failed for test %v: %v", i+1, err)
 		}
@@ -35,7 +35,7 @@ func TestAesCtsHmacSha196_Encrypt_Decrypt(t *testing.T) {
 	//t.Log("AES CTS Encryption tests finished")
 	for i, test := range tests {
 		b, _ := hex.DecodeString(test.cipher)
-		p, err := AESCTSDecrypt(key, b, e)
+		p, err := decryptCTS(key, b, e)
 		if err != nil {
 			t.Errorf("Decryption failed for test %v: %v", i+1, err)
 		}
@@ -68,8 +68,8 @@ func TestAes256CtsHmacSha196_StringToKey(t *testing.T) {
 	var e Aes256CtsHmacSha96
 	for i, test := range tests {
 
-		assert.Equal(t, test.pbkdf2, hex.EncodeToString(AESStringToPBKDF2(test.phrase, test.salt, test.iterations, e)), "PBKDF2 not as expected")
-		k, err := AESStringToKeyIter(test.phrase, test.salt, test.iterations, e)
+		assert.Equal(t, test.pbkdf2, hex.EncodeToString(stringToPBKDF2(test.phrase, test.salt, test.iterations, e)), "PBKDF2 not as expected")
+		k, err := stringToKeyIter(test.phrase, test.salt, test.iterations, e)
 		if err != nil {
 			t.Errorf("Error in processing string to key for test %d: %v", i, err)
 		}

+ 167 - 0
crypto/crypto.go

@@ -0,0 +1,167 @@
+package crypto
+
+import (
+	"crypto/rand"
+	"encoding/hex"
+	"errors"
+	"fmt"
+	"github.com/jcmturner/gokrb5/crypto/aes"
+	"github.com/jcmturner/gokrb5/crypto/engine"
+	"github.com/jcmturner/gokrb5/crypto/etype"
+	"github.com/jcmturner/gokrb5/iana/chksumtype"
+	"github.com/jcmturner/gokrb5/iana/etypeID"
+	"github.com/jcmturner/gokrb5/iana/patype"
+	"github.com/jcmturner/gokrb5/types"
+)
+
+func GetEtype(id int) (etype.EType, error) {
+	switch id {
+	case etypeID.AES128_CTS_HMAC_SHA1_96:
+		var et aes.Aes128CtsHmacSha96
+		return et, nil
+	case etypeID.AES256_CTS_HMAC_SHA1_96:
+		var et aes.Aes256CtsHmacSha96
+		return et, nil
+	default:
+		return nil, fmt.Errorf("Unknown or unsupported EType: %d", id)
+	}
+}
+
+func GetChksumEtype(id int) (etype.EType, error) {
+	switch id {
+	case chksumtype.HMAC_SHA1_96_AES128:
+		var et aes.Aes128CtsHmacSha96
+		return et, nil
+	case chksumtype.HMAC_SHA1_96_AES256:
+		var et aes.Aes256CtsHmacSha96
+		return et, nil
+	default:
+		return nil, fmt.Errorf("Unknown or unsupported checksum type: %d", id)
+	}
+}
+
+func GetKeyFromPassword(passwd string, cn types.PrincipalName, realm string, etypeId int, pas types.PADataSequence) (types.EncryptionKey, etype.EType, error) {
+	var key types.EncryptionKey
+	etype, err := GetEtype(etypeId)
+	if err != nil {
+		return key, etype, fmt.Errorf("Error getting encryption type: %v", err)
+	}
+	sk2p := etype.GetDefaultStringToKeyParams()
+	var salt string
+	var paID int
+	for _, pa := range pas {
+		switch pa.PADataType {
+		case patype.PA_PW_SALT:
+			if paID > pa.PADataType {
+				continue
+			}
+			salt = string(pa.PADataValue)
+		case patype.PA_ETYPE_INFO:
+			if paID > pa.PADataType {
+				continue
+			}
+			var et types.ETypeInfo
+			err := et.Unmarshal(pa.PADataValue)
+			if err != nil {
+				return key, etype, fmt.Errorf("Error unmashalling PA Data to PA-ETYPE-INFO2: %v", err)
+			}
+			if etypeId != et[0].EType {
+				etype, err = GetEtype(et[0].EType)
+				if err != nil {
+					return key, etype, fmt.Errorf("Error getting encryption type: %v", err)
+				}
+			}
+			salt = string(et[0].Salt)
+		case patype.PA_ETYPE_INFO2:
+			if paID > pa.PADataType {
+				continue
+			}
+			var et2 types.ETypeInfo2
+			err := et2.Unmarshal(pa.PADataValue)
+			if err != nil {
+				return key, etype, fmt.Errorf("Error unmashalling PA Data to PA-ETYPE-INFO2: %v", err)
+			}
+			if etypeId != et2[0].EType {
+				etype, err = GetEtype(et2[0].EType)
+				if err != nil {
+					return key, etype, fmt.Errorf("Error getting encryption type: %v", err)
+				}
+			}
+			if len(et2[0].S2KParams) == 4 {
+				sk2p = hex.EncodeToString(et2[0].S2KParams)
+			}
+			salt = et2[0].Salt
+		}
+	}
+	if salt == "" {
+		salt = cn.GetSalt(realm)
+	}
+	k, err := etype.StringToKey(passwd, salt, sk2p)
+	if err != nil {
+		return key, etype, fmt.Errorf("Error deriving key from string: %+v", err)
+	}
+	key = types.EncryptionKey{
+		KeyType:  etypeId,
+		KeyValue: k,
+	}
+	return key, etype, nil
+}
+
+// Pass a usage value of zero to use the key provided directly rather than deriving one
+func GetEncryptedData(pt []byte, key types.EncryptionKey, usage int, kvno int) (types.EncryptedData, error) {
+	var ed types.EncryptedData
+	etype, err := GetEtype(key.KeyType)
+	if err != nil {
+		return ed, fmt.Errorf("Error getting etype: %v", err)
+	}
+	k := key.KeyValue
+	if usage != 0 {
+		k, err = etype.DeriveKey(key.KeyValue, engine.GetUsageKe(uint32(usage)))
+		if err != nil {
+			return ed, fmt.Errorf("Error deriving key: %v", err)
+		}
+	}
+	//confounder
+	c := make([]byte, etype.GetConfounderByteSize())
+	_, err = rand.Read(c)
+	if err != nil {
+		return ed, fmt.Errorf("Could not generate random confounder: %v", err)
+	}
+	pt = append(c, pt...)
+	_, b, err := etype.Encrypt(k, pt)
+	if err != nil {
+		return ed, fmt.Errorf("Error encrypting data: %v", err)
+	}
+	ih, err := engine.GetIntegrityHash(pt, key.KeyValue, uint32(usage), etype)
+	b = append(b, ih...)
+	ed = types.EncryptedData{
+		EType:  key.KeyType,
+		Cipher: b,
+		KVNO:   kvno,
+	}
+	return ed, nil
+}
+
+func DecryptEncPart(pe types.EncryptedData, key types.EncryptionKey, usage uint32) ([]byte, error) {
+	//Derive the key
+	etype, err := GetEtype(key.KeyType)
+	k, err := etype.DeriveKey(key.KeyValue, engine.GetUsageKe(usage))
+	if err != nil {
+		return nil, fmt.Errorf("Error deriving key: %v", err)
+	}
+	// Strip off the checksum from the end
+	b, err := etype.Decrypt(k, pe.Cipher[:len(pe.Cipher)-etype.GetHMACBitLength()/8])
+	if err != nil {
+		return nil, fmt.Errorf("Error decrypting: %v", err)
+	}
+	//Verify checksum
+	if !etype.VerifyIntegrity(key.KeyValue, pe.Cipher, b, usage) {
+		return nil, errors.New("Error decrypting encrypted part: integrity verification failed")
+	}
+	//Remove the confounder bytes
+	b = b[etype.GetConfounderByteSize():]
+	if err != nil {
+		return nil, fmt.Errorf("Error decrypting encrypted part: %v", err)
+	}
+	return b, nil
+}

+ 9 - 8
crypto/des3-cbc-sha1-kd.go → crypto/des3/des3-cbc-sha1-kd.go

@@ -1,4 +1,4 @@
-package crypto
+package des3
 
 import (
 	"crypto/cipher"
@@ -6,9 +6,10 @@ import (
 	"crypto/sha1"
 	"errors"
 	"fmt"
-	"hash"
+	"github.com/jcmturner/gokrb5/crypto/engine"
 	"github.com/jcmturner/gokrb5/iana/chksumtype"
-	"github.com/jcmturner/gokrb5/iana/etype"
+	"github.com/jcmturner/gokrb5/iana/etypeID"
+	"hash"
 )
 
 //RFC: 3961 Section 6.3
@@ -52,7 +53,7 @@ type Des3CbcSha1Kd struct {
 }
 
 func (e Des3CbcSha1Kd) GetETypeID() int {
-	return etype.DES3_CBC_SHA1_KD
+	return etypeID.DES3_CBC_SHA1_KD
 }
 
 func (e Des3CbcSha1Kd) GetHashID() int {
@@ -104,7 +105,7 @@ func (e Des3CbcSha1Kd) RandomToKey(b []byte) (protocolKey []byte) {
 }
 
 func (e Des3CbcSha1Kd) DeriveRandom(protocolKey, usage []byte) ([]byte, error) {
-	r, err := deriveRandom(protocolKey, usage, e.GetCypherBlockBitLength(), e.GetKeySeedBitLength(), e)
+	r, err := engine.DeriveRandom(protocolKey, usage, e.GetCypherBlockBitLength(), e.GetKeySeedBitLength(), e)
 	return r, err
 }
 
@@ -122,7 +123,7 @@ func (e Des3CbcSha1Kd) Encrypt(key, message []byte) ([]byte, []byte, error) {
 
 	}
 	if len(message)%e.GetMessageBlockByteSize() != 0 {
-		message, _ = pkcs7Pad(message, e.GetMessageBlockByteSize())
+		message, _ = engine.PKCS7Pad(message, e.GetMessageBlockByteSize())
 	}
 
 	block, err := des.NewTripleDESCipher(key)
@@ -163,7 +164,7 @@ func (e Des3CbcSha1Kd) Decrypt(key, ciphertext []byte) (message []byte, err erro
 
 	mode := cipher.NewCBCDecrypter(block, iv)
 	mode.CryptBlocks(message, ciphertext)
-	m, er := pkcs7Unpad(message, e.GetMessageBlockByteSize())
+	m, er := engine.PKCS7Unpad(message, e.GetMessageBlockByteSize())
 	if er == nil {
 		message = m
 	}
@@ -171,5 +172,5 @@ func (e Des3CbcSha1Kd) Decrypt(key, ciphertext []byte) (message []byte, err erro
 }
 
 func (e Des3CbcSha1Kd) VerifyIntegrity(protocolKey, ct, pt []byte, usage uint32) bool {
-	return VerifyIntegrity(protocolKey, ct, pt, usage, e)
+	return engine.VerifyIntegrity(protocolKey, ct, pt, usage, e)
 }

+ 1 - 1
crypto/des3-cbc-sha1-kd_test.go → crypto/des3/des3-cbc-sha1-kd_test.go

@@ -1,4 +1,4 @@
-package crypto
+package des3
 
 import (
 	"encoding/hex"

+ 167 - 0
crypto/engine/engine.go

@@ -0,0 +1,167 @@
+package engine
+
+import (
+	"bytes"
+	"crypto/hmac"
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"github.com/jcmturner/gokrb5/crypto/etype"
+)
+
+// RFC3961: DR(Key, Constant) = k-truncate(E(Key, Constant, initial-cipher-state))
+// key - base key or protocol key. Likely to be a key from a keytab file
+// usage - a constant
+// n - block size in bits (not bytes) - note if you use something like aes.BlockSize this is in bytes.
+// k - key length / key seed length in bits. Eg. for AES256 this value is 256
+// encrypt - the encryption function to use
+func DeriveRandom(key, usage []byte, n, k int, e etype.EType) ([]byte, error) {
+	//Ensure the usage constant is at least the size of the cypher block size. Pass it through the nfold algorithm that will "stretch" it if needs be.
+	nFoldUsage := Nfold(usage, n)
+	//k-truncate implemented by creating a byte array the size of k (k is in bits hence /8)
+	out := make([]byte, k/8)
+
+	/*If the output	of E is shorter than k bits, it is fed back into the encryption as many times as necessary.
+	The construct is as follows (where | indicates concatentation):
+
+	K1 = E(Key, n-fold(Constant), initial-cipher-state)
+	K2 = E(Key, K1, initial-cipher-state)
+	K3 = E(Key, K2, initial-cipher-state)
+	K4 = ...
+
+	DR(Key, Constant) = k-truncate(K1 | K2 | K3 | K4 ...)*/
+	_, K, err := e.Encrypt(key, nFoldUsage)
+	if err != nil {
+		return out, err
+	}
+	for i := copy(out, K); i < len(out); {
+		_, K, _ = e.Encrypt(key, K)
+		i = i + copy(out[i:], K)
+	}
+	return out, nil
+}
+
+func ZeroPad(b []byte, m int) ([]byte, error) {
+	if m <= 0 {
+		return nil, errors.New("Invalid message block size when padding")
+	}
+	if b == nil || len(b) == 0 {
+		return nil, errors.New("Data not valid to pad: Zero size")
+	}
+	if l := len(b) % m; l != 0 {
+		n := m - l
+		z := make([]byte, n)
+		b = append(b, z...)
+	}
+	return b, nil
+}
+
+func PKCS7Pad(b []byte, m int) ([]byte, error) {
+	if m <= 0 {
+		return nil, errors.New("Invalid message block size when padding")
+	}
+	if b == nil || len(b) == 0 {
+		return nil, errors.New("Data not valid to pad: Zero size")
+	}
+	n := m - (len(b) % m)
+	pb := make([]byte, len(b)+n)
+	copy(pb, b)
+	copy(pb[len(b):], bytes.Repeat([]byte{byte(n)}, n))
+	return pb, nil
+}
+
+func PKCS7Unpad(b []byte, m int) ([]byte, error) {
+	if m <= 0 {
+		return nil, errors.New("Invalid message block size when unpadding")
+	}
+	if b == nil || len(b) == 0 {
+		return nil, errors.New("Padded data not valid: Zero size")
+	}
+	if len(b)%m != 0 {
+		return nil, errors.New("Padded data not valid: Not multiple of message block size")
+	}
+	c := b[len(b)-1]
+	n := int(c)
+	if n == 0 || n > len(b) {
+		return nil, errors.New("Padded data not valid: Data may not have been padded")
+	}
+	for i := 0; i < n; i++ {
+		if b[len(b)-n+i] != c {
+			return nil, errors.New("Padded data not valid")
+		}
+	}
+	return b[:len(b)-n], nil
+}
+
+func getHash(pt, key []byte, usage []byte, etype etype.EType) ([]byte, error) {
+	k, err := etype.DeriveKey(key, usage)
+	if err != nil {
+		return nil, fmt.Errorf("Unable to derive key for checksum: %v", err)
+	}
+	mac := hmac.New(etype.GetHash, k)
+	p := make([]byte, len(pt))
+	copy(p, pt)
+	mac.Write(p)
+	return mac.Sum(nil)[:etype.GetHMACBitLength()/8], nil
+}
+
+func GetChecksumHash(pt, key []byte, usage uint32, etype etype.EType) ([]byte, error) {
+	return getHash(pt, key, GetUsageKc(usage), etype)
+}
+
+func GetIntegrityHash(pt, key []byte, usage uint32, etype etype.EType) ([]byte, error) {
+	return getHash(pt, key, GetUsageKi(usage), etype)
+}
+
+func VerifyIntegrity(key, ct, pt []byte, usage uint32, etype etype.EType) bool {
+	//The ciphertext output is the concatenation of the output of the basic
+	//encryption function E and a (possibly truncated) HMAC using the
+	//specified hash function H, both applied to the plaintext with a
+	//random confounder prefix and sufficient padding to bring it to a
+	//multiple of the message block size.  When the HMAC is computed, the
+	//key is used in the protocol key form.
+	h := make([]byte, etype.GetHMACBitLength()/8)
+	copy(h, ct[len(ct)-etype.GetHMACBitLength()/8:])
+	expectedMAC, _ := GetIntegrityHash(pt, key, usage, etype)
+	return hmac.Equal(h, expectedMAC)
+}
+
+func VerifyChecksum(key, chksum, msg []byte, usage uint32, etype etype.EType) bool {
+	//The ciphertext output is the concatenation of the output of the basic
+	//encryption function E and a (possibly truncated) HMAC using the
+	//specified hash function H, both applied to the plaintext with a
+	//random confounder prefix and sufficient padding to bring it to a
+	//multiple of the message block size.  When the HMAC is computed, the
+	//key is used in the protocol key form.
+	expectedMAC, _ := GetChecksumHash(msg, key, usage, etype)
+	return hmac.Equal(chksum, expectedMAC)
+}
+
+/*
+Key Usage Numbers
+RFC 3961: The "well-known constant" used for the DK function is the key usage number, expressed as four octets in big-endian order, followed by one octet indicated below.
+Kc = DK(base-key, usage | 0x99);
+Ke = DK(base-key, usage | 0xAA);
+Ki = DK(base-key, usage | 0x55);
+*/
+
+// un - usage number
+func GetUsageKc(un uint32) []byte {
+	return getUsage(un, 0x99)
+}
+
+// un - usage number
+func GetUsageKe(un uint32) []byte {
+	return getUsage(un, 0xAA)
+}
+
+// un - usage number
+func GetUsageKi(un uint32) []byte {
+	return getUsage(un, 0x55)
+}
+
+func getUsage(un uint32, o byte) []byte {
+	var buf bytes.Buffer
+	binary.Write(&buf, binary.BigEndian, un)
+	return append(buf.Bytes(), o)
+}

+ 1 - 1
crypto/nfold.go → crypto/engine/nfold.go

@@ -1,4 +1,4 @@
-package crypto
+package engine
 
 /*
 Implementation of the n-fold algorithm as defined in RFC 3961.

+ 1 - 1
crypto/nfold_test.go → crypto/engine/nfold_test.go

@@ -1,4 +1,4 @@
-package crypto
+package engine
 
 import (
 	"encoding/hex"

+ 23 - 0
crypto/etype/etype.go

@@ -0,0 +1,23 @@
+package etype
+
+import "hash"
+
+type EType interface {
+	GetETypeID() int
+	GetHashID() int
+	GetKeyByteSize() int                                        // See "protocol key format" for defined values
+	GetKeySeedBitLength() int                                   // key-generation seed length, k
+	GetDefaultStringToKeyParams() string                        // default string-to-key parameters (s2kparams)
+	StringToKey(string, salt, s2kparams string) ([]byte, error) // string-to-key (UTF-8 string, UTF-8 string, opaque)->(protocol-key)
+	RandomToKey(b []byte) []byte                                // random-to-key (bitstring[K])->(protocol-key)
+	GetHMACBitLength() int                                      // HMAC output size, h
+	GetMessageBlockByteSize() int                               // message block size, m
+	Encrypt(key, message []byte) ([]byte, []byte, error)        // E function - encrypt (specific-key, state, octet string)->(state, octet string)
+	Decrypt(key, ciphertext []byte) ([]byte, error)             // D function
+	GetCypherBlockBitLength() int                               // cipher block size, c
+	GetConfounderByteSize() int                                 // This is the same as the cipher block size but in bytes.
+	DeriveKey(protocolKey, usage []byte) ([]byte, error)        // DK key-derivation (protocol-key, integer)->(specific-key)
+	DeriveRandom(protocolKey, usage []byte) ([]byte, error)     // DR pseudo-random (protocol-key, octet-string)->(octet-string)
+	VerifyIntegrity(protocolKey, ct, pt []byte, usage uint32) bool
+	GetHash() hash.Hash
+}

+ 1 - 1
iana/constants.go

@@ -1,4 +1,4 @@
-// Assigned numbers
+// Assigned numbers.
 package iana
 
 const (

+ 1 - 1
iana/etype/constants.go → iana/etypeID/constants.go

@@ -1,4 +1,4 @@
-package etype
+package etypeID
 
 const (
 	//RESERVED : 0

+ 3 - 9
messages/KDCRep.go

@@ -137,11 +137,9 @@ func (e *EncKDCRepPart) Unmarshal(b []byte) error {
 }
 
 func (k *ASRep) DecryptEncPart(c *credentials.Credentials) error {
-	var etype crypto.EType
 	var key types.EncryptionKey
 	var err error
 	if c.HasKeytab() {
-		etype, err = crypto.GetEtype(k.EncPart.EType)
 		if err != nil {
 			return fmt.Errorf("Error getting encryption type: %v", err)
 		}
@@ -151,7 +149,7 @@ func (k *ASRep) DecryptEncPart(c *credentials.Credentials) error {
 		}
 	}
 	if c.HasPassword() {
-		key, etype, err = crypto.GetKeyFromPassword(c.Password, k.CName, k.CRealm, k.EncPart.EType, k.PAData)
+		key, _, err = crypto.GetKeyFromPassword(c.Password, k.CName, k.CRealm, k.EncPart.EType, k.PAData)
 		if err != nil {
 			return fmt.Errorf("Could not derive key from password: %v", err)
 		}
@@ -159,7 +157,7 @@ func (k *ASRep) DecryptEncPart(c *credentials.Credentials) error {
 	if !c.HasKeytab() && !c.HasPassword() {
 		return errors.New("No secret available in credentials to preform decryption")
 	}
-	b, err := crypto.DecryptEncPart(key.KeyValue, k.EncPart, etype, keyusage.AS_REP_ENCPART)
+	b, err := crypto.DecryptEncPart(k.EncPart, key, keyusage.AS_REP_ENCPART)
 	if err != nil {
 		return fmt.Errorf("Error decrypting KDC_REP EncPart: %v", err)
 	}
@@ -232,11 +230,7 @@ func (k *ASRep) IsValid(cfg *config.Config, asReq ASReq) (bool, error) {
 }
 
 func (k *TGSRep) DecryptEncPart(key types.EncryptionKey) error {
-	etype, err := crypto.GetEtype(key.KeyType)
-	if err != nil {
-		return fmt.Errorf("Could not get etype: %v", err)
-	}
-	b, err := crypto.DecryptEncPart(key.KeyValue, k.EncPart, etype, keyusage.TGS_REP_ENCPART_SESSION_KEY)
+	b, err := crypto.DecryptEncPart(k.EncPart, key, keyusage.TGS_REP_ENCPART_SESSION_KEY)
 	if err != nil {
 		return fmt.Errorf("Error decrypting KDC_REP EncPart: %v", err)
 	}

+ 5 - 5
messages/KDCRep_test.go

@@ -4,7 +4,7 @@ import (
 	"encoding/hex"
 	"fmt"
 	"github.com/jcmturner/gokrb5/credentials"
-	"github.com/jcmturner/gokrb5/iana/etype"
+	"github.com/jcmturner/gokrb5/iana/etypeID"
 	"github.com/jcmturner/gokrb5/iana/msgtype"
 	"github.com/jcmturner/gokrb5/keytab"
 	"github.com/jcmturner/gokrb5/testdata"
@@ -239,9 +239,9 @@ func TestUnmarshalASRepDecodeAndDecrypt(t *testing.T) {
 	assert.Equal(t, 2, asRep.Ticket.SName.NameType, "Ticket service nametype not as expected")
 	assert.Equal(t, "krbtgt", asRep.Ticket.SName.NameString[0], "Ticket service name string not as expected")
 	assert.Equal(t, test_realm, asRep.Ticket.SName.NameString[1], "Ticket service name string not as expected")
-	assert.Equal(t, etype.ETypesByName["aes256-cts-hmac-sha1-96"], asRep.Ticket.EncPart.EType, "Etype of ticket encrypted part not as expected")
+	assert.Equal(t, etypeID.ETypesByName["aes256-cts-hmac-sha1-96"], asRep.Ticket.EncPart.EType, "Etype of ticket encrypted part not as expected")
 	assert.Equal(t, 1, asRep.Ticket.EncPart.KVNO, "Ticket encrypted part KVNO not as expected")
-	assert.Equal(t, etype.ETypesByName["aes256-cts-hmac-sha1-96"], asRep.EncPart.EType, "Etype of encrypted part not as expected")
+	assert.Equal(t, etypeID.ETypesByName["aes256-cts-hmac-sha1-96"], asRep.EncPart.EType, "Etype of encrypted part not as expected")
 	assert.Equal(t, 0, asRep.EncPart.KVNO, "Encrypted part KVNO not as expected")
 	//t.Log("Finished testing unecrypted parts of AS REP")
 	ktb, _ := hex.DecodeString(testuser1_etype18_keytab)
@@ -293,9 +293,9 @@ func TestUnmarshalASRepDecodeAndDecrypt_withPassword(t *testing.T) {
 	assert.Equal(t, 2, asRep.Ticket.SName.NameType, "Ticket service nametype not as expected")
 	assert.Equal(t, "krbtgt", asRep.Ticket.SName.NameString[0], "Ticket service name string not as expected")
 	assert.Equal(t, test_realm, asRep.Ticket.SName.NameString[1], "Ticket service name string not as expected")
-	assert.Equal(t, etype.AES256_CTS_HMAC_SHA1_96, asRep.Ticket.EncPart.EType, "Etype of ticket encrypted part not as expected")
+	assert.Equal(t, etypeID.AES256_CTS_HMAC_SHA1_96, asRep.Ticket.EncPart.EType, "Etype of ticket encrypted part not as expected")
 	assert.Equal(t, 1, asRep.Ticket.EncPart.KVNO, "Ticket encrypted part KVNO not as expected")
-	assert.Equal(t, etype.AES256_CTS_HMAC_SHA1_96, asRep.EncPart.EType, "Etype of encrypted part not as expected")
+	assert.Equal(t, etypeID.AES256_CTS_HMAC_SHA1_96, asRep.EncPart.EType, "Etype of encrypted part not as expected")
 	assert.Equal(t, 0, asRep.EncPart.KVNO, "Encrypted part KVNO not as expected")
 	cred := credentials.NewCredentials(test_user, test_realm)
 	err = asRep.DecryptEncPart(cred.WithPassword(test_user_password))

+ 2 - 1
messages/KDCReq.go

@@ -9,6 +9,7 @@ import (
 	"github.com/jcmturner/gokrb5/asn1tools"
 	"github.com/jcmturner/gokrb5/config"
 	"github.com/jcmturner/gokrb5/crypto"
+	"github.com/jcmturner/gokrb5/crypto/engine"
 	"github.com/jcmturner/gokrb5/iana"
 	"github.com/jcmturner/gokrb5/iana/asnAppTag"
 	"github.com/jcmturner/gokrb5/iana/keyusage"
@@ -171,7 +172,7 @@ func NewTGSReq(username string, c *config.Config, tkt types.Ticket, sessionKey t
 	if err != nil {
 		return a, fmt.Errorf("Error getting etype to encrypt authenticator: %v", err)
 	}
-	cb, err := crypto.GetChecksumHash(b, sessionKey.KeyValue, keyusage.TGS_REQ_PA_TGS_REQ_AP_REQ_AUTHENTICATOR_CHKSUM, etype)
+	cb, err := engine.GetChecksumHash(b, sessionKey.KeyValue, keyusage.TGS_REQ_PA_TGS_REQ_AP_REQ_AUTHENTICATOR_CHKSUM, etype)
 	auth.Cksum = types.Checksum{
 		CksumType: etype.GetHashID(),
 		Checksum:  cb,

+ 2 - 6
messages/KRBCred.go

@@ -71,12 +71,8 @@ func (k *KRBCred) Unmarshal(b []byte) error {
 	return nil
 }
 
-func (k *KRBCred) DecryptEncPart(key []byte) error {
-	etype, err := crypto.GetEtype(k.EncPart.EType)
-	if err != nil {
-		return fmt.Errorf("Keytab error: %v", err)
-	}
-	b, err := crypto.DecryptEncPart(key, k.EncPart, etype, keyusage.KRB_CRED_ENCPART)
+func (k *KRBCred) DecryptEncPart(key types.EncryptionKey) error {
+	b, err := crypto.DecryptEncPart(k.EncPart, key, keyusage.KRB_CRED_ENCPART)
 	if err != nil {
 		return fmt.Errorf("Error decrypting KDC_REP EncPart: %v", err)
 	}