diff --git a/cmd/collect_local.go b/cmd/collect_local.go index 4e88a95..998b15a 100644 --- a/cmd/collect_local.go +++ b/cmd/collect_local.go @@ -1,11 +1,15 @@ package cmd import ( + "archive/tar" + "bytes" "fmt" + "io" "io/fs" "os" "path/filepath" "strings" + "sync" "github.com/Snider/Borg/pkg/compress" "github.com/Snider/Borg/pkg/datanode" @@ -100,6 +104,21 @@ func CollectLocal(directory string, outputFile string, format string, compressio return "", fmt.Errorf("not a directory: %s", absDir) } + // Use streaming pipeline for STIM v2 format + if format == "stim" { + if outputFile == "" { + baseName := filepath.Base(absDir) + if baseName == "." || baseName == "/" { + baseName = "local" + } + outputFile = baseName + ".stim" + } + if err := CollectLocalStreaming(absDir, outputFile, compression, password); err != nil { + return "", err + } + return outputFile, nil + } + // Load gitignore patterns if enabled var gitignorePatterns []string if respectGitignore { @@ -331,3 +350,230 @@ func matchesExclude(path string, excludes []string) bool { } return false } + +// CollectLocalStreaming collects files from a local directory using a streaming +// pipeline: walk -> tar -> compress -> encrypt -> file. +// The encryption runs in a goroutine, consuming from an io.Pipe that the +// tar/compress writes feed into synchronously. +func CollectLocalStreaming(dir, output, compression, password string) error { + // Resolve to absolute path + absDir, err := filepath.Abs(dir) + if err != nil { + return fmt.Errorf("error resolving directory path: %w", err) + } + + // Validate directory exists + info, err := os.Stat(absDir) + if err != nil { + return fmt.Errorf("error accessing directory: %w", err) + } + if !info.IsDir() { + return fmt.Errorf("not a directory: %s", absDir) + } + + // Create output file + outFile, err := os.Create(output) + if err != nil { + return fmt.Errorf("error creating output file: %w", err) + } + + // cleanup removes partial output on error + cleanup := func() { + outFile.Close() + os.Remove(output) + } + + // Build streaming pipeline: + // tar.Writer -> compressWriter -> pipeWriter -> pipeReader -> StreamEncrypt -> outFile + pr, pw := io.Pipe() + + // Start encryption goroutine + var encErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + encErr = tim.StreamEncrypt(pr, outFile, password) + }() + + // Create compression writer wrapping the pipe writer + compWriter, err := compress.NewCompressWriter(pw, compression) + if err != nil { + pw.Close() + wg.Wait() + cleanup() + return fmt.Errorf("error creating compression writer: %w", err) + } + + // Create tar writer wrapping the compression writer + tw := tar.NewWriter(compWriter) + + // Walk directory and write tar entries + walkErr := filepath.WalkDir(absDir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // Get relative path + relPath, err := filepath.Rel(absDir, path) + if err != nil { + return err + } + + // Skip root + if relPath == "." { + return nil + } + + // Normalize to forward slashes for tar + relPath = filepath.ToSlash(relPath) + + // Check if entry is a symlink using Lstat + linfo, err := os.Lstat(path) + if err != nil { + return err + } + isSymlink := linfo.Mode()&fs.ModeSymlink != 0 + + if isSymlink { + // Read symlink target + linkTarget, err := os.Readlink(path) + if err != nil { + return err + } + + // Resolve to check if target exists + absTarget := linkTarget + if !filepath.IsAbs(absTarget) { + absTarget = filepath.Join(filepath.Dir(path), linkTarget) + } + _, statErr := os.Stat(absTarget) + if statErr != nil { + // Broken symlink - skip silently + return nil + } + + // Write valid symlink as tar entry + hdr := &tar.Header{ + Typeflag: tar.TypeSymlink, + Name: relPath, + Linkname: linkTarget, + Mode: 0777, + } + return tw.WriteHeader(hdr) + } + + if d.IsDir() { + // Write directory header + hdr := &tar.Header{ + Typeflag: tar.TypeDir, + Name: relPath + "/", + Mode: 0755, + } + return tw.WriteHeader(hdr) + } + + // Regular file: write header + content + finfo, err := d.Info() + if err != nil { + return err + } + + hdr := &tar.Header{ + Name: relPath, + Mode: 0644, + Size: finfo.Size(), + } + if err := tw.WriteHeader(hdr); err != nil { + return err + } + + f, err := os.Open(path) + if err != nil { + return fmt.Errorf("error opening %s: %w", relPath, err) + } + defer f.Close() + + if _, err := io.Copy(tw, f); err != nil { + return fmt.Errorf("error streaming %s: %w", relPath, err) + } + + return nil + }) + + // Close pipeline layers in order: tar -> compress -> pipe + // We must close even on error to unblock the encryption goroutine. + twCloseErr := tw.Close() + compCloseErr := compWriter.Close() + + if walkErr != nil { + pw.CloseWithError(walkErr) + wg.Wait() + cleanup() + return fmt.Errorf("error walking directory: %w", walkErr) + } + + if twCloseErr != nil { + pw.CloseWithError(twCloseErr) + wg.Wait() + cleanup() + return fmt.Errorf("error closing tar writer: %w", twCloseErr) + } + + if compCloseErr != nil { + pw.CloseWithError(compCloseErr) + wg.Wait() + cleanup() + return fmt.Errorf("error closing compression writer: %w", compCloseErr) + } + + // Signal EOF to encryption goroutine + pw.Close() + + // Wait for encryption to finish + wg.Wait() + + if encErr != nil { + cleanup() + return fmt.Errorf("error encrypting data: %w", encErr) + } + + // Close output file + if err := outFile.Close(); err != nil { + os.Remove(output) + return fmt.Errorf("error closing output file: %w", err) + } + + return nil +} + +// DecryptStimV2 decrypts a STIM v2 file back into a DataNode. +// It opens the file, runs StreamDecrypt, decompresses the result, +// and parses the tar archive into a DataNode. +func DecryptStimV2(path, password string) (*datanode.DataNode, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("error opening file: %w", err) + } + defer f.Close() + + // Decrypt + var decrypted bytes.Buffer + if err := tim.StreamDecrypt(f, &decrypted, password); err != nil { + return nil, fmt.Errorf("error decrypting: %w", err) + } + + // Decompress + decompressed, err := compress.Decompress(decrypted.Bytes()) + if err != nil { + return nil, fmt.Errorf("error decompressing: %w", err) + } + + // Parse tar into DataNode + dn, err := datanode.FromTar(decompressed) + if err != nil { + return nil, fmt.Errorf("error parsing tar: %w", err) + } + + return dn, nil +} diff --git a/cmd/collect_local_test.go b/cmd/collect_local_test.go new file mode 100644 index 0000000..28bd638 --- /dev/null +++ b/cmd/collect_local_test.go @@ -0,0 +1,161 @@ +package cmd + +import ( + "os" + "path/filepath" + "testing" +) + +func TestCollectLocalStreaming_Good(t *testing.T) { + // Create a temp directory with some test files + srcDir := t.TempDir() + outDir := t.TempDir() + + // Create files in subdirectories + subDir := filepath.Join(srcDir, "subdir") + if err := os.MkdirAll(subDir, 0755); err != nil { + t.Fatalf("failed to create subdir: %v", err) + } + + files := map[string]string{ + "hello.txt": "hello world", + "subdir/nested.go": "package main\n", + } + for name, content := range files { + path := filepath.Join(srcDir, name) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("failed to write %s: %v", name, err) + } + } + + output := filepath.Join(outDir, "test.stim") + err := CollectLocalStreaming(srcDir, output, "gz", "test-password") + if err != nil { + t.Fatalf("CollectLocalStreaming() error = %v", err) + } + + // Verify file exists and is non-empty + info, err := os.Stat(output) + if err != nil { + t.Fatalf("output file does not exist: %v", err) + } + if info.Size() == 0 { + t.Fatal("output file is empty") + } +} + +func TestCollectLocalStreaming_Decrypt_Good(t *testing.T) { + // Create a temp directory with known files + srcDir := t.TempDir() + outDir := t.TempDir() + + subDir := filepath.Join(srcDir, "pkg") + if err := os.MkdirAll(subDir, 0755); err != nil { + t.Fatalf("failed to create subdir: %v", err) + } + + expectedFiles := map[string]string{ + "README.md": "# Test Project\n", + "pkg/main.go": "package main\n\nfunc main() {}\n", + } + for name, content := range expectedFiles { + path := filepath.Join(srcDir, name) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("failed to write %s: %v", name, err) + } + } + + password := "decrypt-test-pw" + output := filepath.Join(outDir, "roundtrip.stim") + + // Collect + err := CollectLocalStreaming(srcDir, output, "gz", password) + if err != nil { + t.Fatalf("CollectLocalStreaming() error = %v", err) + } + + // Decrypt + dn, err := DecryptStimV2(output, password) + if err != nil { + t.Fatalf("DecryptStimV2() error = %v", err) + } + + // Verify each expected file exists in the DataNode + for name, wantContent := range expectedFiles { + f, err := dn.Open(name) + if err != nil { + t.Errorf("file %q not found in DataNode: %v", name, err) + continue + } + buf := make([]byte, 4096) + n, _ := f.Read(buf) + f.Close() + got := string(buf[:n]) + if got != wantContent { + t.Errorf("file %q content mismatch:\n got: %q\n want: %q", name, got, wantContent) + } + } +} + +func TestCollectLocalStreaming_BrokenSymlink_Good(t *testing.T) { + srcDir := t.TempDir() + outDir := t.TempDir() + + // Create a regular file + if err := os.WriteFile(filepath.Join(srcDir, "real.txt"), []byte("I exist"), 0644); err != nil { + t.Fatalf("failed to write real.txt: %v", err) + } + + // Create a broken symlink pointing to a nonexistent target + brokenLink := filepath.Join(srcDir, "broken-link") + if err := os.Symlink("/nonexistent/target/file", brokenLink); err != nil { + t.Fatalf("failed to create broken symlink: %v", err) + } + + output := filepath.Join(outDir, "symlink.stim") + err := CollectLocalStreaming(srcDir, output, "none", "sym-password") + if err != nil { + t.Fatalf("CollectLocalStreaming() should skip broken symlinks, got error = %v", err) + } + + // Verify output exists and is non-empty + info, err := os.Stat(output) + if err != nil { + t.Fatalf("output file does not exist: %v", err) + } + if info.Size() == 0 { + t.Fatal("output file is empty") + } + + // Decrypt and verify the broken symlink was skipped + dn, err := DecryptStimV2(output, "sym-password") + if err != nil { + t.Fatalf("DecryptStimV2() error = %v", err) + } + + // real.txt should be present + if _, err := dn.Stat("real.txt"); err != nil { + t.Error("expected real.txt in DataNode but it's missing") + } + + // broken-link should NOT be present + exists, _ := dn.Exists("broken-link") + if exists { + t.Error("broken symlink should have been skipped but was found in DataNode") + } +} + +func TestCollectLocalStreaming_Bad(t *testing.T) { + outDir := t.TempDir() + output := filepath.Join(outDir, "should-not-exist.stim") + + err := CollectLocalStreaming("/nonexistent/path/that/does/not/exist", output, "none", "password") + if err == nil { + t.Fatal("expected error for nonexistent directory, got nil") + } + + // Verify no partial output file was left behind + if _, statErr := os.Stat(output); statErr == nil { + t.Error("partial output file should have been cleaned up") + } +}