feat(collect): add streaming pipeline for STIM v2 output
Add CollectLocalStreaming that uses a streaming pipeline (walk -> tar -> compress -> encrypt -> file) via io.Pipe, avoiding buffering the entire dataset in memory. Add DecryptStimV2 for round-trip decryption back to DataNode. Wire streaming path into existing CollectLocal when format is "stim". Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
0b2ae3a0ba
commit
99284b472d
2 changed files with 407 additions and 0 deletions
|
|
@ -1,11 +1,15 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
|
|
@ -100,6 +104,21 @@ func CollectLocal(directory string, outputFile string, format string, compressio
|
|||
return "", fmt.Errorf("not a directory: %s", absDir)
|
||||
}
|
||||
|
||||
// Use streaming pipeline for STIM v2 format
|
||||
if format == "stim" {
|
||||
if outputFile == "" {
|
||||
baseName := filepath.Base(absDir)
|
||||
if baseName == "." || baseName == "/" {
|
||||
baseName = "local"
|
||||
}
|
||||
outputFile = baseName + ".stim"
|
||||
}
|
||||
if err := CollectLocalStreaming(absDir, outputFile, compression, password); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return outputFile, nil
|
||||
}
|
||||
|
||||
// Load gitignore patterns if enabled
|
||||
var gitignorePatterns []string
|
||||
if respectGitignore {
|
||||
|
|
@ -331,3 +350,230 @@ func matchesExclude(path string, excludes []string) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CollectLocalStreaming collects files from a local directory using a streaming
|
||||
// pipeline: walk -> tar -> compress -> encrypt -> file.
|
||||
// The encryption runs in a goroutine, consuming from an io.Pipe that the
|
||||
// tar/compress writes feed into synchronously.
|
||||
func CollectLocalStreaming(dir, output, compression, password string) error {
|
||||
// Resolve to absolute path
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error resolving directory path: %w", err)
|
||||
}
|
||||
|
||||
// Validate directory exists
|
||||
info, err := os.Stat(absDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error accessing directory: %w", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("not a directory: %s", absDir)
|
||||
}
|
||||
|
||||
// Create output file
|
||||
outFile, err := os.Create(output)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating output file: %w", err)
|
||||
}
|
||||
|
||||
// cleanup removes partial output on error
|
||||
cleanup := func() {
|
||||
outFile.Close()
|
||||
os.Remove(output)
|
||||
}
|
||||
|
||||
// Build streaming pipeline:
|
||||
// tar.Writer -> compressWriter -> pipeWriter -> pipeReader -> StreamEncrypt -> outFile
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
// Start encryption goroutine
|
||||
var encErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
encErr = tim.StreamEncrypt(pr, outFile, password)
|
||||
}()
|
||||
|
||||
// Create compression writer wrapping the pipe writer
|
||||
compWriter, err := compress.NewCompressWriter(pw, compression)
|
||||
if err != nil {
|
||||
pw.Close()
|
||||
wg.Wait()
|
||||
cleanup()
|
||||
return fmt.Errorf("error creating compression writer: %w", err)
|
||||
}
|
||||
|
||||
// Create tar writer wrapping the compression writer
|
||||
tw := tar.NewWriter(compWriter)
|
||||
|
||||
// Walk directory and write tar entries
|
||||
walkErr := filepath.WalkDir(absDir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get relative path
|
||||
relPath, err := filepath.Rel(absDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip root
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Normalize to forward slashes for tar
|
||||
relPath = filepath.ToSlash(relPath)
|
||||
|
||||
// Check if entry is a symlink using Lstat
|
||||
linfo, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
isSymlink := linfo.Mode()&fs.ModeSymlink != 0
|
||||
|
||||
if isSymlink {
|
||||
// Read symlink target
|
||||
linkTarget, err := os.Readlink(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Resolve to check if target exists
|
||||
absTarget := linkTarget
|
||||
if !filepath.IsAbs(absTarget) {
|
||||
absTarget = filepath.Join(filepath.Dir(path), linkTarget)
|
||||
}
|
||||
_, statErr := os.Stat(absTarget)
|
||||
if statErr != nil {
|
||||
// Broken symlink - skip silently
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write valid symlink as tar entry
|
||||
hdr := &tar.Header{
|
||||
Typeflag: tar.TypeSymlink,
|
||||
Name: relPath,
|
||||
Linkname: linkTarget,
|
||||
Mode: 0777,
|
||||
}
|
||||
return tw.WriteHeader(hdr)
|
||||
}
|
||||
|
||||
if d.IsDir() {
|
||||
// Write directory header
|
||||
hdr := &tar.Header{
|
||||
Typeflag: tar.TypeDir,
|
||||
Name: relPath + "/",
|
||||
Mode: 0755,
|
||||
}
|
||||
return tw.WriteHeader(hdr)
|
||||
}
|
||||
|
||||
// Regular file: write header + content
|
||||
finfo, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hdr := &tar.Header{
|
||||
Name: relPath,
|
||||
Mode: 0644,
|
||||
Size: finfo.Size(),
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening %s: %w", relPath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(tw, f); err != nil {
|
||||
return fmt.Errorf("error streaming %s: %w", relPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Close pipeline layers in order: tar -> compress -> pipe
|
||||
// We must close even on error to unblock the encryption goroutine.
|
||||
twCloseErr := tw.Close()
|
||||
compCloseErr := compWriter.Close()
|
||||
|
||||
if walkErr != nil {
|
||||
pw.CloseWithError(walkErr)
|
||||
wg.Wait()
|
||||
cleanup()
|
||||
return fmt.Errorf("error walking directory: %w", walkErr)
|
||||
}
|
||||
|
||||
if twCloseErr != nil {
|
||||
pw.CloseWithError(twCloseErr)
|
||||
wg.Wait()
|
||||
cleanup()
|
||||
return fmt.Errorf("error closing tar writer: %w", twCloseErr)
|
||||
}
|
||||
|
||||
if compCloseErr != nil {
|
||||
pw.CloseWithError(compCloseErr)
|
||||
wg.Wait()
|
||||
cleanup()
|
||||
return fmt.Errorf("error closing compression writer: %w", compCloseErr)
|
||||
}
|
||||
|
||||
// Signal EOF to encryption goroutine
|
||||
pw.Close()
|
||||
|
||||
// Wait for encryption to finish
|
||||
wg.Wait()
|
||||
|
||||
if encErr != nil {
|
||||
cleanup()
|
||||
return fmt.Errorf("error encrypting data: %w", encErr)
|
||||
}
|
||||
|
||||
// Close output file
|
||||
if err := outFile.Close(); err != nil {
|
||||
os.Remove(output)
|
||||
return fmt.Errorf("error closing output file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptStimV2 decrypts a STIM v2 file back into a DataNode.
|
||||
// It opens the file, runs StreamDecrypt, decompresses the result,
|
||||
// and parses the tar archive into a DataNode.
|
||||
func DecryptStimV2(path, password string) (*datanode.DataNode, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Decrypt
|
||||
var decrypted bytes.Buffer
|
||||
if err := tim.StreamDecrypt(f, &decrypted, password); err != nil {
|
||||
return nil, fmt.Errorf("error decrypting: %w", err)
|
||||
}
|
||||
|
||||
// Decompress
|
||||
decompressed, err := compress.Decompress(decrypted.Bytes())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decompressing: %w", err)
|
||||
}
|
||||
|
||||
// Parse tar into DataNode
|
||||
dn, err := datanode.FromTar(decompressed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing tar: %w", err)
|
||||
}
|
||||
|
||||
return dn, nil
|
||||
}
|
||||
|
|
|
|||
161
cmd/collect_local_test.go
Normal file
161
cmd/collect_local_test.go
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCollectLocalStreaming_Good(t *testing.T) {
|
||||
// Create a temp directory with some test files
|
||||
srcDir := t.TempDir()
|
||||
outDir := t.TempDir()
|
||||
|
||||
// Create files in subdirectories
|
||||
subDir := filepath.Join(srcDir, "subdir")
|
||||
if err := os.MkdirAll(subDir, 0755); err != nil {
|
||||
t.Fatalf("failed to create subdir: %v", err)
|
||||
}
|
||||
|
||||
files := map[string]string{
|
||||
"hello.txt": "hello world",
|
||||
"subdir/nested.go": "package main\n",
|
||||
}
|
||||
for name, content := range files {
|
||||
path := filepath.Join(srcDir, name)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to write %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
output := filepath.Join(outDir, "test.stim")
|
||||
err := CollectLocalStreaming(srcDir, output, "gz", "test-password")
|
||||
if err != nil {
|
||||
t.Fatalf("CollectLocalStreaming() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists and is non-empty
|
||||
info, err := os.Stat(output)
|
||||
if err != nil {
|
||||
t.Fatalf("output file does not exist: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Fatal("output file is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalStreaming_Decrypt_Good(t *testing.T) {
|
||||
// Create a temp directory with known files
|
||||
srcDir := t.TempDir()
|
||||
outDir := t.TempDir()
|
||||
|
||||
subDir := filepath.Join(srcDir, "pkg")
|
||||
if err := os.MkdirAll(subDir, 0755); err != nil {
|
||||
t.Fatalf("failed to create subdir: %v", err)
|
||||
}
|
||||
|
||||
expectedFiles := map[string]string{
|
||||
"README.md": "# Test Project\n",
|
||||
"pkg/main.go": "package main\n\nfunc main() {}\n",
|
||||
}
|
||||
for name, content := range expectedFiles {
|
||||
path := filepath.Join(srcDir, name)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to write %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
password := "decrypt-test-pw"
|
||||
output := filepath.Join(outDir, "roundtrip.stim")
|
||||
|
||||
// Collect
|
||||
err := CollectLocalStreaming(srcDir, output, "gz", password)
|
||||
if err != nil {
|
||||
t.Fatalf("CollectLocalStreaming() error = %v", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
dn, err := DecryptStimV2(output, password)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptStimV2() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify each expected file exists in the DataNode
|
||||
for name, wantContent := range expectedFiles {
|
||||
f, err := dn.Open(name)
|
||||
if err != nil {
|
||||
t.Errorf("file %q not found in DataNode: %v", name, err)
|
||||
continue
|
||||
}
|
||||
buf := make([]byte, 4096)
|
||||
n, _ := f.Read(buf)
|
||||
f.Close()
|
||||
got := string(buf[:n])
|
||||
if got != wantContent {
|
||||
t.Errorf("file %q content mismatch:\n got: %q\n want: %q", name, got, wantContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalStreaming_BrokenSymlink_Good(t *testing.T) {
|
||||
srcDir := t.TempDir()
|
||||
outDir := t.TempDir()
|
||||
|
||||
// Create a regular file
|
||||
if err := os.WriteFile(filepath.Join(srcDir, "real.txt"), []byte("I exist"), 0644); err != nil {
|
||||
t.Fatalf("failed to write real.txt: %v", err)
|
||||
}
|
||||
|
||||
// Create a broken symlink pointing to a nonexistent target
|
||||
brokenLink := filepath.Join(srcDir, "broken-link")
|
||||
if err := os.Symlink("/nonexistent/target/file", brokenLink); err != nil {
|
||||
t.Fatalf("failed to create broken symlink: %v", err)
|
||||
}
|
||||
|
||||
output := filepath.Join(outDir, "symlink.stim")
|
||||
err := CollectLocalStreaming(srcDir, output, "none", "sym-password")
|
||||
if err != nil {
|
||||
t.Fatalf("CollectLocalStreaming() should skip broken symlinks, got error = %v", err)
|
||||
}
|
||||
|
||||
// Verify output exists and is non-empty
|
||||
info, err := os.Stat(output)
|
||||
if err != nil {
|
||||
t.Fatalf("output file does not exist: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Fatal("output file is empty")
|
||||
}
|
||||
|
||||
// Decrypt and verify the broken symlink was skipped
|
||||
dn, err := DecryptStimV2(output, "sym-password")
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptStimV2() error = %v", err)
|
||||
}
|
||||
|
||||
// real.txt should be present
|
||||
if _, err := dn.Stat("real.txt"); err != nil {
|
||||
t.Error("expected real.txt in DataNode but it's missing")
|
||||
}
|
||||
|
||||
// broken-link should NOT be present
|
||||
exists, _ := dn.Exists("broken-link")
|
||||
if exists {
|
||||
t.Error("broken symlink should have been skipped but was found in DataNode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalStreaming_Bad(t *testing.T) {
|
||||
outDir := t.TempDir()
|
||||
output := filepath.Join(outDir, "should-not-exist.stim")
|
||||
|
||||
err := CollectLocalStreaming("/nonexistent/path/that/does/not/exist", output, "none", "password")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
|
||||
// Verify no partial output file was left behind
|
||||
if _, statErr := os.Stat(output); statErr == nil {
|
||||
t.Error("partial output file should have been cleaned up")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue