diff --git a/pkg/enchantrix/sigils.go b/pkg/enchantrix/sigils.go index 60c0391..6fc5cfa 100644 --- a/pkg/enchantrix/sigils.go +++ b/pkg/enchantrix/sigils.go @@ -73,12 +73,18 @@ func (s *Base64Sigil) Out(data []byte) ([]byte, error) { } // GzipSigil is a Sigil that compresses/decompresses data using gzip. -type GzipSigil struct{} +type GzipSigil struct { + writer io.Writer +} // In compresses the data using gzip. func (s *GzipSigil) In(data []byte) ([]byte, error) { var b bytes.Buffer - gz := gzip.NewWriter(&b) + w := s.writer + if w == nil { + w = &b + } + gz := gzip.NewWriter(w) if _, err := gz.Write(data); err != nil { return nil, err } diff --git a/pkg/enchantrix/sigils_test.go b/pkg/enchantrix/sigils_test.go index defb616..c191bd2 100644 --- a/pkg/enchantrix/sigils_test.go +++ b/pkg/enchantrix/sigils_test.go @@ -2,11 +2,32 @@ package enchantrix import ( "encoding/hex" + "errors" "testing" "github.com/stretchr/testify/assert" ) +// mockWriter is a writer that fails on Write +type mockWriter struct{} + +func (m *mockWriter) Write(p []byte) (n int, err error) { + return 0, errors.New("write error") +} + +// failOnSecondWrite is a writer that fails on the second write call. +type failOnSecondWrite struct { + callCount int +} + +func (m *failOnSecondWrite) Write(p []byte) (n int, err error) { + m.callCount++ + if m.callCount > 1 { + return 0, errors.New("second write failed") + } + return len(p), nil +} + func TestReverseSigil(t *testing.T) { s := &ReverseSigil{} data := []byte("hello") @@ -67,6 +88,16 @@ func TestGzipSigil(t *testing.T) { // Bad - invalid gzip data _, err = s.Out([]byte("not gzip")) assert.Error(t, err) + + // Test writer error + s.writer = &mockWriter{} + _, err = s.In(data) + assert.Error(t, err) + + // Test closer error + s.writer = &failOnSecondWrite{} + _, err = s.In(data) + assert.Error(t, err) } func TestJSONSigil(t *testing.T) { @@ -83,6 +114,11 @@ func TestJSONSigil(t *testing.T) { // Bad - invalid json _, err = s.In([]byte("not json")) assert.Error(t, err) + + // Out is a no-op, so it should return the data as-is + outData, err := s.Out(data) + assert.NoError(t, err) + assert.Equal(t, data, outData) } func TestHashSigils_Good(t *testing.T) {