rsa.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package codec
  2. import (
  3. "crypto/rand"
  4. "crypto/rsa"
  5. "crypto/x509"
  6. "encoding/base64"
  7. "encoding/pem"
  8. "errors"
  9. "io/ioutil"
  10. )
  11. var (
  12. ErrPrivateKey = errors.New("private key error")
  13. ErrPublicKey = errors.New("failed to parse PEM block containing the public key")
  14. ErrNotRsaKey = errors.New("key type is not RSA")
  15. )
  16. type (
  17. RsaDecrypter interface {
  18. Decrypt(input []byte) ([]byte, error)
  19. DecryptBase64(input string) ([]byte, error)
  20. }
  21. RsaEncrypter interface {
  22. Encrypt(input []byte) ([]byte, error)
  23. }
  24. rsaBase struct {
  25. bytesLimit int
  26. }
  27. rsaDecrypter struct {
  28. rsaBase
  29. privateKey *rsa.PrivateKey
  30. }
  31. rsaEncrypter struct {
  32. rsaBase
  33. publicKey *rsa.PublicKey
  34. }
  35. )
  36. func NewRsaDecrypter(file string) (RsaDecrypter, error) {
  37. content, err := ioutil.ReadFile(file)
  38. if err != nil {
  39. return nil, err
  40. }
  41. block, _ := pem.Decode(content)
  42. if block == nil {
  43. return nil, ErrPrivateKey
  44. }
  45. privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
  46. if err != nil {
  47. return nil, err
  48. }
  49. return &rsaDecrypter{
  50. rsaBase: rsaBase{
  51. bytesLimit: privateKey.N.BitLen() >> 3,
  52. },
  53. privateKey: privateKey,
  54. }, nil
  55. }
  56. func (r *rsaDecrypter) Decrypt(input []byte) ([]byte, error) {
  57. return r.crypt(input, func(block []byte) ([]byte, error) {
  58. return rsaDecryptBlock(r.privateKey, block)
  59. })
  60. }
  61. func (r *rsaDecrypter) DecryptBase64(input string) ([]byte, error) {
  62. if len(input) == 0 {
  63. return nil, nil
  64. }
  65. base64Decoded, err := base64.StdEncoding.DecodeString(input)
  66. if err != nil {
  67. return nil, err
  68. }
  69. return r.Decrypt(base64Decoded)
  70. }
  71. func NewRsaEncrypter(key []byte) (RsaEncrypter, error) {
  72. block, _ := pem.Decode(key)
  73. if block == nil {
  74. return nil, ErrPublicKey
  75. }
  76. pub, err := x509.ParsePKIXPublicKey(block.Bytes)
  77. if err != nil {
  78. return nil, err
  79. }
  80. switch pubKey := pub.(type) {
  81. case *rsa.PublicKey:
  82. return &rsaEncrypter{
  83. rsaBase: rsaBase{
  84. // https://www.ietf.org/rfc/rfc2313.txt
  85. // The length of the data D shall not be more than k-11 octets, which is
  86. // positive since the length k of the modulus is at least 12 octets.
  87. bytesLimit: (pubKey.N.BitLen() >> 3) - 11,
  88. },
  89. publicKey: pubKey,
  90. }, nil
  91. default:
  92. return nil, ErrNotRsaKey
  93. }
  94. }
  95. func (r *rsaEncrypter) Encrypt(input []byte) ([]byte, error) {
  96. return r.crypt(input, func(block []byte) ([]byte, error) {
  97. return rsaEncryptBlock(r.publicKey, block)
  98. })
  99. }
  100. func (r *rsaBase) crypt(input []byte, cryptFn func([]byte) ([]byte, error)) ([]byte, error) {
  101. var result []byte
  102. inputLen := len(input)
  103. for i := 0; i*r.bytesLimit < inputLen; i++ {
  104. start := r.bytesLimit * i
  105. var stop int
  106. if r.bytesLimit*(i+1) > inputLen {
  107. stop = inputLen
  108. } else {
  109. stop = r.bytesLimit * (i + 1)
  110. }
  111. bs, err := cryptFn(input[start:stop])
  112. if err != nil {
  113. return nil, err
  114. }
  115. result = append(result, bs...)
  116. }
  117. return result, nil
  118. }
  119. func rsaDecryptBlock(privateKey *rsa.PrivateKey, block []byte) ([]byte, error) {
  120. return rsa.DecryptPKCS1v15(rand.Reader, privateKey, block)
  121. }
  122. func rsaEncryptBlock(publicKey *rsa.PublicKey, msg []byte) ([]byte, error) {
  123. return rsa.EncryptPKCS1v15(rand.Reader, publicKey, msg)
  124. }