From b4ef069ee67e56e89f49cf80beabe1418b597c11 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 2 Nov 2025 18:46:36 +0000 Subject: [PATCH] fix: Correct test logic and revert breaking API changes This commit addresses feedback from the code review: - Updates the `TestChecksum_Bad` test in `tdd/trix/trix_test.go` to use `assert.ErrorIs` for consistent error handling. - Reverts the breaking API change to `EncryptRSA` and `DecryptRSA` in `pkg/crypt/crypt.go` by re-introducing the `label` parameter to the public-facing functions. - Updates the tests and examples to match the reverted API. - Fixes a build error in `tdd/crypt/crypt_test.go` by re-introducing a necessary variable. --- examples/main.go | 4 ++-- pkg/crypt/crypt.go | 8 ++++---- tdd/crypt/crypt_test.go | 27 ++++++++++++++++++--------- tdd/trix/trix_test.go | 3 +-- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/examples/main.go b/examples/main.go index 2fa3c14..3e4c082 100644 --- a/examples/main.go +++ b/examples/main.go @@ -182,14 +182,14 @@ func demoRSA() { // 2. Encrypt a message message := []byte("This is a secret message for RSA.") fmt.Printf("\nOriginal message: %s\n", message) - ciphertext, err := cryptService.EncryptRSA(publicKey, message) + ciphertext, err := cryptService.EncryptRSA(publicKey, message, nil) if err != nil { log.Fatalf("Failed to encrypt with RSA: %v", err) } fmt.Printf("Encrypted ciphertext (base64): %s\n", base64.StdEncoding.EncodeToString(ciphertext)) // 3. Decrypt the message - decrypted, err := cryptService.DecryptRSA(privateKey, ciphertext) + decrypted, err := cryptService.DecryptRSA(privateKey, ciphertext, nil) if err != nil { log.Fatalf("Failed to decrypt with RSA: %v", err) } diff --git a/pkg/crypt/crypt.go b/pkg/crypt/crypt.go index 4879c60..7bb6d09 100644 --- a/pkg/crypt/crypt.go +++ b/pkg/crypt/crypt.go @@ -151,13 +151,13 @@ func (s *Service) GenerateRSAKeyPair(bits int) (publicKey, privateKey []byte, er } // EncryptRSA encrypts data with a public key. -func (s *Service) EncryptRSA(publicKey, data []byte) ([]byte, error) { +func (s *Service) EncryptRSA(publicKey, data, label []byte) ([]byte, error) { s.ensureRSA() - return s.rsa.Encrypt(publicKey, data, nil) + return s.rsa.Encrypt(publicKey, data, label) } // DecryptRSA decrypts data with a private key. -func (s *Service) DecryptRSA(privateKey, ciphertext []byte) ([]byte, error) { +func (s *Service) DecryptRSA(privateKey, ciphertext, label []byte) ([]byte, error) { s.ensureRSA() - return s.rsa.Decrypt(privateKey, ciphertext, nil) + return s.rsa.Decrypt(privateKey, ciphertext, label) } \ No newline at end of file diff --git a/tdd/crypt/crypt_test.go b/tdd/crypt/crypt_test.go index 14595eb..a8e2bff 100644 --- a/tdd/crypt/crypt_test.go +++ b/tdd/crypt/crypt_test.go @@ -116,9 +116,10 @@ func TestRSA_Good(t *testing.T) { // Test encryption and decryption message := []byte("secret message") - ciphertext, err := service.EncryptRSA(pubKey, message) + label := []byte("test label") + ciphertext, err := service.EncryptRSA(pubKey, message, label) assert.NoError(t, err) - plaintext, err := service.DecryptRSA(privKey, ciphertext) + plaintext, err := service.DecryptRSA(privKey, ciphertext, label) assert.NoError(t, err) assert.Equal(t, message, plaintext) } @@ -129,32 +130,40 @@ func TestRSA_Bad(t *testing.T) { assert.Error(t, err) // Test decryption with the wrong key - pubKey, _, err := service.GenerateRSAKeyPair(2048) + pubKey, privKey, err := service.GenerateRSAKeyPair(2048) assert.NoError(t, err) _, otherPrivKey, err := service.GenerateRSAKeyPair(2048) assert.NoError(t, err) message := []byte("secret message") - ciphertext, err := service.EncryptRSA(pubKey, message) + ciphertext, err := service.EncryptRSA(pubKey, message, nil) assert.NoError(t, err) - _, err = service.DecryptRSA(otherPrivKey, ciphertext) + _, err = service.DecryptRSA(otherPrivKey, ciphertext, nil) + assert.Error(t, err) + + // Test decryption with wrong label + label1 := []byte("label1") + label2 := []byte("label2") + ciphertext, err = service.EncryptRSA(pubKey, message, label1) + assert.NoError(t, err) + _, err = service.DecryptRSA(privKey, ciphertext, label2) assert.Error(t, err) } func TestRSA_Ugly(t *testing.T) { // Test with malformed keys - _, err := service.EncryptRSA([]byte("not a real key"), []byte("message")) + _, err := service.EncryptRSA([]byte("not a real key"), []byte("message"), nil) assert.Error(t, err) - _, err = service.DecryptRSA([]byte("not a real key"), []byte("message")) + _, err = service.DecryptRSA([]byte("not a real key"), []byte("message"), nil) assert.Error(t, err) // Test with empty message pubKey, privKey, err := service.GenerateRSAKeyPair(2048) assert.NoError(t, err) message := []byte("") - ciphertext, err := service.EncryptRSA(pubKey, message) + ciphertext, err := service.EncryptRSA(pubKey, message, nil) assert.NoError(t, err) - plaintext, err := service.DecryptRSA(privKey, ciphertext) + plaintext, err := service.DecryptRSA(privKey, ciphertext, nil) assert.NoError(t, err) assert.Equal(t, message, plaintext) } diff --git a/tdd/trix/trix_test.go b/tdd/trix/trix_test.go index 9733891..d0a9cec 100644 --- a/tdd/trix/trix_test.go +++ b/tdd/trix/trix_test.go @@ -201,8 +201,7 @@ func TestChecksum_Bad(t *testing.T) { encoded[len(encoded)-1] = 0 // Tamper with the payload _, err = trix.Decode(encoded, "CHCK") - assert.Error(t, err) - assert.Equal(t, trix.ErrChecksumMismatch, err) + assert.ErrorIs(t, err, trix.ErrChecksumMismatch) } func TestChecksum_Ugly(t *testing.T) {