Merge pull request #21 from Snider/refactor-rsa-improvements
Refactor(crypt): Improve RSA safety and flexibility
This commit is contained in:
commit
56f28c1ea5
3 changed files with 31 additions and 14 deletions
|
|
@ -137,17 +137,27 @@ func (s *Service) Fletcher64(payload string) uint64 {
|
|||
|
||||
// --- RSA ---
|
||||
|
||||
// ensureRSA initializes the RSA service if it is not already.
|
||||
func (s *Service) ensureRSA() {
|
||||
if s.rsa == nil {
|
||||
s.rsa = rsa.NewService()
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRSAKeyPair creates a new RSA key pair.
|
||||
func (s *Service) GenerateRSAKeyPair(bits int) (publicKey, privateKey []byte, err error) {
|
||||
s.ensureRSA()
|
||||
return s.rsa.GenerateKeyPair(bits)
|
||||
}
|
||||
|
||||
// EncryptRSA encrypts data with a public key.
|
||||
func (s *Service) EncryptRSA(publicKey, data []byte) ([]byte, error) {
|
||||
return s.rsa.Encrypt(publicKey, data)
|
||||
func (s *Service) EncryptRSA(publicKey, data, label []byte) ([]byte, error) {
|
||||
s.ensureRSA()
|
||||
return s.rsa.Encrypt(publicKey, data, label)
|
||||
}
|
||||
|
||||
// DecryptRSA decrypts data with a private key.
|
||||
func (s *Service) DecryptRSA(privateKey, ciphertext []byte) ([]byte, error) {
|
||||
return s.rsa.Decrypt(privateKey, ciphertext)
|
||||
func (s *Service) DecryptRSA(privateKey, ciphertext, label []byte) ([]byte, error) {
|
||||
s.ensureRSA()
|
||||
return s.rsa.Decrypt(privateKey, ciphertext, label)
|
||||
}
|
||||
|
|
@ -19,6 +19,9 @@ func NewService() *Service {
|
|||
|
||||
// GenerateKeyPair creates a new RSA key pair.
|
||||
func (s *Service) GenerateKeyPair(bits int) (publicKey, privateKey []byte, err error) {
|
||||
if bits < 2048 {
|
||||
return nil, nil, fmt.Errorf("rsa: key size too small: %d (minimum 2048)", bits)
|
||||
}
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
|
|
@ -43,7 +46,7 @@ func (s *Service) GenerateKeyPair(bits int) (publicKey, privateKey []byte, err e
|
|||
}
|
||||
|
||||
// Encrypt encrypts data with a public key.
|
||||
func (s *Service) Encrypt(publicKey, data []byte) ([]byte, error) {
|
||||
func (s *Service) Encrypt(publicKey, data, label []byte) ([]byte, error) {
|
||||
block, _ := pem.Decode(publicKey)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("failed to decode public key")
|
||||
|
|
@ -59,7 +62,7 @@ func (s *Service) Encrypt(publicKey, data []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("not an RSA public key")
|
||||
}
|
||||
|
||||
ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, rsaPub, data, nil)
|
||||
ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, rsaPub, data, label)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt data: %w", err)
|
||||
}
|
||||
|
|
@ -68,7 +71,7 @@ func (s *Service) Encrypt(publicKey, data []byte) ([]byte, error) {
|
|||
}
|
||||
|
||||
// Decrypt decrypts data with a private key.
|
||||
func (s *Service) Decrypt(privateKey, ciphertext []byte) ([]byte, error) {
|
||||
func (s *Service) Decrypt(privateKey, ciphertext, label []byte) ([]byte, error) {
|
||||
block, _ := pem.Decode(privateKey)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("failed to decode private key")
|
||||
|
|
@ -79,7 +82,7 @@ func (s *Service) Decrypt(privateKey, ciphertext []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, priv, ciphertext, nil)
|
||||
plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, priv, ciphertext, label)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt data: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ func TestRSA_Good(t *testing.T) {
|
|||
|
||||
// Encrypt and decrypt a message
|
||||
message := []byte("Hello, World!")
|
||||
ciphertext, err := s.Encrypt(pubKey, message)
|
||||
ciphertext, err := s.Encrypt(pubKey, message, nil)
|
||||
assert.NoError(t, err)
|
||||
plaintext, err := s.Decrypt(privKey, ciphertext)
|
||||
plaintext, err := s.Decrypt(privKey, ciphertext, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, message, plaintext)
|
||||
}
|
||||
|
|
@ -33,9 +33,13 @@ func TestRSA_Bad(t *testing.T) {
|
|||
_, otherPrivKey, err := s.GenerateKeyPair(2048)
|
||||
assert.NoError(t, err)
|
||||
message := []byte("Hello, World!")
|
||||
ciphertext, err := s.Encrypt(pubKey, message)
|
||||
ciphertext, err := s.Encrypt(pubKey, message, nil)
|
||||
assert.NoError(t, err)
|
||||
_, err = s.Decrypt(otherPrivKey, ciphertext)
|
||||
_, err = s.Decrypt(otherPrivKey, ciphertext, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Key size too small
|
||||
_, _, err = s.GenerateKeyPair(512)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -43,8 +47,8 @@ func TestRSA_Ugly(t *testing.T) {
|
|||
s := NewService()
|
||||
|
||||
// Malformed keys and messages
|
||||
_, err := s.Encrypt([]byte("not-a-key"), []byte("message"))
|
||||
_, err := s.Encrypt([]byte("not-a-key"), []byte("message"), nil)
|
||||
assert.Error(t, err)
|
||||
_, err = s.Decrypt([]byte("not-a-key"), []byte("message"))
|
||||
_, err = s.Decrypt([]byte("not-a-key"), []byte("message"), nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue