crypto.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. package util
  2. import (
  3. "crypto/aes"
  4. "crypto/cipher"
  5. "encoding/base64"
  6. "errors"
  7. "fmt"
  8. )
  9. //DecryptMsg 消息解密
  10. func DecryptMsg(appID, encryptedMsg, aesKey string) (rawMsgXMLBytes []byte, err error) {
  11. var encryptedMsgBytes, key, getAppIDBytes []byte
  12. encryptedMsgBytes, err = base64.StdEncoding.DecodeString(encryptedMsg)
  13. if err != nil {
  14. return
  15. }
  16. key, err = aesKeyDecode(aesKey)
  17. if err != nil {
  18. return
  19. }
  20. _, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key)
  21. if err != nil {
  22. err = fmt.Errorf("消息解密失败,%v", err)
  23. return
  24. }
  25. if appID != string(getAppIDBytes) {
  26. err = fmt.Errorf("消息解密校验APPID失败")
  27. return
  28. }
  29. return
  30. }
  31. func aesKeyDecode(encodedAESKey string) (key []byte, err error) {
  32. if len(encodedAESKey) != 43 {
  33. err = errors.New("the length of encodedAESKey must be equal to 43")
  34. return
  35. }
  36. key, err = base64.StdEncoding.DecodeString(encodedAESKey + "=")
  37. if err != nil {
  38. return
  39. }
  40. if len(key) != 32 {
  41. err = errors.New("encodingAESKey invalid")
  42. return
  43. }
  44. return
  45. }
  46. // AESDecryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
  47. func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID []byte, err error) {
  48. const (
  49. BlockSize = 32 // PKCS#7
  50. BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
  51. )
  52. if len(ciphertext) < BlockSize {
  53. err = fmt.Errorf("the length of ciphertext too short: %d", len(ciphertext))
  54. return
  55. }
  56. if len(ciphertext)&BlockMask != 0 {
  57. err = fmt.Errorf("ciphertext is not a multiple of the block size, the length is %d", len(ciphertext))
  58. return
  59. }
  60. plaintext := make([]byte, len(ciphertext)) // len(plaintext) >= BLOCK_SIZE
  61. // 解密
  62. block, err := aes.NewCipher(aesKey)
  63. if err != nil {
  64. panic(err)
  65. }
  66. mode := cipher.NewCBCDecrypter(block, aesKey[:16])
  67. mode.CryptBlocks(plaintext, ciphertext)
  68. // PKCS#7 去除补位
  69. amountToPad := int(plaintext[len(plaintext)-1])
  70. if amountToPad < 1 || amountToPad > BlockSize {
  71. err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad)
  72. return
  73. }
  74. plaintext = plaintext[:len(plaintext)-amountToPad]
  75. // 反拼接
  76. // len(plaintext) == 16+4+len(rawXMLMsg)+len(appId)
  77. if len(plaintext) <= 20 {
  78. err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext))
  79. return
  80. }
  81. rawXMLMsgLen := int(decodeNetworkByteOrder(plaintext[16:20]))
  82. if rawXMLMsgLen < 0 {
  83. err = fmt.Errorf("incorrect msg length: %d", rawXMLMsgLen)
  84. return
  85. }
  86. appIDOffset := 20 + rawXMLMsgLen
  87. if len(plaintext) <= appIDOffset {
  88. err = fmt.Errorf("msg length too large: %d", rawXMLMsgLen)
  89. return
  90. }
  91. random = plaintext[:16:20]
  92. rawXMLMsg = plaintext[20:appIDOffset:appIDOffset]
  93. appID = plaintext[appIDOffset:]
  94. return
  95. }
  96. // 从 4 字节的网络字节序里解析出整数
  97. func decodeNetworkByteOrder(orderBytes []byte) (n uint32) {
  98. return uint32(orderBytes[0])<<24 |
  99. uint32(orderBytes[1])<<16 |
  100. uint32(orderBytes[2])<<8 |
  101. uint32(orderBytes[3])
  102. }