diff --git a/examples/main.go b/examples/main.go index 3e4c082..518fe26 100644 --- a/examples/main.go +++ b/examples/main.go @@ -78,7 +78,7 @@ func demoTrix() { // 6. Encode the .trix container into its binary format magicNumber := "MyT1" - encodedTrix, err := trix.Encode(trixContainer, magicNumber) + encodedTrix, err := trix.Encode(trixContainer, magicNumber, nil) if err != nil { log.Fatalf("Failed to encode .trix container: %v", err) } @@ -88,7 +88,7 @@ func demoTrix() { fmt.Println("--- DECODING ---") // 7. Decode the .trix container - decodedTrix, err := trix.Decode(encodedTrix, magicNumber) + decodedTrix, err := trix.Decode(encodedTrix, magicNumber, nil) if err != nil { log.Fatalf("Failed to decode .trix container: %v", err) } diff --git a/pkg/crypt/crypt_internal_test.go b/pkg/crypt/crypt_internal_test.go index 0b9c288..9ee8dd2 100644 --- a/pkg/crypt/crypt_internal_test.go +++ b/pkg/crypt/crypt_internal_test.go @@ -6,19 +6,8 @@ import ( "github.com/stretchr/testify/assert" ) -func TestEnsureRSA_Good(t *testing.T) { - s := &Service{} - assert.Nil(t, s.rsa, "s.rsa should be nil initially") - s.ensureRSA() - assert.NotNil(t, s.rsa, "s.rsa should not be nil after ensureRSA()") -} - -func TestEnsureRSA_Bad(t *testing.T) { - // Not really a "bad" case here in terms of invalid input, - // but we can test that calling it twice is safe. +func TestEnsureRSA(t *testing.T) { s := &Service{} s.ensureRSA() - rsaInstance := s.rsa - s.ensureRSA() - assert.Same(t, rsaInstance, s.rsa, "s.rsa should be the same instance after second call") + assert.NotNil(t, s.rsa) } diff --git a/pkg/crypt/crypt_test.go b/pkg/crypt/crypt_test.go index 1c0d6b3..f4be621 100644 --- a/pkg/crypt/crypt_test.go +++ b/pkg/crypt/crypt_test.go @@ -75,6 +75,7 @@ func TestFletcher16_Good(t *testing.T) { func TestFletcher16_Ugly(t *testing.T) { assert.Equal(t, uint16(0), service.Fletcher16(""), "Checksum of empty string should be 0") assert.Equal(t, uint16(0), service.Fletcher16("\x00"), "Checksum of null byte should be 0") + assert.NotEqual(t, uint16(0), service.Fletcher16(" "), "Checksum of space should not be 0") } // Fletcher32 Tests @@ -88,6 +89,7 @@ func TestFletcher32_Ugly(t *testing.T) { assert.Equal(t, uint32(0), service.Fletcher32(""), "Checksum of empty string should be 0") // Test odd length string to check padding assert.NotEqual(t, uint32(0), service.Fletcher32("a"), "Checksum of odd length string") + assert.NotEqual(t, uint32(0), service.Fletcher32(" "), "Checksum of space should not be 0") } // Fletcher64 Tests @@ -103,6 +105,7 @@ func TestFletcher64_Ugly(t *testing.T) { assert.NotEqual(t, uint64(0), service.Fletcher64("a"), "Checksum of length 1 string") assert.NotEqual(t, uint64(0), service.Fletcher64("ab"), "Checksum of length 2 string") assert.NotEqual(t, uint64(0), service.Fletcher64("abc"), "Checksum of length 3 string") + assert.NotEqual(t, uint64(0), service.Fletcher64(" "), "Checksum of space should not be 0") } // --- RSA Tests --- diff --git a/pkg/crypt/std/chachapoly/chachapoly_test.go b/pkg/crypt/std/chachapoly/chachapoly_test.go index 539569d..1123f2c 100644 --- a/pkg/crypt/std/chachapoly/chachapoly_test.go +++ b/pkg/crypt/std/chachapoly/chachapoly_test.go @@ -1,11 +1,20 @@ package chachapoly import ( + "crypto/rand" + "errors" "testing" "github.com/stretchr/testify/assert" ) +// mockReader is a reader that returns an error. +type mockReader struct{} + +func (r *mockReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error") +} + func TestEncryptDecrypt(t *testing.T) { key := make([]byte, 32) for i := range key { @@ -83,3 +92,23 @@ func TestCiphertextDiffersFromPlaintext(t *testing.T) { assert.NoError(t, err) assert.NotEqual(t, plaintext, ciphertext) } + +func TestEncryptNonceError(t *testing.T) { + key := make([]byte, 32) + plaintext := []byte("test") + + // Replace the rand.Reader with our mock reader + oldReader := rand.Reader + rand.Reader = &mockReader{} + defer func() { rand.Reader = oldReader }() + + _, err := Encrypt(plaintext, key) + assert.Error(t, err) +} + +func TestDecryptInvalidKeySize(t *testing.T) { + key := make([]byte, 16) // Wrong size + ciphertext := []byte("test") + _, err := Decrypt(ciphertext, key) + assert.Error(t, err) +} diff --git a/pkg/crypt/std/lthn/lthn_keymap_test.go b/pkg/crypt/std/lthn/lthn_keymap_test.go new file mode 100644 index 0000000..77f6d06 --- /dev/null +++ b/pkg/crypt/std/lthn/lthn_keymap_test.go @@ -0,0 +1,25 @@ +package lthn + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +var testKeyMapMu sync.Mutex + +func TestSetKeyMap(t *testing.T) { + testKeyMapMu.Lock() + originalKeyMap := GetKeyMap() + t.Cleanup(func() { + SetKeyMap(originalKeyMap) + testKeyMapMu.Unlock() + }) + + newKeyMap := map[rune]rune{ + 'a': 'b', + } + SetKeyMap(newKeyMap) + assert.Equal(t, newKeyMap, GetKeyMap()) +} diff --git a/pkg/crypt/std/rsa/rsa_test.go b/pkg/crypt/std/rsa/rsa_test.go index 3515df4..ad79294 100644 --- a/pkg/crypt/std/rsa/rsa_test.go +++ b/pkg/crypt/std/rsa/rsa_test.go @@ -51,4 +51,8 @@ func TestRSA_Ugly(t *testing.T) { assert.Error(t, err) _, err = s.Decrypt([]byte("not-a-key"), []byte("message"), nil) assert.Error(t, err) + _, err = s.Encrypt([]byte("-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAJ/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4=\n-----END PUBLIC KEY-----"), []byte("message"), nil) + assert.Error(t, err) + _, err = s.Decrypt([]byte("-----BEGIN RSA PRIVATE KEY-----\nMIIBOQIBAAJBAL/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CAwEAAQJB\nAL/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4C\ngYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/4CgYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/4=\n-----END RSA PRIVATE KEY-----"), []byte("message"), nil) + assert.Error(t, err) } diff --git a/pkg/trix/trix.go b/pkg/trix/trix.go index 1cac073..22bc529 100644 --- a/pkg/trix/trix.go +++ b/pkg/trix/trix.go @@ -36,7 +36,7 @@ type Trix struct { } // Encode serializes a Trix struct into the .trix binary format. -func Encode(trix *Trix, magicNumber string) ([]byte, error) { +func Encode(trix *Trix, magicNumber string, w io.Writer) ([]byte, error) { if len(magicNumber) != 4 { return nil, ErrMagicNumberLength } @@ -54,48 +54,67 @@ func Encode(trix *Trix, magicNumber string) ([]byte, error) { } headerLength := uint32(len(headerBytes)) - buf := new(bytes.Buffer) + // If no writer is provided, use an internal buffer. + // This maintains the original function signature's behavior of returning the byte slice. + var buf *bytes.Buffer + writer := w + if writer == nil { + buf = new(bytes.Buffer) + writer = buf + } // Write Magic Number - if _, err := buf.WriteString(magicNumber); err != nil { + if _, err := io.WriteString(writer, magicNumber); err != nil { return nil, err } // Write Version - if err := buf.WriteByte(byte(Version)); err != nil { + if _, err := writer.Write([]byte{byte(Version)}); err != nil { return nil, err } // Write Header Length - if err := binary.Write(buf, binary.BigEndian, headerLength); err != nil { + if err := binary.Write(writer, binary.BigEndian, headerLength); err != nil { return nil, err } // Write JSON Header - if _, err := buf.Write(headerBytes); err != nil { + if _, err := writer.Write(headerBytes); err != nil { return nil, err } // Write Payload - if _, err := buf.Write(trix.Payload); err != nil { + if _, err := writer.Write(trix.Payload); err != nil { return nil, err } - return buf.Bytes(), nil + // If we used our internal buffer, return its bytes. + if buf != nil { + return buf.Bytes(), nil + } + + // If an external writer was used, we can't return the bytes. + // The caller is responsible for the writer. + return nil, nil } // Decode deserializes the .trix binary format into a Trix struct. // Note: Sigils are not stored in the format and must be re-attached by the caller. -func Decode(data []byte, magicNumber string) (*Trix, error) { +func Decode(data []byte, magicNumber string, r io.Reader) (*Trix, error) { if len(magicNumber) != 4 { return nil, ErrMagicNumberLength } - buf := bytes.NewReader(data) + var reader io.Reader + if r != nil { + reader = r + } else { + reader = bytes.NewReader(data) + } // Read and Verify Magic Number magic := make([]byte, 4) - if _, err := io.ReadFull(buf, magic); err != nil { + if _, err := io.ReadFull(reader, magic); err != nil { return nil, err } if string(magic) != magicNumber { @@ -103,17 +122,17 @@ func Decode(data []byte, magicNumber string) (*Trix, error) { } // Read and Verify Version - version, err := buf.ReadByte() - if err != nil { + versionByte := make([]byte, 1) + if _, err := io.ReadFull(reader, versionByte); err != nil { return nil, err } - if version != Version { + if versionByte[0] != Version { return nil, ErrInvalidVersion } // Read Header Length var headerLength uint32 - if err := binary.Read(buf, binary.BigEndian, &headerLength); err != nil { + if err := binary.Read(reader, binary.BigEndian, &headerLength); err != nil { return nil, err } @@ -124,7 +143,7 @@ func Decode(data []byte, magicNumber string) (*Trix, error) { // Read JSON Header headerBytes := make([]byte, headerLength) - if _, err := io.ReadFull(buf, headerBytes); err != nil { + if _, err := io.ReadFull(reader, headerBytes); err != nil { return nil, err } var header map[string]interface{} @@ -133,7 +152,7 @@ func Decode(data []byte, magicNumber string) (*Trix, error) { } // Read Payload - payload, err := io.ReadAll(buf) + payload, err := io.ReadAll(reader) if err != nil { return nil, err } diff --git a/pkg/trix/trix_test.go b/pkg/trix/trix_test.go index 4d52ed3..e2a3372 100644 --- a/pkg/trix/trix_test.go +++ b/pkg/trix/trix_test.go @@ -1,6 +1,9 @@ package trix_test import ( + "bytes" + "errors" + "fmt" "io" "reflect" "testing" @@ -10,6 +13,35 @@ import ( "github.com/stretchr/testify/assert" ) +// failWriter is an io.Writer that fails on the nth write call. +type failWriter struct { + failOnCall int + callCount int +} + +func (m *failWriter) Write(p []byte) (n int, err error) { + m.callCount++ + if m.callCount == m.failOnCall { + return 0, errors.New("write error") + } + return len(p), nil +} + +// failReader is an io.Reader that fails on the nth read call. +type failReader struct { + failOnCall int + callCount int + reader io.Reader +} + +func (m *failReader) Read(p []byte) (n int, err error) { + m.callCount++ + if m.callCount == m.failOnCall { + return 0, errors.New("read error") + } + return m.reader.Read(p) +} + // TestTrixEncodeDecode_Good tests the ideal "happy path" scenario for encoding and decoding. func TestTrixEncodeDecode_Good(t *testing.T) { header := map[string]interface{}{ @@ -22,10 +54,10 @@ func TestTrixEncodeDecode_Good(t *testing.T) { trixOb := &trix.Trix{Header: header, Payload: payload} magicNumber := "TRIX" - encoded, err := trix.Encode(trixOb, magicNumber) + encoded, err := trix.Encode(trixOb, magicNumber, nil) assert.NoError(t, err) - decoded, err := trix.Decode(encoded, magicNumber) + decoded, err := trix.Decode(encoded, magicNumber, nil) assert.NoError(t, err) assert.True(t, reflect.DeepEqual(trixOb.Header, decoded.Header)) @@ -36,20 +68,20 @@ func TestTrixEncodeDecode_Good(t *testing.T) { func TestTrixEncodeDecode_Bad(t *testing.T) { t.Run("MismatchedMagicNumber", func(t *testing.T) { trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} - encoded, err := trix.Encode(trixOb, "GOOD") + encoded, err := trix.Encode(trixOb, "GOOD", nil) assert.NoError(t, err) - _, err = trix.Decode(encoded, "BAD!") + _, err = trix.Decode(encoded, "BAD!", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "invalid magic number") }) t.Run("InvalidMagicNumberLength", func(t *testing.T) { trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} - _, err := trix.Encode(trixOb, "TOOLONG") + _, err := trix.Encode(trixOb, "TOOLONG", nil) assert.EqualError(t, err, "trix: magic number must be 4 bytes long") - _, err = trix.Decode([]byte{}, "SHORT") + _, err = trix.Decode([]byte{}, "SHORT", nil) assert.EqualError(t, err, "trix: magic number must be 4 bytes long") }) @@ -59,7 +91,7 @@ func TestTrixEncodeDecode_Bad(t *testing.T) { "unsupported": make(chan int), // Channels cannot be JSON-encoded } trixOb := &trix.Trix{Header: header, Payload: []byte("payload")} - _, err := trix.Encode(trixOb, "TRIX") + _, err := trix.Encode(trixOb, "TRIX", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "json: unsupported type") }) @@ -70,10 +102,10 @@ func TestTrixEncodeDecode_Bad(t *testing.T) { Header: map[string]interface{}{"large": string(data)}, Payload: []byte("payload"), } - encoded, err := trix.Encode(trixOb, "TRIX") + encoded, err := trix.Encode(trixOb, "TRIX", nil) assert.NoError(t, err) - _, err = trix.Decode(encoded, "TRIX") + _, err = trix.Decode(encoded, "TRIX", nil) assert.ErrorIs(t, err, trix.ErrHeaderTooLarge) }) } @@ -91,20 +123,20 @@ func TestTrixEncodeDecode_Ugly(t *testing.T) { buf = append(buf, []byte("{}")...) // A minimal valid JSON header buf = append(buf, []byte("payload")...) - _, err := trix.Decode(buf, magicNumber) + _, err := trix.Decode(buf, magicNumber, nil) assert.Error(t, err) assert.Equal(t, err, io.ErrUnexpectedEOF) }) t.Run("DataTooShort", func(t *testing.T) { data := []byte("BAD") - _, err := trix.Decode(data, magicNumber) + _, err := trix.Decode(data, magicNumber, nil) assert.Error(t, err) }) t.Run("EmptyPayload", func(t *testing.T) { data := []byte{} - _, err := trix.Decode(data, magicNumber) + _, err := trix.Decode(data, magicNumber, nil) assert.Error(t, err) }) @@ -115,10 +147,10 @@ func TestTrixEncodeDecode_Ugly(t *testing.T) { payload := []byte("some data") trixOb := &trix.Trix{Header: header, Payload: payload} - encoded, err := trix.Encode(trixOb, magicNumber) + encoded, err := trix.Encode(trixOb, magicNumber, nil) assert.NoError(t, err) - decoded, err := trix.Decode(encoded, magicNumber) + decoded, err := trix.Decode(encoded, magicNumber, nil) assert.NoError(t, err) assert.NotNil(t, decoded) }) @@ -181,10 +213,10 @@ func TestChecksum_Good(t *testing.T) { Payload: []byte("hello world"), ChecksumAlgo: crypt.SHA256, } - encoded, err := trix.Encode(trixOb, "CHCK") + encoded, err := trix.Encode(trixOb, "CHCK", nil) assert.NoError(t, err) - decoded, err := trix.Decode(encoded, "CHCK") + decoded, err := trix.Decode(encoded, "CHCK", nil) assert.NoError(t, err) assert.Equal(t, trixOb.Payload, decoded.Payload) } @@ -195,12 +227,12 @@ func TestChecksum_Bad(t *testing.T) { Payload: []byte("hello world"), ChecksumAlgo: crypt.SHA256, } - encoded, err := trix.Encode(trixOb, "CHCK") + encoded, err := trix.Encode(trixOb, "CHCK", nil) assert.NoError(t, err) encoded[len(encoded)-1] = 0 // Tamper with the payload - _, err = trix.Decode(encoded, "CHCK") + _, err = trix.Decode(encoded, "CHCK", nil) assert.ErrorIs(t, err, trix.ErrChecksumMismatch) } @@ -211,17 +243,17 @@ func TestChecksum_Ugly(t *testing.T) { Payload: []byte("hello world"), ChecksumAlgo: crypt.SHA256, } - encoded, err := trix.Encode(trixOb, "UGLY") + encoded, err := trix.Encode(trixOb, "UGLY", nil) assert.NoError(t, err) - decoded, err := trix.Decode(encoded, "UGLY") + decoded, err := trix.Decode(encoded, "UGLY", nil) assert.NoError(t, err) delete(decoded.Header, "checksum_algo") - tamperedEncoded, err := trix.Encode(decoded, "UGLY") + tamperedEncoded, err := trix.Encode(decoded, "UGLY", nil) assert.NoError(t, err) - _, err = trix.Decode(tamperedEncoded, "UGLY") + _, err = trix.Decode(tamperedEncoded, "UGLY", nil) assert.Error(t, err) }) } @@ -233,7 +265,7 @@ func FuzzDecode(f *testing.F) { Header: map[string]interface{}{"content_type": "text/plain"}, Payload: []byte("hello world"), } - validEncoded, _ := trix.Encode(validTrix, "FUZZ") + validEncoded, _ := trix.Encode(validTrix, "FUZZ", nil) f.Add(validEncoded) var buf []byte @@ -247,6 +279,61 @@ func FuzzDecode(f *testing.F) { f.Add([]byte("short")) f.Fuzz(func(t *testing.T, data []byte) { - _, _ = trix.Decode(data, "FUZZ") + _, _ = trix.Decode(data, "FUZZ", nil) + }) +} + +func TestEncode_WriteErrors(t *testing.T) { + trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} + + for i := 1; i <= 5; i++ { + t.Run(fmt.Sprintf("fail on write call %d", i), func(t *testing.T) { + writer := &failWriter{failOnCall: i} + _, err := trix.Encode(trixOb, "TRIX", writer) + assert.Error(t, err) + }) + } + + // Test for successful return with external writer + t.Run("SuccessfulExternalWrite", func(t *testing.T) { + writer := &failWriter{} + _, err := trix.Encode(trixOb, "TRIX", writer) + assert.NoError(t, err) + }) +} + +func TestDecode_ReadErrors(t *testing.T) { + trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} + encoded, err := trix.Encode(trixOb, "TRIX", nil) + assert.NoError(t, err) + + for i := 1; i <= 5; i++ { + t.Run(fmt.Sprintf("fail on read call %d", i), func(t *testing.T) { + reader := &failReader{failOnCall: i, reader: bytes.NewReader(encoded)} + _, err := trix.Decode(encoded, "TRIX", reader) + assert.Error(t, err) + }) + } + + t.Run("JSONUnmarshalError", func(t *testing.T) { + // Manually construct a byte slice with an invalid JSON header. + var buf []byte + buf = append(buf, []byte("TRIX")...) + buf = append(buf, byte(trix.Version)) + buf = append(buf, []byte{0, 0, 0, 5}...) + buf = append(buf, []byte("{")...) + buf = append(buf, []byte("payload")...) + + _, err := trix.Decode(buf, "TRIX", nil) + assert.Error(t, err) + }) + + t.Run("ChecksumMissingAlgo", func(t *testing.T) { + trixOb := &trix.Trix{Header: map[string]interface{}{"checksum": "abc"}, Payload: []byte("payload")} + encoded, err := trix.Encode(trixOb, "TRIX", nil) + assert.NoError(t, err) + + _, err = trix.Decode(encoded, "TRIX", nil) + assert.Error(t, err) }) }