diff --git a/examples/main.go b/examples/main.go index 3133163..5adae8c 100644 --- a/examples/main.go +++ b/examples/main.go @@ -44,7 +44,8 @@ func main() { } // 4. Encode the .trix container into its binary format - encodedTrix, err := trix.Encode(trixContainer) + magicNumber := "MyT1" // My Trix 1 + encodedTrix, err := trix.Encode(trixContainer, magicNumber) if err != nil { log.Fatalf("Failed to encode .trix container: %v", err) } @@ -52,7 +53,7 @@ func main() { fmt.Println("Successfully created .trix container.") // 5. Decode the .trix container to retrieve the encrypted data - decodedTrix, err := trix.Decode(encodedTrix) + decodedTrix, err := trix.Decode(encodedTrix, magicNumber) if err != nil { log.Fatalf("Failed to decode .trix container: %v", err) } diff --git a/pkg/trix/trix.go b/pkg/trix/trix.go index 575b55e..8f1c9e7 100644 --- a/pkg/trix/trix.go +++ b/pkg/trix/trix.go @@ -5,17 +5,18 @@ import ( "encoding/binary" "encoding/json" "errors" + "fmt" "io" ) const ( - MagicNumber = "TRIX" - Version = 2 + Version = 2 ) var ( ErrInvalidMagicNumber = errors.New("trix: invalid magic number") ErrInvalidVersion = errors.New("trix: invalid version") + ErrMagicNumberLength = errors.New("trix: magic number must be 4 bytes long") ) // Trix represents the structure of a .trix file. @@ -25,7 +26,11 @@ type Trix struct { } // Encode serializes a Trix struct into the .trix binary format. -func Encode(trix *Trix) ([]byte, error) { +func Encode(trix *Trix, magicNumber string) ([]byte, error) { + if len(magicNumber) != 4 { + return nil, ErrMagicNumberLength + } + headerBytes, err := json.Marshal(trix.Header) if err != nil { return nil, err @@ -35,7 +40,7 @@ func Encode(trix *Trix) ([]byte, error) { buf := new(bytes.Buffer) // Write Magic Number - if _, err := buf.WriteString(MagicNumber); err != nil { + if _, err := buf.WriteString(magicNumber); err != nil { return nil, err } @@ -63,7 +68,11 @@ func Encode(trix *Trix) ([]byte, error) { } // Decode deserializes the .trix binary format into a Trix struct. -func Decode(data []byte) (*Trix, error) { +func Decode(data []byte, magicNumber string) (*Trix, error) { + if len(magicNumber) != 4 { + return nil, ErrMagicNumberLength + } + buf := bytes.NewReader(data) // Read and Verify Magic Number @@ -71,8 +80,8 @@ func Decode(data []byte) (*Trix, error) { if _, err := io.ReadFull(buf, magic); err != nil { return nil, err } - if string(magic) != MagicNumber { - return nil, ErrInvalidMagicNumber + if string(magic) != magicNumber { + return nil, fmt.Errorf("%w: expected %s, got %s", ErrInvalidMagicNumber, magicNumber, string(magic)) } // Read and Verify Version diff --git a/pkg/trix/trix_test.go b/pkg/trix/trix_test.go index d46695f..13531e7 100644 --- a/pkg/trix/trix_test.go +++ b/pkg/trix/trix_test.go @@ -21,12 +21,72 @@ func TestEncodeDecode(t *testing.T) { Payload: payload, } - encoded, err := Encode(trix) + magicNumber := "TRIX" + encoded, err := Encode(trix, magicNumber) assert.NoError(t, err) - decoded, err := Decode(encoded) + decoded, err := Decode(encoded, magicNumber) assert.NoError(t, err) assert.True(t, reflect.DeepEqual(trix.Header, decoded.Header)) assert.Equal(t, trix.Payload, decoded.Payload) } + +func TestEncodeDecode_InvalidMagicNumber(t *testing.T) { + header := map[string]interface{}{ + "content_type": "application/octet-stream", + } + payload := []byte("This is a secret message.") + + trix := &Trix{ + Header: header, + Payload: payload, + } + + magicNumber := "TRIX" + wrongMagicNumber := "XXXX" + encoded, err := Encode(trix, magicNumber) + assert.NoError(t, err) + + _, err = Decode(encoded, wrongMagicNumber) + assert.Error(t, err) + assert.EqualError(t, err, "trix: invalid magic number: expected XXXX, got TRIX") +} + +func TestEncode_InvalidMagicNumberLength(t *testing.T) { + header := map[string]interface{}{ + "content_type": "application/octet-stream", + } + payload := []byte("This is a secret message.") + + trix := &Trix{ + Header: header, + Payload: payload, + } + + magicNumber := "TOOLONG" + _, err := Encode(trix, magicNumber) + assert.Error(t, err) + assert.EqualError(t, err, "trix: magic number must be 4 bytes long") +} + +func TestDecode_InvalidMagicNumberLength(t *testing.T) { + header := map[string]interface{}{ + "content_type": "application/octet-stream", + } + payload := []byte("This is a secret message.") + + trix := &Trix{ + Header: header, + Payload: payload, + } + + magicNumber := "TRIX" + encoded, err := Encode(trix, magicNumber) + assert.NoError(t, err) + + invalidMagicNumber := "SHORT" + _, err = Decode(encoded, invalidMagicNumber) + assert.Error(t, err) + assert.EqualError(t, err, "trix: magic number must be 4 bytes long") +}