123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- package codec
- import (
- "bytes"
- "crypto/aes"
- "crypto/cipher"
- "encoding/base64"
- "errors"
- "github.com/tal-tech/go-zero/core/logx"
- )
- // ErrPaddingSize indicates bad padding size.
- var ErrPaddingSize = errors.New("padding size error")
- type ecb struct {
- b cipher.Block
- blockSize int
- }
- func newECB(b cipher.Block) *ecb {
- return &ecb{
- b: b,
- blockSize: b.BlockSize(),
- }
- }
- type ecbEncrypter ecb
- // NewECBEncrypter returns an ECB encrypter.
- func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
- return (*ecbEncrypter)(newECB(b))
- }
- func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
- // why we don't return error is because cipher.BlockMode doesn't allow this
- func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
- if len(src)%x.blockSize != 0 {
- logx.Error("crypto/cipher: input not full blocks")
- return
- }
- if len(dst) < len(src) {
- logx.Error("crypto/cipher: output smaller than input")
- return
- }
- for len(src) > 0 {
- x.b.Encrypt(dst, src[:x.blockSize])
- src = src[x.blockSize:]
- dst = dst[x.blockSize:]
- }
- }
- type ecbDecrypter ecb
- // NewECBDecrypter returns an ECB decrypter.
- func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
- return (*ecbDecrypter)(newECB(b))
- }
- func (x *ecbDecrypter) BlockSize() int {
- return x.blockSize
- }
- // why we don't return error is because cipher.BlockMode doesn't allow this
- func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
- if len(src)%x.blockSize != 0 {
- logx.Error("crypto/cipher: input not full blocks")
- return
- }
- if len(dst) < len(src) {
- logx.Error("crypto/cipher: output smaller than input")
- return
- }
- for len(src) > 0 {
- x.b.Decrypt(dst, src[:x.blockSize])
- src = src[x.blockSize:]
- dst = dst[x.blockSize:]
- }
- }
- // EcbDecrypt decrypts src with the given key.
- func EcbDecrypt(key, src []byte) ([]byte, error) {
- block, err := aes.NewCipher(key)
- if err != nil {
- logx.Errorf("Decrypt key error: % x", key)
- return nil, err
- }
- decrypter := NewECBDecrypter(block)
- decrypted := make([]byte, len(src))
- decrypter.CryptBlocks(decrypted, src)
- return pkcs5Unpadding(decrypted, decrypter.BlockSize())
- }
- // EcbDecryptBase64 decrypts base64 encoded src with the given base64 encoded key.
- // The returned string is also base64 encoded.
- func EcbDecryptBase64(key, src string) (string, error) {
- keyBytes, err := getKeyBytes(key)
- if err != nil {
- return "", err
- }
- encryptedBytes, err := base64.StdEncoding.DecodeString(src)
- if err != nil {
- return "", err
- }
- decryptedBytes, err := EcbDecrypt(keyBytes, encryptedBytes)
- if err != nil {
- return "", err
- }
- return base64.StdEncoding.EncodeToString(decryptedBytes), nil
- }
- // EcbEncrypt encrypts src with the given key.
- func EcbEncrypt(key, src []byte) ([]byte, error) {
- block, err := aes.NewCipher(key)
- if err != nil {
- logx.Errorf("Encrypt key error: % x", key)
- return nil, err
- }
- padded := pkcs5Padding(src, block.BlockSize())
- crypted := make([]byte, len(padded))
- encrypter := NewECBEncrypter(block)
- encrypter.CryptBlocks(crypted, padded)
- return crypted, nil
- }
- // EcbEncryptBase64 encrypts base64 encoded src with the given base64 encoded key.
- // The returned string is also base64 encoded.
- func EcbEncryptBase64(key, src string) (string, error) {
- keyBytes, err := getKeyBytes(key)
- if err != nil {
- return "", err
- }
- srcBytes, err := base64.StdEncoding.DecodeString(src)
- if err != nil {
- return "", err
- }
- encryptedBytes, err := EcbEncrypt(keyBytes, srcBytes)
- if err != nil {
- return "", err
- }
- return base64.StdEncoding.EncodeToString(encryptedBytes), nil
- }
- func getKeyBytes(key string) ([]byte, error) {
- if len(key) <= 32 {
- return []byte(key), nil
- }
- keyBytes, err := base64.StdEncoding.DecodeString(key)
- if err != nil {
- return nil, err
- }
- return keyBytes, nil
- }
- func pkcs5Padding(ciphertext []byte, blockSize int) []byte {
- padding := blockSize - len(ciphertext)%blockSize
- padtext := bytes.Repeat([]byte{byte(padding)}, padding)
- return append(ciphertext, padtext...)
- }
- func pkcs5Unpadding(src []byte, blockSize int) ([]byte, error) {
- length := len(src)
- unpadding := int(src[length-1])
- if unpadding >= length || unpadding > blockSize {
- return nil, ErrPaddingSize
- }
- return src[:length-unpadding], nil
- }
|