rsa.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package codec
  2. import (
  3. "crypto/rand"
  4. "crypto/rsa"
  5. "crypto/x509"
  6. "encoding/base64"
  7. "encoding/pem"
  8. "errors"
  9. "io/ioutil"
  10. )
  11. var (
  12. // ErrPrivateKey indicates the invalid private key.
  13. ErrPrivateKey = errors.New("private key error")
  14. // ErrPublicKey indicates the invalid public key.
  15. ErrPublicKey = errors.New("failed to parse PEM block containing the public key")
  16. // ErrNotRsaKey indicates the invalid RSA key.
  17. ErrNotRsaKey = errors.New("key type is not RSA")
  18. )
  19. type (
  20. // RsaDecrypter represents a RSA decrypter.
  21. RsaDecrypter interface {
  22. Decrypt(input []byte) ([]byte, error)
  23. DecryptBase64(input string) ([]byte, error)
  24. }
  25. // RsaEncrypter represents a RSA encrypter.
  26. RsaEncrypter interface {
  27. Encrypt(input []byte) ([]byte, error)
  28. }
  29. rsaBase struct {
  30. bytesLimit int
  31. }
  32. rsaDecrypter struct {
  33. rsaBase
  34. privateKey *rsa.PrivateKey
  35. }
  36. rsaEncrypter struct {
  37. rsaBase
  38. publicKey *rsa.PublicKey
  39. }
  40. )
  41. // NewRsaDecrypter returns a RsaDecrypter with the given file.
  42. func NewRsaDecrypter(file string) (RsaDecrypter, error) {
  43. content, err := ioutil.ReadFile(file)
  44. if err != nil {
  45. return nil, err
  46. }
  47. block, _ := pem.Decode(content)
  48. if block == nil {
  49. return nil, ErrPrivateKey
  50. }
  51. privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
  52. if err != nil {
  53. return nil, err
  54. }
  55. return &rsaDecrypter{
  56. rsaBase: rsaBase{
  57. bytesLimit: privateKey.N.BitLen() >> 3,
  58. },
  59. privateKey: privateKey,
  60. }, nil
  61. }
  62. func (r *rsaDecrypter) Decrypt(input []byte) ([]byte, error) {
  63. return r.crypt(input, func(block []byte) ([]byte, error) {
  64. return rsaDecryptBlock(r.privateKey, block)
  65. })
  66. }
  67. func (r *rsaDecrypter) DecryptBase64(input string) ([]byte, error) {
  68. if len(input) == 0 {
  69. return nil, nil
  70. }
  71. base64Decoded, err := base64.StdEncoding.DecodeString(input)
  72. if err != nil {
  73. return nil, err
  74. }
  75. return r.Decrypt(base64Decoded)
  76. }
  77. // NewRsaEncrypter returns a RsaEncrypter with the given key.
  78. func NewRsaEncrypter(key []byte) (RsaEncrypter, error) {
  79. block, _ := pem.Decode(key)
  80. if block == nil {
  81. return nil, ErrPublicKey
  82. }
  83. pub, err := x509.ParsePKIXPublicKey(block.Bytes)
  84. if err != nil {
  85. return nil, err
  86. }
  87. switch pubKey := pub.(type) {
  88. case *rsa.PublicKey:
  89. return &rsaEncrypter{
  90. rsaBase: rsaBase{
  91. // https://www.ietf.org/rfc/rfc2313.txt
  92. // The length of the data D shall not be more than k-11 octets, which is
  93. // positive since the length k of the modulus is at least 12 octets.
  94. bytesLimit: (pubKey.N.BitLen() >> 3) - 11,
  95. },
  96. publicKey: pubKey,
  97. }, nil
  98. default:
  99. return nil, ErrNotRsaKey
  100. }
  101. }
  102. func (r *rsaEncrypter) Encrypt(input []byte) ([]byte, error) {
  103. return r.crypt(input, func(block []byte) ([]byte, error) {
  104. return rsaEncryptBlock(r.publicKey, block)
  105. })
  106. }
  107. func (r *rsaBase) crypt(input []byte, cryptFn func([]byte) ([]byte, error)) ([]byte, error) {
  108. var result []byte
  109. inputLen := len(input)
  110. for i := 0; i*r.bytesLimit < inputLen; i++ {
  111. start := r.bytesLimit * i
  112. var stop int
  113. if r.bytesLimit*(i+1) > inputLen {
  114. stop = inputLen
  115. } else {
  116. stop = r.bytesLimit * (i + 1)
  117. }
  118. bs, err := cryptFn(input[start:stop])
  119. if err != nil {
  120. return nil, err
  121. }
  122. result = append(result, bs...)
  123. }
  124. return result, nil
  125. }
  126. func rsaDecryptBlock(privateKey *rsa.PrivateKey, block []byte) ([]byte, error) {
  127. return rsa.DecryptPKCS1v15(rand.Reader, privateKey, block)
  128. }
  129. func rsaEncryptBlock(publicKey *rsa.PublicKey, msg []byte) ([]byte, error) {
  130. return rsa.EncryptPKCS1v15(rand.Reader, publicKey, msg)
  131. }