diff --git a/pkg/trix/trix.go b/pkg/trix/trix.go index bb78c3c..e0e8ca6 100644 --- a/pkg/trix/trix.go +++ b/pkg/trix/trix.go @@ -34,10 +34,10 @@ type Trix struct { ChecksumAlgo crypt.HashType `json:"-"` } -// Encode serializes a Trix struct into the .trix binary format. -func Encode(trix *Trix, magicNumber string) ([]byte, error) { +// EncodeTo serializes a Trix struct into the .trix binary format and writes it to an io.Writer. +func EncodeTo(trix *Trix, magicNumber string, w io.Writer) error { if len(magicNumber) != 4 { - return nil, ErrMagicNumberLength + return ErrMagicNumberLength } // Calculate and add checksum if an algorithm is specified @@ -49,52 +49,57 @@ func Encode(trix *Trix, magicNumber string) ([]byte, error) { headerBytes, err := json.Marshal(trix.Header) if err != nil { - return nil, err + return err } headerLength := uint32(len(headerBytes)) - buf := new(bytes.Buffer) - // Write Magic Number - if _, err := buf.WriteString(magicNumber); err != nil { - return nil, err + if _, err := io.WriteString(w, magicNumber); err != nil { + return err } // Write Version - if err := buf.WriteByte(byte(Version)); err != nil { - return nil, err + if _, err := w.Write([]byte{byte(Version)}); err != nil { + return err } // Write Header Length - if err := binary.Write(buf, binary.BigEndian, headerLength); err != nil { - return nil, err + if err := binary.Write(w, binary.BigEndian, headerLength); err != nil { + return err } // Write JSON Header - if _, err := buf.Write(headerBytes); err != nil { - return nil, err + if _, err := w.Write(headerBytes); err != nil { + return err } // Write Payload - if _, err := buf.Write(trix.Payload); err != nil { - return nil, err + if _, err := w.Write(trix.Payload); err != nil { + return err } + return nil +} + +// Encode serializes a Trix struct into the .trix binary format. +func Encode(trix *Trix, magicNumber string) ([]byte, error) { + var buf bytes.Buffer + err := EncodeTo(trix, magicNumber, &buf) + if err != nil { + return nil, err + } return buf.Bytes(), nil } -// Decode deserializes the .trix binary format into a Trix struct. -// Note: Sigils are not stored in the format and must be re-attached by the caller. -func Decode(data []byte, magicNumber string) (*Trix, error) { +// DecodeFrom deserializes the .trix binary format from an io.Reader into a Trix struct. +func DecodeFrom(r io.Reader, magicNumber string) (*Trix, error) { if len(magicNumber) != 4 { return nil, ErrMagicNumberLength } - buf := bytes.NewReader(data) - // Read and Verify Magic Number magic := make([]byte, 4) - if _, err := io.ReadFull(buf, magic); err != nil { + if _, err := io.ReadFull(r, magic); err != nil { return nil, err } if string(magic) != magicNumber { @@ -102,28 +107,26 @@ func Decode(data []byte, magicNumber string) (*Trix, error) { } // Read and Verify Version - version, err := buf.ReadByte() - if err != nil { + versionByte := make([]byte, 1) + if _, err := io.ReadFull(r, versionByte); err != nil { return nil, err } - if version != Version { + if versionByte[0] != Version { return nil, ErrInvalidVersion } // Read Header Length var headerLength uint32 - if err := binary.Read(buf, binary.BigEndian, &headerLength); err != nil { + if err := binary.Read(r, binary.BigEndian, &headerLength); err != nil { return nil, err } - // Check if the announced header length is longer than the remaining buffer. - if int64(headerLength) > int64(buf.Len()) { - return nil, ErrInvalidHeaderLength - } + // We can't implement the ErrInvalidHeaderLength check here because we don't know the total length of the stream. + // The check is implicitly handled by io.ReadFull, which will return io.ErrUnexpectedEOF if the stream ends prematurely. // Read JSON Header headerBytes := make([]byte, headerLength) - if _, err := io.ReadFull(buf, headerBytes); err != nil { + if _, err := io.ReadFull(r, headerBytes); err != nil { return nil, err } var header map[string]interface{} @@ -132,7 +135,7 @@ func Decode(data []byte, magicNumber string) (*Trix, error) { } // Read Payload - payload, err := io.ReadAll(buf) + payload, err := io.ReadAll(r) if err != nil { return nil, err } @@ -155,6 +158,26 @@ func Decode(data []byte, magicNumber string) (*Trix, error) { }, nil } +// Decode deserializes the .trix binary format into a Trix struct. +// Note: Sigils are not stored in the format and must be re-attached by the caller. +func Decode(data []byte, magicNumber string) (*Trix, error) { + buf := bytes.NewReader(data) + + // We can perform the header length check here because we have the full byte slice. + // We read the header length, check it, then pass the rest of the buffer to DecodeFrom. + // This is a bit of a hack, but it's the only way to keep the check. + // A better solution would be to have a separate DecodeBytes function. + if len(data) > 9 { // 4 magic + 1 version + 4 header length + headerLengthBytes := data[5:9] + headerLength := binary.BigEndian.Uint32(headerLengthBytes) + if int64(headerLength) > int64(len(data)-9) { + return nil, ErrInvalidHeaderLength + } + } + + return DecodeFrom(buf, magicNumber) +} + // Pack applies the In method of all attached sigils to the payload. func (t *Trix) Pack() error { for _, sigilName := range t.InSigils { diff --git a/pkg/trix/trix_test.go b/pkg/trix/trix_test.go index 69ae067..7a640e1 100644 --- a/pkg/trix/trix_test.go +++ b/pkg/trix/trix_test.go @@ -1,7 +1,9 @@ package trix import ( + "bytes" "encoding/binary" + "errors" "reflect" "testing" @@ -165,6 +167,62 @@ func TestPackUnpack_Ugly(t *testing.T) { // --- Checksum Tests --- +// --- Stream Tests --- + +// mockErrorWriter is an io.Writer that always returns an error. +type mockErrorWriter struct{} + +func (w *mockErrorWriter) Write(p []byte) (n int, err error) { + return 0, errors.New("mock writer error") +} + +// mockErrorReader is an io.Reader that always returns an error. +type mockErrorReader struct{} + +func (r *mockErrorReader) Read(p []byte) (n int, err error) { + return 0, errors.New("mock reader error") +} + +func TestStream_Good(t *testing.T) { + trix := &Trix{ + Header: map[string]interface{}{"content_type": "text/plain"}, + Payload: []byte("hello world"), + } + magicNumber := "STRM" + + var buf bytes.Buffer + err := EncodeTo(trix, magicNumber, &buf) + assert.NoError(t, err) + + decoded, err := DecodeFrom(&buf, magicNumber) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual(trix.Header, decoded.Header)) + assert.Equal(t, trix.Payload, decoded.Payload) +} + +func TestStream_Bad(t *testing.T) { + t.Run("WriterError", func(t *testing.T) { + trix := &Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} + err := EncodeTo(trix, "FAIL", &mockErrorWriter{}) + assert.Error(t, err) + assert.EqualError(t, err, "mock writer error") + }) + + t.Run("ReaderError", func(t *testing.T) { + _, err := DecodeFrom(&mockErrorReader{}, "FAIL") + assert.Error(t, err) + assert.EqualError(t, err, "mock reader error") + }) +} + +func TestStream_Ugly(t *testing.T) { + t.Run("EmptyReader", func(t *testing.T) { + _, err := DecodeFrom(bytes.NewReader([]byte{}), "UGLY") + assert.Error(t, err) + }) +} + func TestChecksum_Ugly(t *testing.T) { t.Run("MissingAlgoInHeader", func(t *testing.T) { header := `{"checksum":"5891b5b522d5df086d0ff0b110fbd9d21bb4fc7163af34d08286a2e846f6be03"}` // sha256 checksum for "hello world"