diff --git a/examples/main.go b/examples/main.go index 5adae8c..0a31fb2 100644 --- a/examples/main.go +++ b/examples/main.go @@ -41,6 +41,7 @@ func main() { trixContainer := &trix.Trix{ Header: header, Payload: actualCiphertext, + Sigils: []trix.Sigil{&trix.ReverseSigil{}}, } // 4. Encode the .trix container into its binary format @@ -58,6 +59,13 @@ func main() { log.Fatalf("Failed to decode .trix container: %v", err) } + // Manually apply the Out method of the sigil to restore the original payload. + restoredPayload, err := trixContainer.Sigils[0].Out(decodedTrix.Payload) + if err != nil { + log.Fatalf("Failed to apply sigil: %v", err) + } + decodedTrix.Payload = restoredPayload + // 6. Reassemble the ciphertext (nonce + payload) and decrypt retrievedNonceStr, ok := decodedTrix.Header["nonce"].(string) if !ok { diff --git a/pkg/crypt/crypt.go b/pkg/crypt/crypt.go index a9e57a9..26d96b1 100644 --- a/pkg/crypt/crypt.go +++ b/pkg/crypt/crypt.go @@ -61,10 +61,14 @@ func (s *Service) Hash(lib HashType, payload string) string { // Luhn validates a number using the Luhn algorithm. func (s *Service) Luhn(payload string) bool { payload = strings.ReplaceAll(payload, " ", "") + if len(payload) <= 1 { + return false + } + sum := 0 - isSecond := false - for i := len(payload) - 1; i >= 0; i-- { - digit, err := strconv.Atoi(string(payload[i])) + isSecond := len(payload)%2 == 0 + for _, r := range payload { + digit, err := strconv.Atoi(string(r)) if err != nil { return false // Contains non-digit } diff --git a/pkg/crypt/crypt_test.go b/pkg/crypt/crypt_test.go index 6d282a9..bead19f 100644 --- a/pkg/crypt/crypt_test.go +++ b/pkg/crypt/crypt_test.go @@ -1,53 +1,109 @@ package crypt import ( - "fmt" + "strings" "testing" "github.com/stretchr/testify/assert" ) -func TestHash(t *testing.T) { - service := NewService() +var service = NewService() + +// --- Hashing Tests --- + +func TestHash_Good(t *testing.T) { payload := "hello" - hash := service.Hash(LTHN, payload) - assert.NotEmpty(t, hash) + // Test all supported hash types + for _, hashType := range []HashType{LTHN, SHA512, SHA256, SHA1, MD5} { + hash := service.Hash(hashType, payload) + assert.NotEmpty(t, hash, "Hash should not be empty for type %s", hashType) + } } -func TestLuhn(t *testing.T) { - service := NewService() +func TestHash_Bad(t *testing.T) { + // Using an unsupported hash type should default to SHA256 + hash := service.Hash("unsupported", "hello") + expectedHash := service.Hash(SHA256, "hello") + assert.Equal(t, expectedHash, hash) +} + +func TestHash_Ugly(t *testing.T) { + // Test with potentially problematic inputs + testCases := []string{ + "", // Empty string + " ", // Whitespace + "\x00\x01\x02\x03\x04", // Null bytes + strings.Repeat("a", 1024*1024), // Large payload (1MB) + "こんにちは", // Unicode characters + } + + for _, tc := range testCases { + for _, hashType := range []HashType{LTHN, SHA512, SHA256, SHA1, MD5} { + hash := service.Hash(hashType, tc) + assert.NotEmpty(t, hash, "Hash for ugly input should not be empty for type %s", hashType) + } + } +} + +// --- Checksum Tests --- + +// Luhn Tests +func TestLuhn_Good(t *testing.T) { assert.True(t, service.Luhn("79927398713")) - assert.False(t, service.Luhn("79927398714")) } -func TestFletcher16(t *testing.T) { - service := NewService() +func TestLuhn_Bad(t *testing.T) { + assert.False(t, service.Luhn("79927398714"), "Should fail for incorrect checksum") + assert.False(t, service.Luhn("7992739871a"), "Should fail for non-numeric input") +} + +func TestLuhn_Ugly(t *testing.T) { + assert.False(t, service.Luhn(""), "Should be false for empty string") + assert.False(t, service.Luhn(" 1 2 3 "), "Should handle spaces but result in false") +} + +// Fletcher16 Tests +func TestFletcher16_Good(t *testing.T) { assert.Equal(t, uint16(0xC8F0), service.Fletcher16("abcde")) assert.Equal(t, uint16(0x2057), service.Fletcher16("abcdef")) assert.Equal(t, uint16(0x0627), service.Fletcher16("abcdefgh")) } -func TestFletcher32(t *testing.T) { - service := NewService() - expected := uint32(0xF04FC729) - actual := service.Fletcher32("abcde") - fmt.Printf("Fletcher32('abcde'): expected: %x, actual: %x\n", expected, actual) - assert.Equal(t, expected, actual) - - expected = uint32(0x56502D2A) - actual = service.Fletcher32("abcdef") - fmt.Printf("Fletcher32('abcdef'): expected: %x, actual: %x\n", expected, actual) - assert.Equal(t, expected, actual) - - expected = uint32(0xEBE19591) - actual = service.Fletcher32("abcdefgh") - fmt.Printf("Fletcher32('abcdefgh'): expected: %x, actual: %x\n", expected, actual) - assert.Equal(t, expected, actual) +func TestFletcher16_Bad(t *testing.T) { + // No obviously "bad" inputs that don't fall into "ugly" + // For Fletcher, any string is a valid input. } -func TestFletcher64(t *testing.T) { - service := NewService() +func TestFletcher16_Ugly(t *testing.T) { + assert.Equal(t, uint16(0), service.Fletcher16(""), "Checksum of empty string should be 0") +} + +// Fletcher32 Tests +func TestFletcher32_Good(t *testing.T) { + assert.Equal(t, uint32(0xF04FC729), service.Fletcher32("abcde")) + assert.Equal(t, uint32(0x56502D2A), service.Fletcher32("abcdef")) + assert.Equal(t, uint32(0xEBE19591), service.Fletcher32("abcdefgh")) +} + +func TestFletcher32_Bad(t *testing.T) { + // Any string is a valid input. +} + +func TestFletcher32_Ugly(t *testing.T) { + assert.Equal(t, uint32(0), service.Fletcher32(""), "Checksum of empty string should be 0") +} + +// Fletcher64 Tests +func TestFletcher64_Good(t *testing.T) { assert.Equal(t, uint64(0xc8c6c527646362c6), service.Fletcher64("abcde")) assert.Equal(t, uint64(0xc8c72b276463c8c6), service.Fletcher64("abcdef")) assert.Equal(t, uint64(0x312e2b28cccac8c6), service.Fletcher64("abcdefgh")) } + +func TestFletcher64_Bad(t *testing.T) { + // Any string is a valid input. +} + +func TestFletcher64_Ugly(t *testing.T) { + assert.Equal(t, uint64(0), service.Fletcher64(""), "Checksum of empty string should be 0") +} diff --git a/pkg/trix/trix.go b/pkg/trix/trix.go index 8f1c9e7..bab5edd 100644 --- a/pkg/trix/trix.go +++ b/pkg/trix/trix.go @@ -17,12 +17,20 @@ var ( ErrInvalidMagicNumber = errors.New("trix: invalid magic number") ErrInvalidVersion = errors.New("trix: invalid version") ErrMagicNumberLength = errors.New("trix: magic number must be 4 bytes long") + ErrNilSigil = errors.New("trix: sigil cannot be nil") ) +// Sigil defines the interface for a data transformer. +type Sigil interface { + In(data []byte) ([]byte, error) + Out(data []byte) ([]byte, error) +} + // Trix represents the structure of a .trix file. type Trix struct { Header map[string]interface{} Payload []byte + Sigils []Sigil `json:"-"` // Ignore Sigils during JSON marshaling } // Encode serializes a Trix struct into the .trix binary format. @@ -31,6 +39,19 @@ func Encode(trix *Trix, magicNumber string) ([]byte, error) { return nil, ErrMagicNumberLength } + // Apply sigils to the payload before encoding + payload := trix.Payload + for _, sigil := range trix.Sigils { + if sigil == nil { + return nil, ErrNilSigil + } + var err error + payload, err = sigil.In(payload) + if err != nil { + return nil, err + } + } + headerBytes, err := json.Marshal(trix.Header) if err != nil { return nil, err @@ -60,7 +81,7 @@ func Encode(trix *Trix, magicNumber string) ([]byte, error) { } // Write Payload - if _, err := buf.Write(trix.Payload); err != nil { + if _, err := buf.Write(payload); err != nil { return nil, err } @@ -68,6 +89,7 @@ func Encode(trix *Trix, magicNumber string) ([]byte, error) { } // Decode deserializes the .trix binary format into a Trix struct. +// Note: Sigils are not stored in the format and must be re-attached by the caller. func Decode(data []byte, magicNumber string) (*Trix, error) { if len(magicNumber) != 4 { return nil, ErrMagicNumberLength @@ -120,3 +142,21 @@ func Decode(data []byte, magicNumber string) (*Trix, error) { Payload: payload, }, nil } + +// ReverseSigil is an example Sigil that reverses the bytes of the payload. +type ReverseSigil struct{} + +// In reverses the bytes of the data. +func (s *ReverseSigil) In(data []byte) ([]byte, error) { + reversed := make([]byte, len(data)) + for i, j := 0, len(data)-1; i < len(data); i, j = i+1, j-1 { + reversed[i] = data[j] + } + return reversed, nil +} + +// Out reverses the bytes of the data. +func (s *ReverseSigil) Out(data []byte) ([]byte, error) { + // Reversing the bytes again restores the original data. + return s.In(data) +} diff --git a/pkg/trix/trix_test.go b/pkg/trix/trix_test.go index 13531e7..2cf4912 100644 --- a/pkg/trix/trix_test.go +++ b/pkg/trix/trix_test.go @@ -1,13 +1,16 @@ package trix import ( + "errors" + "io" "reflect" "testing" "github.com/stretchr/testify/assert" ) -func TestEncodeDecode(t *testing.T) { +// TestTrixEncodeDecode_Good tests the ideal "happy path" scenario for encoding and decoding. +func TestTrixEncodeDecode_Good(t *testing.T) { header := map[string]interface{}{ "content_type": "application/octet-stream", "encryption_algorithm": "chacha20poly1035", @@ -15,13 +18,9 @@ func TestEncodeDecode(t *testing.T) { "created_at": "2025-10-30T12:00:00Z", } payload := []byte("This is a secret message.") - - trix := &Trix{ - Header: header, - Payload: payload, - } - + trix := &Trix{Header: header, Payload: payload} magicNumber := "TRIX" + encoded, err := Encode(trix, magicNumber) assert.NoError(t, err) @@ -32,61 +31,145 @@ func TestEncodeDecode(t *testing.T) { assert.Equal(t, trix.Payload, decoded.Payload) } -func TestEncodeDecode_InvalidMagicNumber(t *testing.T) { - header := map[string]interface{}{ - "content_type": "application/octet-stream", - } - payload := []byte("This is a secret message.") +// TestTrixEncodeDecode_Bad tests expected failure scenarios with well-formed but invalid inputs. +func TestTrixEncodeDecode_Bad(t *testing.T) { + t.Run("MismatchedMagicNumber", func(t *testing.T) { + trix := &Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} + encoded, err := Encode(trix, "GOOD") + assert.NoError(t, err) + _, err = Decode(encoded, "BAD!") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid magic number") + }) + + t.Run("InvalidMagicNumberLength", func(t *testing.T) { + trix := &Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} + _, err := Encode(trix, "TOOLONG") + assert.EqualError(t, err, "trix: magic number must be 4 bytes long") + + _, err = Decode([]byte{}, "SHORT") + assert.EqualError(t, err, "trix: magic number must be 4 bytes long") + }) + + t.Run("MalformedHeaderJSON", func(t *testing.T) { + // Create a Trix struct with a header that cannot be marshaled to JSON. + header := map[string]interface{}{ + "unsupported": make(chan int), // Channels cannot be JSON-encoded + } + trix := &Trix{Header: header, Payload: []byte("payload")} + _, err := Encode(trix, "TRIX") + assert.Error(t, err) + assert.Contains(t, err.Error(), "json: unsupported type") + }) +} + +// TestTrixEncodeDecode_Ugly tests malicious or malformed inputs designed to cause crashes or panics. +func TestTrixEncodeDecode_Ugly(t *testing.T) { + magicNumber := "UGLY" + + t.Run("CorruptedHeaderLength", func(t *testing.T) { + // Manually construct a byte slice where the header length is larger than the actual data. + var buf []byte + buf = append(buf, []byte(magicNumber)...) // Magic Number + buf = append(buf, byte(Version)) // Version + // Header length of 1000, but the header is only 2 bytes long. + buf = append(buf, []byte{0, 0, 3, 232}...) // BigEndian representation of 1000 + buf = append(buf, []byte("{}")...) // A minimal valid JSON header + buf = append(buf, []byte("payload")...) + + _, err := Decode(buf, magicNumber) + assert.Error(t, err) + assert.Equal(t, err, io.ErrUnexpectedEOF) + }) + + t.Run("DataTooShort", func(t *testing.T) { + // Data is too short to contain even the magic number. + data := []byte("BAD") + _, err := Decode(data, magicNumber) + assert.Error(t, err) + }) + + t.Run("EmptyPayload", func(t *testing.T) { + data := []byte{} + _, err := Decode(data, magicNumber) + assert.Error(t, err) + }) + + t.Run("FuzzedJSON", func(t *testing.T) { + // A header that is technically valid but contains unexpected types. + header := map[string]interface{}{ + "payload": map[string]interface{}{"nested": 123}, + } + payload := []byte("some data") + trix := &Trix{Header: header, Payload: payload} + + encoded, err := Encode(trix, magicNumber) + assert.NoError(t, err) + + decoded, err := Decode(encoded, magicNumber) + assert.NoError(t, err) + assert.NotNil(t, decoded) + }) +} + +// --- Sigil Tests --- + +// FailingSigil is a helper for testing sigils that intentionally fail. +type FailingSigil struct { + err error +} + +func (s *FailingSigil) In(data []byte) ([]byte, error) { + return nil, s.err +} +func (s *FailingSigil) Out(data []byte) ([]byte, error) { + return nil, s.err +} + +func TestSigilPipeline_Good(t *testing.T) { + originalPayload := []byte("hello world") trix := &Trix{ - Header: header, - Payload: payload, + Header: map[string]interface{}{}, + Payload: originalPayload, + Sigils: []Sigil{&ReverseSigil{}}, } - magicNumber := "TRIX" - wrongMagicNumber := "XXXX" - encoded, err := Encode(trix, magicNumber) + encoded, err := Encode(trix, "SIGL") assert.NoError(t, err) - _, err = Decode(encoded, wrongMagicNumber) - assert.Error(t, err) - assert.EqualError(t, err, "trix: invalid magic number: expected XXXX, got TRIX") -} - -func TestEncode_InvalidMagicNumberLength(t *testing.T) { - header := map[string]interface{}{ - "content_type": "application/octet-stream", - } - payload := []byte("This is a secret message.") - - trix := &Trix{ - Header: header, - Payload: payload, - } - - magicNumber := "TOOLONG" - _, err := Encode(trix, magicNumber) - assert.Error(t, err) - assert.EqualError(t, err, "trix: magic number must be 4 bytes long") -} - -func TestDecode_InvalidMagicNumberLength(t *testing.T) { - header := map[string]interface{}{ - "content_type": "application/octet-stream", - } - payload := []byte("This is a secret message.") - - trix := &Trix{ - Header: header, - Payload: payload, - } - - magicNumber := "TRIX" - encoded, err := Encode(trix, magicNumber) + decoded, err := Decode(encoded, "SIGL") assert.NoError(t, err) - invalidMagicNumber := "SHORT" - _, err = Decode(encoded, invalidMagicNumber) - assert.Error(t, err) - assert.EqualError(t, err, "trix: magic number must be 4 bytes long") + // Manually apply the Out method to restore the original payload. + restoredPayload, err := trix.Sigils[0].Out(decoded.Payload) + assert.NoError(t, err) + assert.Equal(t, originalPayload, restoredPayload) +} + +func TestSigilPipeline_Bad(t *testing.T) { + expectedErr := errors.New("sigil failed") + trix := &Trix{ + Header: map[string]interface{}{}, + Payload: []byte("some data"), + Sigils: []Sigil{&ReverseSigil{}, &FailingSigil{err: expectedErr}}, + } + + _, err := Encode(trix, "FAIL") + assert.Error(t, err) + assert.Equal(t, expectedErr, err) +} + +func TestSigilPipeline_Ugly(t *testing.T) { + t.Run("NilSigil", func(t *testing.T) { + trix := &Trix{ + Header: map[string]interface{}{}, + Payload: []byte("some data"), + Sigils: []Sigil{nil}, + } + + _, err := Encode(trix, "UGLY") + assert.Error(t, err) + assert.Equal(t, ErrNilSigil, err) + }) }