diff --git a/pkg/trix/trix_test.go b/pkg/trix/trix_test.go index c2c695f..a89c2fd 100644 --- a/pkg/trix/trix_test.go +++ b/pkg/trix/trix_test.go @@ -1,7 +1,9 @@ package trix_test import ( + "bytes" "errors" + "fmt" "io" "reflect" "testing" @@ -11,32 +13,33 @@ import ( "github.com/stretchr/testify/assert" ) -// mockReader is an io.Reader that fails on demand. -type mockReader struct { - readErr error +// failWriter is an io.Writer that fails on the nth write call. +type failWriter struct { + failOnCall int + callCount int } -func (m *mockReader) Read(p []byte) (n int, err error) { - if m.readErr != nil { - return 0, m.readErr - } - // Simulate a successful read by filling the buffer with zeros. - for i := range p { - p[i] = 0 +func (m *failWriter) Write(p []byte) (n int, err error) { + m.callCount++ + if m.callCount == m.failOnCall { + return 0, errors.New("write error") } return len(p), nil } -// mockWriter is an io.Writer that fails on demand. -type mockWriter struct { - writeErr error +// failReader is an io.Reader that fails on the nth read call. +type failReader struct { + failOnCall int + callCount int + reader io.Reader } -func (m *mockWriter) Write(p []byte) (n int, err error) { - if m.writeErr != nil { - return 0, m.writeErr +func (m *failReader) Read(p []byte) (n int, err error) { + m.callCount++ + if m.callCount == m.failOnCall { + return 0, errors.New("read error") } - return len(p), nil + return m.reader.Read(p) } // TestTrixEncodeDecode_Good tests the ideal "happy path" scenario for encoding and decoding. @@ -280,19 +283,57 @@ func FuzzDecode(f *testing.F) { }) } -func TestTrixEncodeDecode_IOErrors(t *testing.T) { - t.Run("EncodeWriteError", func(t *testing.T) { - trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} - _, err := trix.Encode(trixOb, "TRIX", &mockWriter{writeErr: errors.New("write error")}) +func TestEncode_WriteErrors(t *testing.T) { + trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} + + for i := 1; i <= 5; i++ { + t.Run(fmt.Sprintf("fail on write call %d", i), func(t *testing.T) { + writer := &failWriter{failOnCall: i} + _, err := trix.Encode(trixOb, "TRIX", writer) + assert.Error(t, err) + }) + } + + // Test for successful return with external writer + t.Run("SuccessfulExternalWrite", func(t *testing.T) { + writer := &failWriter{} + _, err := trix.Encode(trixOb, "TRIX", writer) + assert.NoError(t, err) + }) +} + +func TestDecode_ReadErrors(t *testing.T) { + trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} + encoded, err := trix.Encode(trixOb, "TRIX", nil) + assert.NoError(t, err) + + for i := 1; i <= 5; i++ { + t.Run(fmt.Sprintf("fail on read call %d", i), func(t *testing.T) { + reader := &failReader{failOnCall: i, reader: bytes.NewReader(encoded)} + _, err := trix.Decode(encoded, "TRIX", reader) + assert.Error(t, err) + }) + } + + t.Run("JSONUnmarshalError", func(t *testing.T) { + // Manually construct a byte slice with an invalid JSON header. + var buf []byte + buf = append(buf, []byte("TRIX")...) + buf = append(buf, byte(trix.Version)) + buf = append(buf, []byte{0, 0, 0, 5}...) + buf = append(buf, []byte("{")...) + buf = append(buf, []byte("payload")...) + + _, err := trix.Decode(buf, "TRIX", nil) assert.Error(t, err) }) - t.Run("DecodeReadError", func(t *testing.T) { - trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")} + t.Run("ChecksumMissingAlgo", func(t *testing.T) { + trixOb := &trix.Trix{Header: map[string]interface{}{"checksum": "abc"}, Payload: []byte("payload")} encoded, err := trix.Encode(trixOb, "TRIX", nil) assert.NoError(t, err) - _, err = trix.Decode(encoded, "TRIX", &mockReader{readErr: errors.New("read error")}) + _, err = trix.Decode(encoded, "TRIX", nil) assert.Error(t, err) }) }