Merge pull request #21 from Snider/refactor-rsa-improvements

Refactor(crypt): Improve RSA safety and flexibility
This commit is contained in:
Snider 2025-11-02 16:00:29 +00:00 committed by GitHub
commit 56f28c1ea5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 31 additions and 14 deletions

View file

@ -137,17 +137,27 @@ func (s *Service) Fletcher64(payload string) uint64 {
// --- RSA ---
// ensureRSA initializes the RSA service if it is not already.
func (s *Service) ensureRSA() {
if s.rsa == nil {
s.rsa = rsa.NewService()
}
}
// GenerateRSAKeyPair creates a new RSA key pair.
func (s *Service) GenerateRSAKeyPair(bits int) (publicKey, privateKey []byte, err error) {
s.ensureRSA()
return s.rsa.GenerateKeyPair(bits)
}
// EncryptRSA encrypts data with a public key.
func (s *Service) EncryptRSA(publicKey, data []byte) ([]byte, error) {
return s.rsa.Encrypt(publicKey, data)
func (s *Service) EncryptRSA(publicKey, data, label []byte) ([]byte, error) {
s.ensureRSA()
return s.rsa.Encrypt(publicKey, data, label)
}
// DecryptRSA decrypts data with a private key.
func (s *Service) DecryptRSA(privateKey, ciphertext []byte) ([]byte, error) {
return s.rsa.Decrypt(privateKey, ciphertext)
func (s *Service) DecryptRSA(privateKey, ciphertext, label []byte) ([]byte, error) {
s.ensureRSA()
return s.rsa.Decrypt(privateKey, ciphertext, label)
}

View file

@ -19,6 +19,9 @@ func NewService() *Service {
// GenerateKeyPair creates a new RSA key pair.
func (s *Service) GenerateKeyPair(bits int) (publicKey, privateKey []byte, err error) {
if bits < 2048 {
return nil, nil, fmt.Errorf("rsa: key size too small: %d (minimum 2048)", bits)
}
privKey, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
@ -43,7 +46,7 @@ func (s *Service) GenerateKeyPair(bits int) (publicKey, privateKey []byte, err e
}
// Encrypt encrypts data with a public key.
func (s *Service) Encrypt(publicKey, data []byte) ([]byte, error) {
func (s *Service) Encrypt(publicKey, data, label []byte) ([]byte, error) {
block, _ := pem.Decode(publicKey)
if block == nil {
return nil, fmt.Errorf("failed to decode public key")
@ -59,7 +62,7 @@ func (s *Service) Encrypt(publicKey, data []byte) ([]byte, error) {
return nil, fmt.Errorf("not an RSA public key")
}
ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, rsaPub, data, nil)
ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, rsaPub, data, label)
if err != nil {
return nil, fmt.Errorf("failed to encrypt data: %w", err)
}
@ -68,7 +71,7 @@ func (s *Service) Encrypt(publicKey, data []byte) ([]byte, error) {
}
// Decrypt decrypts data with a private key.
func (s *Service) Decrypt(privateKey, ciphertext []byte) ([]byte, error) {
func (s *Service) Decrypt(privateKey, ciphertext, label []byte) ([]byte, error) {
block, _ := pem.Decode(privateKey)
if block == nil {
return nil, fmt.Errorf("failed to decode private key")
@ -79,7 +82,7 @@ func (s *Service) Decrypt(privateKey, ciphertext []byte) ([]byte, error) {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, priv, ciphertext, nil)
plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, priv, ciphertext, label)
if err != nil {
return nil, fmt.Errorf("failed to decrypt data: %w", err)
}

View file

@ -17,9 +17,9 @@ func TestRSA_Good(t *testing.T) {
// Encrypt and decrypt a message
message := []byte("Hello, World!")
ciphertext, err := s.Encrypt(pubKey, message)
ciphertext, err := s.Encrypt(pubKey, message, nil)
assert.NoError(t, err)
plaintext, err := s.Decrypt(privKey, ciphertext)
plaintext, err := s.Decrypt(privKey, ciphertext, nil)
assert.NoError(t, err)
assert.Equal(t, message, plaintext)
}
@ -33,9 +33,13 @@ func TestRSA_Bad(t *testing.T) {
_, otherPrivKey, err := s.GenerateKeyPair(2048)
assert.NoError(t, err)
message := []byte("Hello, World!")
ciphertext, err := s.Encrypt(pubKey, message)
ciphertext, err := s.Encrypt(pubKey, message, nil)
assert.NoError(t, err)
_, err = s.Decrypt(otherPrivKey, ciphertext)
_, err = s.Decrypt(otherPrivKey, ciphertext, nil)
assert.Error(t, err)
// Key size too small
_, _, err = s.GenerateKeyPair(512)
assert.Error(t, err)
}
@ -43,8 +47,8 @@ func TestRSA_Ugly(t *testing.T) {
s := NewService()
// Malformed keys and messages
_, err := s.Encrypt([]byte("not-a-key"), []byte("message"))
_, err := s.Encrypt([]byte("not-a-key"), []byte("message"), nil)
assert.Error(t, err)
_, err = s.Decrypt([]byte("not-a-key"), []byte("message"))
_, err = s.Decrypt([]byte("not-a-key"), []byte("message"), nil)
assert.Error(t, err)
}