aesecb.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package codec
  2. import (
  3. "bytes"
  4. "crypto/aes"
  5. "crypto/cipher"
  6. "encoding/base64"
  7. "errors"
  8. "github.com/tal-tech/go-zero/core/logx"
  9. )
  10. var ErrPaddingSize = errors.New("padding size error")
  11. type ecb struct {
  12. b cipher.Block
  13. blockSize int
  14. }
  15. func newECB(b cipher.Block) *ecb {
  16. return &ecb{
  17. b: b,
  18. blockSize: b.BlockSize(),
  19. }
  20. }
  21. type ecbEncrypter ecb
  22. func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
  23. return (*ecbEncrypter)(newECB(b))
  24. }
  25. func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
  26. // why we don't return error is because cipher.BlockMode doesn't allow this
  27. func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
  28. if len(src)%x.blockSize != 0 {
  29. logx.Error("crypto/cipher: input not full blocks")
  30. return
  31. }
  32. if len(dst) < len(src) {
  33. logx.Error("crypto/cipher: output smaller than input")
  34. return
  35. }
  36. for len(src) > 0 {
  37. x.b.Encrypt(dst, src[:x.blockSize])
  38. src = src[x.blockSize:]
  39. dst = dst[x.blockSize:]
  40. }
  41. }
  42. type ecbDecrypter ecb
  43. func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
  44. return (*ecbDecrypter)(newECB(b))
  45. }
  46. func (x *ecbDecrypter) BlockSize() int {
  47. return x.blockSize
  48. }
  49. // why we don't return error is because cipher.BlockMode doesn't allow this
  50. func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
  51. if len(src)%x.blockSize != 0 {
  52. logx.Error("crypto/cipher: input not full blocks")
  53. return
  54. }
  55. if len(dst) < len(src) {
  56. logx.Error("crypto/cipher: output smaller than input")
  57. return
  58. }
  59. for len(src) > 0 {
  60. x.b.Decrypt(dst, src[:x.blockSize])
  61. src = src[x.blockSize:]
  62. dst = dst[x.blockSize:]
  63. }
  64. }
  65. func EcbDecrypt(key, src []byte) ([]byte, error) {
  66. block, err := aes.NewCipher(key)
  67. if err != nil {
  68. logx.Errorf("Decrypt key error: % x", key)
  69. return nil, err
  70. }
  71. decrypter := NewECBDecrypter(block)
  72. decrypted := make([]byte, len(src))
  73. decrypter.CryptBlocks(decrypted, src)
  74. return pkcs5Unpadding(decrypted, decrypter.BlockSize())
  75. }
  76. func EcbDecryptBase64(key, src string) (string, error) {
  77. keyBytes, err := getKeyBytes(key)
  78. if err != nil {
  79. return "", err
  80. }
  81. encryptedBytes, err := base64.StdEncoding.DecodeString(src)
  82. if err != nil {
  83. return "", err
  84. }
  85. decryptedBytes, err := EcbDecrypt(keyBytes, encryptedBytes)
  86. if err != nil {
  87. return "", err
  88. }
  89. return base64.StdEncoding.EncodeToString(decryptedBytes), nil
  90. }
  91. func EcbEncrypt(key, src []byte) ([]byte, error) {
  92. block, err := aes.NewCipher(key)
  93. if err != nil {
  94. logx.Errorf("Encrypt key error: % x", key)
  95. return nil, err
  96. }
  97. padded := pkcs5Padding(src, block.BlockSize())
  98. crypted := make([]byte, len(padded))
  99. encrypter := NewECBEncrypter(block)
  100. encrypter.CryptBlocks(crypted, padded)
  101. return crypted, nil
  102. }
  103. func EcbEncryptBase64(key, src string) (string, error) {
  104. keyBytes, err := getKeyBytes(key)
  105. if err != nil {
  106. return "", err
  107. }
  108. srcBytes, err := base64.StdEncoding.DecodeString(src)
  109. if err != nil {
  110. return "", err
  111. }
  112. encryptedBytes, err := EcbEncrypt(keyBytes, srcBytes)
  113. if err != nil {
  114. return "", err
  115. }
  116. return base64.StdEncoding.EncodeToString(encryptedBytes), nil
  117. }
  118. func getKeyBytes(key string) ([]byte, error) {
  119. if len(key) <= 32 {
  120. return []byte(key), nil
  121. }
  122. if keyBytes, err := base64.StdEncoding.DecodeString(key); err != nil {
  123. return nil, err
  124. } else {
  125. return keyBytes, nil
  126. }
  127. }
  128. func pkcs5Padding(ciphertext []byte, blockSize int) []byte {
  129. padding := blockSize - len(ciphertext)%blockSize
  130. padtext := bytes.Repeat([]byte{byte(padding)}, padding)
  131. return append(ciphertext, padtext...)
  132. }
  133. func pkcs5Unpadding(src []byte, blockSize int) ([]byte, error) {
  134. length := len(src)
  135. unpadding := int(src[length-1])
  136. if unpadding >= length || unpadding > blockSize {
  137. return nil, ErrPaddingSize
  138. }
  139. return src[:length-unpadding], nil
  140. }