diff --git a/examples/main.go b/examples/main.go index 2fa3c14..3e4c082 100644 --- a/examples/main.go +++ b/examples/main.go @@ -182,14 +182,14 @@ func demoRSA() { // 2. Encrypt a message message := []byte("This is a secret message for RSA.") fmt.Printf("\nOriginal message: %s\n", message) - ciphertext, err := cryptService.EncryptRSA(publicKey, message) + ciphertext, err := cryptService.EncryptRSA(publicKey, message, nil) if err != nil { log.Fatalf("Failed to encrypt with RSA: %v", err) } fmt.Printf("Encrypted ciphertext (base64): %s\n", base64.StdEncoding.EncodeToString(ciphertext)) // 3. Decrypt the message - decrypted, err := cryptService.DecryptRSA(privateKey, ciphertext) + decrypted, err := cryptService.DecryptRSA(privateKey, ciphertext, nil) if err != nil { log.Fatalf("Failed to decrypt with RSA: %v", err) } diff --git a/pkg/crypt/crypt.go b/pkg/crypt/crypt.go index 4879c60..7bb6d09 100644 --- a/pkg/crypt/crypt.go +++ b/pkg/crypt/crypt.go @@ -151,13 +151,13 @@ func (s *Service) GenerateRSAKeyPair(bits int) (publicKey, privateKey []byte, er } // EncryptRSA encrypts data with a public key. -func (s *Service) EncryptRSA(publicKey, data []byte) ([]byte, error) { +func (s *Service) EncryptRSA(publicKey, data, label []byte) ([]byte, error) { s.ensureRSA() - return s.rsa.Encrypt(publicKey, data, nil) + return s.rsa.Encrypt(publicKey, data, label) } // DecryptRSA decrypts data with a private key. -func (s *Service) DecryptRSA(privateKey, ciphertext []byte) ([]byte, error) { +func (s *Service) DecryptRSA(privateKey, ciphertext, label []byte) ([]byte, error) { s.ensureRSA() - return s.rsa.Decrypt(privateKey, ciphertext, nil) + return s.rsa.Decrypt(privateKey, ciphertext, label) } \ No newline at end of file diff --git a/tdd/crypt/crypt_test.go b/tdd/crypt/crypt_test.go index 14595eb..a8e2bff 100644 --- a/tdd/crypt/crypt_test.go +++ b/tdd/crypt/crypt_test.go @@ -116,9 +116,10 @@ func TestRSA_Good(t *testing.T) { // Test encryption and decryption message := []byte("secret message") - ciphertext, err := service.EncryptRSA(pubKey, message) + label := []byte("test label") + ciphertext, err := service.EncryptRSA(pubKey, message, label) assert.NoError(t, err) - plaintext, err := service.DecryptRSA(privKey, ciphertext) + plaintext, err := service.DecryptRSA(privKey, ciphertext, label) assert.NoError(t, err) assert.Equal(t, message, plaintext) } @@ -129,32 +130,40 @@ func TestRSA_Bad(t *testing.T) { assert.Error(t, err) // Test decryption with the wrong key - pubKey, _, err := service.GenerateRSAKeyPair(2048) + pubKey, privKey, err := service.GenerateRSAKeyPair(2048) assert.NoError(t, err) _, otherPrivKey, err := service.GenerateRSAKeyPair(2048) assert.NoError(t, err) message := []byte("secret message") - ciphertext, err := service.EncryptRSA(pubKey, message) + ciphertext, err := service.EncryptRSA(pubKey, message, nil) assert.NoError(t, err) - _, err = service.DecryptRSA(otherPrivKey, ciphertext) + _, err = service.DecryptRSA(otherPrivKey, ciphertext, nil) + assert.Error(t, err) + + // Test decryption with wrong label + label1 := []byte("label1") + label2 := []byte("label2") + ciphertext, err = service.EncryptRSA(pubKey, message, label1) + assert.NoError(t, err) + _, err = service.DecryptRSA(privKey, ciphertext, label2) assert.Error(t, err) } func TestRSA_Ugly(t *testing.T) { // Test with malformed keys - _, err := service.EncryptRSA([]byte("not a real key"), []byte("message")) + _, err := service.EncryptRSA([]byte("not a real key"), []byte("message"), nil) assert.Error(t, err) - _, err = service.DecryptRSA([]byte("not a real key"), []byte("message")) + _, err = service.DecryptRSA([]byte("not a real key"), []byte("message"), nil) assert.Error(t, err) // Test with empty message pubKey, privKey, err := service.GenerateRSAKeyPair(2048) assert.NoError(t, err) message := []byte("") - ciphertext, err := service.EncryptRSA(pubKey, message) + ciphertext, err := service.EncryptRSA(pubKey, message, nil) assert.NoError(t, err) - plaintext, err := service.DecryptRSA(privKey, ciphertext) + plaintext, err := service.DecryptRSA(privKey, ciphertext, nil) assert.NoError(t, err) assert.Equal(t, message, plaintext) } diff --git a/tdd/trix/trix_test.go b/tdd/trix/trix_test.go index 9733891..d0a9cec 100644 --- a/tdd/trix/trix_test.go +++ b/tdd/trix/trix_test.go @@ -201,8 +201,7 @@ func TestChecksum_Bad(t *testing.T) { encoded[len(encoded)-1] = 0 // Tamper with the payload _, err = trix.Decode(encoded, "CHCK") - assert.Error(t, err) - assert.Equal(t, trix.ErrChecksumMismatch, err) + assert.ErrorIs(t, err, trix.ErrChecksumMismatch) } func TestChecksum_Ugly(t *testing.T) {