From e7aeb3c8b8f9088ac586b3c004185b51c9f27d94 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 2 Nov 2025 03:06:04 +0000 Subject: [PATCH] Refactor(crypt): Improve RSA safety and flexibility This commit introduces several improvements to the RSA implementation: - Preserves zero-value service safety by lazily initializing the RSA service in `pkg/crypt/crypt.go`. - Enforces a minimum RSA key size of 2048 bits in `pkg/crypt/std/rsa/rsa.go` to prevent the generation of insecure keys. - Exposes the OAEP label parameter in `Encrypt` and `Decrypt` functions, allowing for more advanced use cases. - Adds a test case to verify that `GenerateKeyPair` correctly rejects key sizes below the new minimum. --- pkg/crypt/crypt.go | 18 ++++++++++++++---- pkg/crypt/std/rsa/rsa.go | 11 +++++++---- pkg/crypt/std/rsa/rsa_test.go | 16 ++++++++++------ 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/pkg/crypt/crypt.go b/pkg/crypt/crypt.go index 8079f4e..7bb6d09 100644 --- a/pkg/crypt/crypt.go +++ b/pkg/crypt/crypt.go @@ -137,17 +137,27 @@ func (s *Service) Fletcher64(payload string) uint64 { // --- RSA --- +// ensureRSA initializes the RSA service if it is not already. +func (s *Service) ensureRSA() { + if s.rsa == nil { + s.rsa = rsa.NewService() + } +} + // GenerateRSAKeyPair creates a new RSA key pair. func (s *Service) GenerateRSAKeyPair(bits int) (publicKey, privateKey []byte, err error) { + s.ensureRSA() return s.rsa.GenerateKeyPair(bits) } // EncryptRSA encrypts data with a public key. -func (s *Service) EncryptRSA(publicKey, data []byte) ([]byte, error) { - return s.rsa.Encrypt(publicKey, data) +func (s *Service) EncryptRSA(publicKey, data, label []byte) ([]byte, error) { + s.ensureRSA() + return s.rsa.Encrypt(publicKey, data, label) } // DecryptRSA decrypts data with a private key. -func (s *Service) DecryptRSA(privateKey, ciphertext []byte) ([]byte, error) { - return s.rsa.Decrypt(privateKey, ciphertext) +func (s *Service) DecryptRSA(privateKey, ciphertext, label []byte) ([]byte, error) { + s.ensureRSA() + return s.rsa.Decrypt(privateKey, ciphertext, label) } \ No newline at end of file diff --git a/pkg/crypt/std/rsa/rsa.go b/pkg/crypt/std/rsa/rsa.go index 5c70152..5a19d3a 100644 --- a/pkg/crypt/std/rsa/rsa.go +++ b/pkg/crypt/std/rsa/rsa.go @@ -19,6 +19,9 @@ func NewService() *Service { // GenerateKeyPair creates a new RSA key pair. func (s *Service) GenerateKeyPair(bits int) (publicKey, privateKey []byte, err error) { + if bits < 2048 { + return nil, nil, fmt.Errorf("rsa: key size too small: %d (minimum 2048)", bits) + } privKey, err := rsa.GenerateKey(rand.Reader, bits) if err != nil { return nil, nil, fmt.Errorf("failed to generate private key: %w", err) @@ -43,7 +46,7 @@ func (s *Service) GenerateKeyPair(bits int) (publicKey, privateKey []byte, err e } // Encrypt encrypts data with a public key. -func (s *Service) Encrypt(publicKey, data []byte) ([]byte, error) { +func (s *Service) Encrypt(publicKey, data, label []byte) ([]byte, error) { block, _ := pem.Decode(publicKey) if block == nil { return nil, fmt.Errorf("failed to decode public key") @@ -59,7 +62,7 @@ func (s *Service) Encrypt(publicKey, data []byte) ([]byte, error) { return nil, fmt.Errorf("not an RSA public key") } - ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, rsaPub, data, nil) + ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, rsaPub, data, label) if err != nil { return nil, fmt.Errorf("failed to encrypt data: %w", err) } @@ -68,7 +71,7 @@ func (s *Service) Encrypt(publicKey, data []byte) ([]byte, error) { } // Decrypt decrypts data with a private key. -func (s *Service) Decrypt(privateKey, ciphertext []byte) ([]byte, error) { +func (s *Service) Decrypt(privateKey, ciphertext, label []byte) ([]byte, error) { block, _ := pem.Decode(privateKey) if block == nil { return nil, fmt.Errorf("failed to decode private key") @@ -79,7 +82,7 @@ func (s *Service) Decrypt(privateKey, ciphertext []byte) ([]byte, error) { return nil, fmt.Errorf("failed to parse private key: %w", err) } - plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, priv, ciphertext, nil) + plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, priv, ciphertext, label) if err != nil { return nil, fmt.Errorf("failed to decrypt data: %w", err) } diff --git a/pkg/crypt/std/rsa/rsa_test.go b/pkg/crypt/std/rsa/rsa_test.go index cd08d63..3515df4 100644 --- a/pkg/crypt/std/rsa/rsa_test.go +++ b/pkg/crypt/std/rsa/rsa_test.go @@ -17,9 +17,9 @@ func TestRSA_Good(t *testing.T) { // Encrypt and decrypt a message message := []byte("Hello, World!") - ciphertext, err := s.Encrypt(pubKey, message) + ciphertext, err := s.Encrypt(pubKey, message, nil) assert.NoError(t, err) - plaintext, err := s.Decrypt(privKey, ciphertext) + plaintext, err := s.Decrypt(privKey, ciphertext, nil) assert.NoError(t, err) assert.Equal(t, message, plaintext) } @@ -33,9 +33,13 @@ func TestRSA_Bad(t *testing.T) { _, otherPrivKey, err := s.GenerateKeyPair(2048) assert.NoError(t, err) message := []byte("Hello, World!") - ciphertext, err := s.Encrypt(pubKey, message) + ciphertext, err := s.Encrypt(pubKey, message, nil) assert.NoError(t, err) - _, err = s.Decrypt(otherPrivKey, ciphertext) + _, err = s.Decrypt(otherPrivKey, ciphertext, nil) + assert.Error(t, err) + + // Key size too small + _, _, err = s.GenerateKeyPair(512) assert.Error(t, err) } @@ -43,8 +47,8 @@ func TestRSA_Ugly(t *testing.T) { s := NewService() // Malformed keys and messages - _, err := s.Encrypt([]byte("not-a-key"), []byte("message")) + _, err := s.Encrypt([]byte("not-a-key"), []byte("message"), nil) assert.Error(t, err) - _, err = s.Decrypt([]byte("not-a-key"), []byte("message")) + _, err = s.Decrypt([]byte("not-a-key"), []byte("message"), nil) assert.Error(t, err) }