diff --git a/pkg/compress/compress.go b/pkg/compress/compress.go index 07e4d28..6322097 100644 --- a/pkg/compress/compress.go +++ b/pkg/compress/compress.go @@ -3,11 +3,34 @@ package compress import ( "bytes" "compress/gzip" + "fmt" "io" "github.com/ulikunitz/xz" ) +// nopCloser wraps an io.Writer with a no-op Close method. +type nopCloser struct{ io.Writer } + +func (n *nopCloser) Close() error { return nil } + +// NewCompressWriter returns a streaming io.WriteCloser that compresses data +// written to it into the underlying writer w using the specified format. +// Supported formats: "gz" (gzip), "xz", "none" or "" (passthrough). +// Unknown formats return an error. +func NewCompressWriter(w io.Writer, format string) (io.WriteCloser, error) { + switch format { + case "gz": + return gzip.NewWriter(w), nil + case "xz": + return xz.NewWriter(w) + case "none", "": + return &nopCloser{w}, nil + default: + return nil, fmt.Errorf("unsupported compression format: %q", format) + } +} + // Compress compresses data using the specified format. func Compress(data []byte, format string) ([]byte, error) { var buf bytes.Buffer diff --git a/pkg/compress/compress_test.go b/pkg/compress/compress_test.go index 489bff4..3d9d75c 100644 --- a/pkg/compress/compress_test.go +++ b/pkg/compress/compress_test.go @@ -5,6 +5,108 @@ import ( "testing" ) +func TestNewCompressWriter_Gzip_Good(t *testing.T) { + original := []byte("hello, streaming gzip world") + var buf bytes.Buffer + + w, err := NewCompressWriter(&buf, "gz") + if err != nil { + t.Fatalf("NewCompressWriter(gz) error: %v", err) + } + if _, err := w.Write(original); err != nil { + t.Fatalf("Write error: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("Close error: %v", err) + } + + compressed := buf.Bytes() + if bytes.Equal(original, compressed) { + t.Fatal("compressed data should differ from original") + } + + decompressed, err := Decompress(compressed) + if err != nil { + t.Fatalf("Decompress error: %v", err) + } + if !bytes.Equal(original, decompressed) { + t.Errorf("round-trip mismatch: got %q, want %q", decompressed, original) + } +} + +func TestNewCompressWriter_Xz_Good(t *testing.T) { + original := []byte("hello, streaming xz world") + var buf bytes.Buffer + + w, err := NewCompressWriter(&buf, "xz") + if err != nil { + t.Fatalf("NewCompressWriter(xz) error: %v", err) + } + if _, err := w.Write(original); err != nil { + t.Fatalf("Write error: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("Close error: %v", err) + } + + compressed := buf.Bytes() + if bytes.Equal(original, compressed) { + t.Fatal("compressed data should differ from original") + } + + decompressed, err := Decompress(compressed) + if err != nil { + t.Fatalf("Decompress error: %v", err) + } + if !bytes.Equal(original, decompressed) { + t.Errorf("round-trip mismatch: got %q, want %q", decompressed, original) + } +} + +func TestNewCompressWriter_None_Good(t *testing.T) { + original := []byte("hello, passthrough world") + var buf bytes.Buffer + + w, err := NewCompressWriter(&buf, "none") + if err != nil { + t.Fatalf("NewCompressWriter(none) error: %v", err) + } + if _, err := w.Write(original); err != nil { + t.Fatalf("Write error: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("Close error: %v", err) + } + + if !bytes.Equal(original, buf.Bytes()) { + t.Errorf("passthrough mismatch: got %q, want %q", buf.Bytes(), original) + } + + // Also test empty string format + var buf2 bytes.Buffer + w2, err := NewCompressWriter(&buf2, "") + if err != nil { + t.Fatalf("NewCompressWriter('') error: %v", err) + } + if _, err := w2.Write(original); err != nil { + t.Fatalf("Write error: %v", err) + } + if err := w2.Close(); err != nil { + t.Fatalf("Close error: %v", err) + } + if !bytes.Equal(original, buf2.Bytes()) { + t.Errorf("passthrough (empty string) mismatch: got %q, want %q", buf2.Bytes(), original) + } +} + +func TestNewCompressWriter_Bad(t *testing.T) { + var buf bytes.Buffer + _, err := NewCompressWriter(&buf, "invalid-format") + if err == nil { + t.Fatal("expected error for unknown compression format, got nil") + } +} + func TestGzip_Good(t *testing.T) { originalData := []byte("hello, gzip world") compressed, err := Compress(originalData, "gz")