Compare commits
1 commit
main
...
feat/failu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
46ffec7071 |
31 changed files with 498 additions and 4802 deletions
|
|
@ -11,11 +11,14 @@ func init() {
|
|||
RootCmd.AddCommand(GetCollectCmd())
|
||||
}
|
||||
func NewCollectCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
cmd := &cobra.Command{
|
||||
Use: "collect",
|
||||
Short: "Collect a resource from a URI.",
|
||||
Long: `Collect a resource from a URI and store it in a DataNode.`,
|
||||
}
|
||||
cmd.PersistentFlags().String("on-failure", "continue", "Action to take on failure: continue, stop, prompt")
|
||||
cmd.PersistentFlags().String("failures-dir", ".borg-failures", "Directory to store failure reports")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func GetCollectCmd() *cobra.Command {
|
||||
|
|
|
|||
|
|
@ -37,81 +37,7 @@ func NewCollectGithubRepoCmd() *cobra.Command {
|
|||
compression, _ := cmd.Flags().GetString("compression")
|
||||
password, _ := cmd.Flags().GetString("password")
|
||||
|
||||
if format != "datanode" && format != "tim" && format != "trix" && format != "stim" {
|
||||
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', 'trix', or 'stim')", format)
|
||||
}
|
||||
if compression != "none" && compression != "gz" && compression != "xz" {
|
||||
return fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression)
|
||||
}
|
||||
|
||||
prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote)
|
||||
prompter.Start()
|
||||
defer prompter.Stop()
|
||||
|
||||
var progressWriter io.Writer
|
||||
if prompter.IsInteractive() {
|
||||
bar := ui.NewProgressBar(-1, "Cloning repository")
|
||||
progressWriter = ui.NewProgressWriter(bar)
|
||||
}
|
||||
|
||||
dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error cloning repository: %w", err)
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if format == "tim" {
|
||||
t, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = t.ToTar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing tim: %w", err)
|
||||
}
|
||||
} else if format == "stim" {
|
||||
if password == "" {
|
||||
return fmt.Errorf("password required for stim format")
|
||||
}
|
||||
t, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = t.ToSigil(password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error encrypting stim: %w", err)
|
||||
}
|
||||
} else if format == "trix" {
|
||||
data, err = trix.ToTrix(dn, password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing trix: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing DataNode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error compressing data: %w", err)
|
||||
}
|
||||
|
||||
if outputFile == "" {
|
||||
outputFile = "repo." + format
|
||||
if compression != "none" {
|
||||
outputFile += "." + compression
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, defaultFilePermission)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing DataNode to file: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Repository saved to", outputFile)
|
||||
return nil
|
||||
return collectRepo(repoURL, outputFile, format, compression, password, cmd)
|
||||
},
|
||||
}
|
||||
cmd.Flags().String("output", "", "Output file for the DataNode")
|
||||
|
|
@ -121,6 +47,84 @@ func NewCollectGithubRepoCmd() *cobra.Command {
|
|||
return cmd
|
||||
}
|
||||
|
||||
func collectRepo(repoURL, outputFile, format, compression, password string, cmd *cobra.Command) error {
|
||||
if format != "datanode" && format != "tim" && format != "trix" && format != "stim" {
|
||||
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', 'trix', or 'stim')", format)
|
||||
}
|
||||
if compression != "none" && compression != "gz" && compression != "xz" {
|
||||
return fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression)
|
||||
}
|
||||
|
||||
prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote)
|
||||
prompter.Start()
|
||||
defer prompter.Stop()
|
||||
|
||||
var progressWriter io.Writer
|
||||
if prompter.IsInteractive() {
|
||||
bar := ui.NewProgressBar(-1, "Cloning repository")
|
||||
progressWriter = ui.NewProgressWriter(bar)
|
||||
}
|
||||
|
||||
dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error cloning repository: %w", err)
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if format == "tim" {
|
||||
t, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = t.ToTar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing tim: %w", err)
|
||||
}
|
||||
} else if format == "stim" {
|
||||
if password == "" {
|
||||
return fmt.Errorf("password required for stim format")
|
||||
}
|
||||
t, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = t.ToSigil(password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error encrypting stim: %w", err)
|
||||
}
|
||||
} else if format == "trix" {
|
||||
data, err = trix.ToTrix(dn, password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing trix: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing DataNode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error compressing data: %w", err)
|
||||
}
|
||||
|
||||
if outputFile == "" {
|
||||
outputFile = "repo." + format
|
||||
if compression != "none" {
|
||||
outputFile += "." + compression
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, defaultFilePermission)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing DataNode to file: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Repository saved to", outputFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
collectGithubCmd.AddCommand(NewCollectGithubRepoCmd())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/failures"
|
||||
"github.com/Snider/Borg/pkg/github"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
|
@ -17,13 +19,57 @@ var collectGithubReposCmd = &cobra.Command{
|
|||
Short: "Collects all public repositories for a user or organization",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
failuresDir, _ := cmd.Flags().GetString("failures-dir")
|
||||
onFailure, _ := cmd.Flags().GetString("on-failure")
|
||||
|
||||
manager, err := failures.NewManager(failuresDir, "github:repos:"+args[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create failure manager: %w", err)
|
||||
}
|
||||
defer manager.Finalize()
|
||||
|
||||
repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, repo := range repos {
|
||||
fmt.Fprintln(cmd.OutOrStdout(), repo)
|
||||
|
||||
manager.SetTotal(len(repos))
|
||||
|
||||
attempts := make(map[string]int)
|
||||
for i := 0; i < len(repos); i++ {
|
||||
repo := repos[i]
|
||||
attempts[repo]++
|
||||
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Collecting", repo)
|
||||
err := collectRepo(repo, "", "datanode", "none", "", cmd)
|
||||
if err != nil {
|
||||
retryable := !strings.Contains(err.Error(), "not found")
|
||||
manager.RecordFailure(&failures.Failure{
|
||||
URL: repo,
|
||||
Error: err.Error(),
|
||||
Retryable: retryable,
|
||||
Attempts: attempts[repo],
|
||||
})
|
||||
|
||||
if onFailure == "stop" {
|
||||
return fmt.Errorf("stopping on first failure: %w", err)
|
||||
} else if onFailure == "prompt" {
|
||||
fmt.Printf("Failed to collect %s. Would you like to (c)ontinue, (s)top, or (r)etry? ", repo)
|
||||
var response string
|
||||
fmt.Scanln(&response)
|
||||
switch response {
|
||||
case "s":
|
||||
return fmt.Errorf("stopping on user prompt")
|
||||
case "r":
|
||||
i-- // Retry the same repo
|
||||
continue
|
||||
default:
|
||||
// Continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,581 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type CollectLocalCmd struct {
|
||||
cobra.Command
|
||||
}
|
||||
|
||||
// NewCollectLocalCmd creates a new collect local command
|
||||
func NewCollectLocalCmd() *CollectLocalCmd {
|
||||
c := &CollectLocalCmd{}
|
||||
c.Command = cobra.Command{
|
||||
Use: "local [directory]",
|
||||
Short: "Collect files from a local directory",
|
||||
Long: `Collect local files into a portable container.
|
||||
|
||||
For STIM format, uses streaming I/O — memory usage is constant
|
||||
(~2 MiB) regardless of input directory size. Other formats
|
||||
(datanode, tim, trix) load files into memory.
|
||||
|
||||
Examples:
|
||||
borg collect local
|
||||
borg collect local ./src
|
||||
borg collect local /path/to/project --output project.tar
|
||||
borg collect local . --format stim --password secret
|
||||
borg collect local . --exclude "*.log" --exclude "node_modules"`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
directory := "."
|
||||
if len(args) > 0 {
|
||||
directory = args[0]
|
||||
}
|
||||
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
password, _ := cmd.Flags().GetString("password")
|
||||
excludes, _ := cmd.Flags().GetStringSlice("exclude")
|
||||
includeHidden, _ := cmd.Flags().GetBool("hidden")
|
||||
respectGitignore, _ := cmd.Flags().GetBool("gitignore")
|
||||
|
||||
progress := ProgressFromCmd(cmd)
|
||||
finalPath, err := CollectLocal(directory, outputFile, format, compression, password, excludes, includeHidden, respectGitignore, progress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Files saved to", finalPath)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
c.Flags().String("output", "", "Output file for the DataNode")
|
||||
c.Flags().String("format", "datanode", "Output format (datanode, tim, trix, or stim)")
|
||||
c.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
c.Flags().String("password", "", "Password for encryption (required for stim/trix format)")
|
||||
c.Flags().StringSlice("exclude", nil, "Patterns to exclude (can be specified multiple times)")
|
||||
c.Flags().Bool("hidden", false, "Include hidden files and directories")
|
||||
c.Flags().Bool("gitignore", true, "Respect .gitignore files (default: true)")
|
||||
return c
|
||||
}
|
||||
|
||||
func init() {
|
||||
collectCmd.AddCommand(&NewCollectLocalCmd().Command)
|
||||
}
|
||||
|
||||
// CollectLocal collects files from a local directory into a DataNode
|
||||
func CollectLocal(directory string, outputFile string, format string, compression string, password string, excludes []string, includeHidden bool, respectGitignore bool, progress ui.Progress) (string, error) {
|
||||
// Validate format
|
||||
if format != "datanode" && format != "tim" && format != "trix" && format != "stim" {
|
||||
return "", fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', 'trix', or 'stim')", format)
|
||||
}
|
||||
if (format == "stim" || format == "trix") && password == "" {
|
||||
return "", fmt.Errorf("password is required for %s format", format)
|
||||
}
|
||||
if compression != "none" && compression != "gz" && compression != "xz" {
|
||||
return "", fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression)
|
||||
}
|
||||
|
||||
// Resolve directory path
|
||||
absDir, err := filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error resolving directory path: %w", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(absDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error accessing directory: %w", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory: %s", absDir)
|
||||
}
|
||||
|
||||
// Use streaming pipeline for STIM v2 format
|
||||
if format == "stim" {
|
||||
if outputFile == "" {
|
||||
baseName := filepath.Base(absDir)
|
||||
if baseName == "." || baseName == "/" {
|
||||
baseName = "local"
|
||||
}
|
||||
outputFile = baseName + ".stim"
|
||||
}
|
||||
if err := CollectLocalStreaming(absDir, outputFile, compression, password); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return outputFile, nil
|
||||
}
|
||||
|
||||
// Load gitignore patterns if enabled
|
||||
var gitignorePatterns []string
|
||||
if respectGitignore {
|
||||
gitignorePatterns = loadGitignore(absDir)
|
||||
}
|
||||
|
||||
// Create DataNode and collect files
|
||||
dn := datanode.New()
|
||||
var fileCount int
|
||||
|
||||
progress.Start("collecting " + directory)
|
||||
|
||||
err = filepath.WalkDir(absDir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get relative path
|
||||
relPath, err := filepath.Rel(absDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip root
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip hidden files/dirs unless explicitly included
|
||||
if !includeHidden && isHidden(relPath) {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check gitignore patterns
|
||||
if respectGitignore && matchesGitignore(relPath, d.IsDir(), gitignorePatterns) {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check exclude patterns
|
||||
if matchesExclude(relPath, excludes) {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip directories (they're implicit in DataNode)
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read file content
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading %s: %w", relPath, err)
|
||||
}
|
||||
|
||||
// Add to DataNode with forward slashes (tar convention)
|
||||
dn.AddData(filepath.ToSlash(relPath), content)
|
||||
fileCount++
|
||||
progress.Update(int64(fileCount), 0)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error walking directory: %w", err)
|
||||
}
|
||||
|
||||
if fileCount == 0 {
|
||||
return "", fmt.Errorf("no files found in %s", directory)
|
||||
}
|
||||
|
||||
progress.Finish(fmt.Sprintf("collected %d files", fileCount))
|
||||
|
||||
// Convert to output format
|
||||
var data []byte
|
||||
if format == "tim" {
|
||||
t, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = t.ToTar()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing tim: %w", err)
|
||||
}
|
||||
} else if format == "stim" {
|
||||
t, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = t.ToSigil(password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encrypting stim: %w", err)
|
||||
}
|
||||
} else if format == "trix" {
|
||||
data, err = trix.ToTrix(dn, password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing trix: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing DataNode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply compression
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error compressing data: %w", err)
|
||||
}
|
||||
|
||||
// Determine output filename
|
||||
if outputFile == "" {
|
||||
baseName := filepath.Base(absDir)
|
||||
if baseName == "." || baseName == "/" {
|
||||
baseName = "local"
|
||||
}
|
||||
outputFile = baseName + "." + format
|
||||
if compression != "none" {
|
||||
outputFile += "." + compression
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error writing output file: %w", err)
|
||||
}
|
||||
|
||||
return outputFile, nil
|
||||
}
|
||||
|
||||
// isHidden checks if a path component starts with a dot
|
||||
func isHidden(path string) bool {
|
||||
parts := strings.Split(filepath.ToSlash(path), "/")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(part, ".") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadGitignore loads patterns from .gitignore if it exists
|
||||
func loadGitignore(dir string) []string {
|
||||
var patterns []string
|
||||
|
||||
gitignorePath := filepath.Join(dir, ".gitignore")
|
||||
content, err := os.ReadFile(gitignorePath)
|
||||
if err != nil {
|
||||
return patterns
|
||||
}
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
// Skip empty lines and comments
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
patterns = append(patterns, line)
|
||||
}
|
||||
|
||||
return patterns
|
||||
}
|
||||
|
||||
// matchesGitignore checks if a path matches any gitignore pattern
|
||||
func matchesGitignore(path string, isDir bool, patterns []string) bool {
|
||||
for _, pattern := range patterns {
|
||||
// Handle directory-only patterns
|
||||
if strings.HasSuffix(pattern, "/") {
|
||||
if !isDir {
|
||||
continue
|
||||
}
|
||||
pattern = strings.TrimSuffix(pattern, "/")
|
||||
}
|
||||
|
||||
// Handle negation (simplified - just skip negated patterns)
|
||||
if strings.HasPrefix(pattern, "!") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Match against path components
|
||||
matched, _ := filepath.Match(pattern, filepath.Base(path))
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Also try matching the full path
|
||||
matched, _ = filepath.Match(pattern, path)
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle ** patterns (simplified)
|
||||
if strings.Contains(pattern, "**") {
|
||||
simplePattern := strings.ReplaceAll(pattern, "**", "*")
|
||||
matched, _ = filepath.Match(simplePattern, path)
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesExclude checks if a path matches any exclude pattern
|
||||
func matchesExclude(path string, excludes []string) bool {
|
||||
for _, pattern := range excludes {
|
||||
// Match against basename
|
||||
matched, _ := filepath.Match(pattern, filepath.Base(path))
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match against full path
|
||||
matched, _ = filepath.Match(pattern, path)
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CollectLocalStreaming collects files from a local directory using a streaming
|
||||
// pipeline: walk -> tar -> compress -> encrypt -> file.
|
||||
// The encryption runs in a goroutine, consuming from an io.Pipe that the
|
||||
// tar/compress writes feed into synchronously.
|
||||
func CollectLocalStreaming(dir, output, compression, password string) error {
|
||||
// Resolve to absolute path
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error resolving directory path: %w", err)
|
||||
}
|
||||
|
||||
// Validate directory exists
|
||||
info, err := os.Stat(absDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error accessing directory: %w", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("not a directory: %s", absDir)
|
||||
}
|
||||
|
||||
// Create output file
|
||||
outFile, err := os.Create(output)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating output file: %w", err)
|
||||
}
|
||||
|
||||
// cleanup removes partial output on error
|
||||
cleanup := func() {
|
||||
outFile.Close()
|
||||
os.Remove(output)
|
||||
}
|
||||
|
||||
// Build streaming pipeline:
|
||||
// tar.Writer -> compressWriter -> pipeWriter -> pipeReader -> StreamEncrypt -> outFile
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
// Start encryption goroutine
|
||||
var encErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
encErr = tim.StreamEncrypt(pr, outFile, password)
|
||||
}()
|
||||
|
||||
// Create compression writer wrapping the pipe writer
|
||||
compWriter, err := compress.NewCompressWriter(pw, compression)
|
||||
if err != nil {
|
||||
pw.Close()
|
||||
wg.Wait()
|
||||
cleanup()
|
||||
return fmt.Errorf("error creating compression writer: %w", err)
|
||||
}
|
||||
|
||||
// Create tar writer wrapping the compression writer
|
||||
tw := tar.NewWriter(compWriter)
|
||||
|
||||
// Walk directory and write tar entries
|
||||
walkErr := filepath.WalkDir(absDir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get relative path
|
||||
relPath, err := filepath.Rel(absDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip root
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Normalize to forward slashes for tar
|
||||
relPath = filepath.ToSlash(relPath)
|
||||
|
||||
// Check if entry is a symlink using Lstat
|
||||
linfo, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
isSymlink := linfo.Mode()&fs.ModeSymlink != 0
|
||||
|
||||
if isSymlink {
|
||||
// Read symlink target
|
||||
linkTarget, err := os.Readlink(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Resolve to check if target exists
|
||||
absTarget := linkTarget
|
||||
if !filepath.IsAbs(absTarget) {
|
||||
absTarget = filepath.Join(filepath.Dir(path), linkTarget)
|
||||
}
|
||||
_, statErr := os.Stat(absTarget)
|
||||
if statErr != nil {
|
||||
// Broken symlink - skip silently
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write valid symlink as tar entry
|
||||
hdr := &tar.Header{
|
||||
Typeflag: tar.TypeSymlink,
|
||||
Name: relPath,
|
||||
Linkname: linkTarget,
|
||||
Mode: 0777,
|
||||
}
|
||||
return tw.WriteHeader(hdr)
|
||||
}
|
||||
|
||||
if d.IsDir() {
|
||||
// Write directory header
|
||||
hdr := &tar.Header{
|
||||
Typeflag: tar.TypeDir,
|
||||
Name: relPath + "/",
|
||||
Mode: 0755,
|
||||
}
|
||||
return tw.WriteHeader(hdr)
|
||||
}
|
||||
|
||||
// Regular file: write header + content
|
||||
finfo, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hdr := &tar.Header{
|
||||
Name: relPath,
|
||||
Mode: 0644,
|
||||
Size: finfo.Size(),
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening %s: %w", relPath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(tw, f); err != nil {
|
||||
return fmt.Errorf("error streaming %s: %w", relPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Close pipeline layers in order: tar -> compress -> pipe
|
||||
// We must close even on error to unblock the encryption goroutine.
|
||||
twCloseErr := tw.Close()
|
||||
compCloseErr := compWriter.Close()
|
||||
|
||||
if walkErr != nil {
|
||||
pw.CloseWithError(walkErr)
|
||||
wg.Wait()
|
||||
cleanup()
|
||||
return fmt.Errorf("error walking directory: %w", walkErr)
|
||||
}
|
||||
|
||||
if twCloseErr != nil {
|
||||
pw.CloseWithError(twCloseErr)
|
||||
wg.Wait()
|
||||
cleanup()
|
||||
return fmt.Errorf("error closing tar writer: %w", twCloseErr)
|
||||
}
|
||||
|
||||
if compCloseErr != nil {
|
||||
pw.CloseWithError(compCloseErr)
|
||||
wg.Wait()
|
||||
cleanup()
|
||||
return fmt.Errorf("error closing compression writer: %w", compCloseErr)
|
||||
}
|
||||
|
||||
// Signal EOF to encryption goroutine
|
||||
pw.Close()
|
||||
|
||||
// Wait for encryption to finish
|
||||
wg.Wait()
|
||||
|
||||
if encErr != nil {
|
||||
cleanup()
|
||||
return fmt.Errorf("error encrypting data: %w", encErr)
|
||||
}
|
||||
|
||||
// Close output file
|
||||
if err := outFile.Close(); err != nil {
|
||||
os.Remove(output)
|
||||
return fmt.Errorf("error closing output file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptStimV2 decrypts a STIM v2 file back into a DataNode.
|
||||
// It opens the file, runs StreamDecrypt, decompresses the result,
|
||||
// and parses the tar archive into a DataNode.
|
||||
func DecryptStimV2(path, password string) (*datanode.DataNode, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Decrypt
|
||||
var decrypted bytes.Buffer
|
||||
if err := tim.StreamDecrypt(f, &decrypted, password); err != nil {
|
||||
return nil, fmt.Errorf("error decrypting: %w", err)
|
||||
}
|
||||
|
||||
// Decompress
|
||||
decompressed, err := compress.Decompress(decrypted.Bytes())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decompressing: %w", err)
|
||||
}
|
||||
|
||||
// Parse tar into DataNode
|
||||
dn, err := datanode.FromTar(decompressed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing tar: %w", err)
|
||||
}
|
||||
|
||||
return dn, nil
|
||||
}
|
||||
|
|
@ -1,161 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCollectLocalStreaming_Good(t *testing.T) {
|
||||
// Create a temp directory with some test files
|
||||
srcDir := t.TempDir()
|
||||
outDir := t.TempDir()
|
||||
|
||||
// Create files in subdirectories
|
||||
subDir := filepath.Join(srcDir, "subdir")
|
||||
if err := os.MkdirAll(subDir, 0755); err != nil {
|
||||
t.Fatalf("failed to create subdir: %v", err)
|
||||
}
|
||||
|
||||
files := map[string]string{
|
||||
"hello.txt": "hello world",
|
||||
"subdir/nested.go": "package main\n",
|
||||
}
|
||||
for name, content := range files {
|
||||
path := filepath.Join(srcDir, name)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to write %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
output := filepath.Join(outDir, "test.stim")
|
||||
err := CollectLocalStreaming(srcDir, output, "gz", "test-password")
|
||||
if err != nil {
|
||||
t.Fatalf("CollectLocalStreaming() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists and is non-empty
|
||||
info, err := os.Stat(output)
|
||||
if err != nil {
|
||||
t.Fatalf("output file does not exist: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Fatal("output file is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalStreaming_Decrypt_Good(t *testing.T) {
|
||||
// Create a temp directory with known files
|
||||
srcDir := t.TempDir()
|
||||
outDir := t.TempDir()
|
||||
|
||||
subDir := filepath.Join(srcDir, "pkg")
|
||||
if err := os.MkdirAll(subDir, 0755); err != nil {
|
||||
t.Fatalf("failed to create subdir: %v", err)
|
||||
}
|
||||
|
||||
expectedFiles := map[string]string{
|
||||
"README.md": "# Test Project\n",
|
||||
"pkg/main.go": "package main\n\nfunc main() {}\n",
|
||||
}
|
||||
for name, content := range expectedFiles {
|
||||
path := filepath.Join(srcDir, name)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to write %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
password := "decrypt-test-pw"
|
||||
output := filepath.Join(outDir, "roundtrip.stim")
|
||||
|
||||
// Collect
|
||||
err := CollectLocalStreaming(srcDir, output, "gz", password)
|
||||
if err != nil {
|
||||
t.Fatalf("CollectLocalStreaming() error = %v", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
dn, err := DecryptStimV2(output, password)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptStimV2() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify each expected file exists in the DataNode
|
||||
for name, wantContent := range expectedFiles {
|
||||
f, err := dn.Open(name)
|
||||
if err != nil {
|
||||
t.Errorf("file %q not found in DataNode: %v", name, err)
|
||||
continue
|
||||
}
|
||||
buf := make([]byte, 4096)
|
||||
n, _ := f.Read(buf)
|
||||
f.Close()
|
||||
got := string(buf[:n])
|
||||
if got != wantContent {
|
||||
t.Errorf("file %q content mismatch:\n got: %q\n want: %q", name, got, wantContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalStreaming_BrokenSymlink_Good(t *testing.T) {
|
||||
srcDir := t.TempDir()
|
||||
outDir := t.TempDir()
|
||||
|
||||
// Create a regular file
|
||||
if err := os.WriteFile(filepath.Join(srcDir, "real.txt"), []byte("I exist"), 0644); err != nil {
|
||||
t.Fatalf("failed to write real.txt: %v", err)
|
||||
}
|
||||
|
||||
// Create a broken symlink pointing to a nonexistent target
|
||||
brokenLink := filepath.Join(srcDir, "broken-link")
|
||||
if err := os.Symlink("/nonexistent/target/file", brokenLink); err != nil {
|
||||
t.Fatalf("failed to create broken symlink: %v", err)
|
||||
}
|
||||
|
||||
output := filepath.Join(outDir, "symlink.stim")
|
||||
err := CollectLocalStreaming(srcDir, output, "none", "sym-password")
|
||||
if err != nil {
|
||||
t.Fatalf("CollectLocalStreaming() should skip broken symlinks, got error = %v", err)
|
||||
}
|
||||
|
||||
// Verify output exists and is non-empty
|
||||
info, err := os.Stat(output)
|
||||
if err != nil {
|
||||
t.Fatalf("output file does not exist: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Fatal("output file is empty")
|
||||
}
|
||||
|
||||
// Decrypt and verify the broken symlink was skipped
|
||||
dn, err := DecryptStimV2(output, "sym-password")
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptStimV2() error = %v", err)
|
||||
}
|
||||
|
||||
// real.txt should be present
|
||||
if _, err := dn.Stat("real.txt"); err != nil {
|
||||
t.Error("expected real.txt in DataNode but it's missing")
|
||||
}
|
||||
|
||||
// broken-link should NOT be present
|
||||
exists, _ := dn.Exists("broken-link")
|
||||
if exists {
|
||||
t.Error("broken symlink should have been skipped but was found in DataNode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalStreaming_Bad(t *testing.T) {
|
||||
outDir := t.TempDir()
|
||||
output := filepath.Join(outDir, "should-not-exist.stim")
|
||||
|
||||
err := CollectLocalStreaming("/nonexistent/path/that/does/not/exist", output, "none", "password")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
|
||||
// Verify no partial output file was left behind
|
||||
if _, statErr := os.Stat(output); statErr == nil {
|
||||
t.Error("partial output file should have been cleaned up")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// ProgressFromCmd returns a Progress based on --quiet flag and TTY detection.
|
||||
func ProgressFromCmd(cmd *cobra.Command) ui.Progress {
|
||||
quiet, _ := cmd.Flags().GetBool("quiet")
|
||||
if quiet {
|
||||
return ui.NewQuietProgress(os.Stderr)
|
||||
}
|
||||
return ui.DefaultProgress()
|
||||
}
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestProgressFromCmd_Good(t *testing.T) {
|
||||
cmd := &cobra.Command{}
|
||||
cmd.PersistentFlags().BoolP("quiet", "q", false, "")
|
||||
|
||||
p := ProgressFromCmd(cmd)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil Progress")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressFromCmd_Quiet_Good(t *testing.T) {
|
||||
cmd := &cobra.Command{}
|
||||
cmd.PersistentFlags().BoolP("quiet", "q", true, "")
|
||||
_ = cmd.PersistentFlags().Set("quiet", "true")
|
||||
|
||||
p := ProgressFromCmd(cmd)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil Progress")
|
||||
}
|
||||
}
|
||||
105
cmd/failures.go
Normal file
105
cmd/failures.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Borg/pkg/failures"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var failuresCmd = &cobra.Command{
|
||||
Use: "failures",
|
||||
Short: "Manage failures from collection runs",
|
||||
}
|
||||
|
||||
var failuresShowCmd = &cobra.Command{
|
||||
Use: "show [run-directory]",
|
||||
Short: "Show a summary of a failure report",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
reportPath := filepath.Join(args[0], "failures.json")
|
||||
data, err := os.ReadFile(reportPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read failure report: %w", err)
|
||||
}
|
||||
|
||||
var report failures.FailureReport
|
||||
if err := json.Unmarshal(data, &report); err != nil {
|
||||
return fmt.Errorf("failed to parse failure report: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Collection: %s\n", report.Collection)
|
||||
fmt.Printf("Started: %s\n", report.Started.Format(time.RFC3339))
|
||||
fmt.Printf("Completed: %s\n", report.Completed.Format(time.RFC3339))
|
||||
fmt.Printf("Total: %d\n", report.Stats.Total)
|
||||
fmt.Printf("Success: %d\n", report.Stats.Success)
|
||||
fmt.Printf("Failed: %d\n", report.Stats.Failed)
|
||||
|
||||
if len(report.Failures) > 0 {
|
||||
fmt.Println("\nFailures:")
|
||||
for _, f := range report.Failures {
|
||||
fmt.Printf(" - URL: %s\n", f.URL)
|
||||
fmt.Printf(" Error: %s\n", f.Error)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var failuresClearCmd = &cobra.Command{
|
||||
Use: "clear",
|
||||
Short: "Clear old failure reports",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
olderThan, _ := cmd.Flags().GetString("older-than")
|
||||
failuresDir, _ := cmd.Flags().GetString("failures-dir")
|
||||
if failuresDir == "" {
|
||||
failuresDir = ".borg-failures"
|
||||
}
|
||||
|
||||
duration, err := time.ParseDuration(olderThan)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid duration for --older-than: %w", err)
|
||||
}
|
||||
|
||||
cutoff := time.Now().Add(-duration)
|
||||
|
||||
entries, err := os.ReadDir(failuresDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read failures directory: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
runTime, err := time.Parse("2006-01-02T15-04-05", entry.Name())
|
||||
if err != nil {
|
||||
// Ignore directories that don't match the timestamp format
|
||||
continue
|
||||
}
|
||||
|
||||
if runTime.Before(cutoff) {
|
||||
runPath := filepath.Join(failuresDir, entry.Name())
|
||||
fmt.Printf("Removing old failure directory: %s\n", runPath)
|
||||
if err := os.RemoveAll(runPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to remove %s: %v\n", runPath, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
RootCmd.AddCommand(failuresCmd)
|
||||
failuresCmd.AddCommand(failuresShowCmd)
|
||||
failuresCmd.AddCommand(failuresClearCmd)
|
||||
|
||||
failuresClearCmd.Flags().String("older-than", "720h", "Clear failures older than this duration (e.g., 7d, 24h)")
|
||||
failuresClearCmd.Flags().String("failures-dir", ".borg-failures", "The directory where failures are stored")
|
||||
}
|
||||
|
|
@ -1,194 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestFullPipeline_Good exercises the complete streaming pipeline end-to-end
|
||||
// with realistic directory contents including nested dirs, a large file that
|
||||
// crosses the AEAD block boundary, valid and broken symlinks, and a hidden file.
|
||||
// Each compression mode (none, gz, xz) is tested as a subtest.
|
||||
func TestFullPipeline_Good(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Build a realistic source directory.
|
||||
srcDir := t.TempDir()
|
||||
|
||||
// Regular files at root level.
|
||||
writeFile(t, srcDir, "readme.md", "# My Project\n\nA description.\n")
|
||||
writeFile(t, srcDir, "config.json", `{"version":"1.0","debug":false}`)
|
||||
|
||||
// Nested directories with source code.
|
||||
mkdirAll(t, srcDir, "src")
|
||||
mkdirAll(t, srcDir, "src/pkg")
|
||||
writeFile(t, srcDir, "src/main.go", "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"hello\")\n}\n")
|
||||
writeFile(t, srcDir, "src/pkg/lib.go", "package pkg\n\n// Lib is a library function.\nfunc Lib() string { return \"lib\" }\n")
|
||||
|
||||
// Large file: 1 MiB + 1 byte — crosses the 64 KiB block boundary used by
|
||||
// the chunked AEAD streaming encryption. Fill with a deterministic pattern
|
||||
// so we can verify content after round-trip.
|
||||
const largeSize = 1024*1024 + 1
|
||||
largeContent := make([]byte, largeSize)
|
||||
for i := range largeContent {
|
||||
largeContent[i] = byte(i % 251) // prime mod for non-trivial pattern
|
||||
}
|
||||
writeFileBytes(t, srcDir, "large.bin", largeContent)
|
||||
|
||||
// Valid symlink pointing at a relative target.
|
||||
if err := os.Symlink("readme.md", filepath.Join(srcDir, "link-to-readme")); err != nil {
|
||||
t.Fatalf("failed to create valid symlink: %v", err)
|
||||
}
|
||||
|
||||
// Broken symlink pointing at a nonexistent absolute path.
|
||||
if err := os.Symlink("/nonexistent/target", filepath.Join(srcDir, "broken-link")); err != nil {
|
||||
t.Fatalf("failed to create broken symlink: %v", err)
|
||||
}
|
||||
|
||||
// Hidden file (dot-prefixed).
|
||||
writeFile(t, srcDir, ".hidden", "secret stuff\n")
|
||||
|
||||
// Run each compression mode as a subtest.
|
||||
modes := []string{"none", "gz", "xz"}
|
||||
for _, comp := range modes {
|
||||
comp := comp // capture
|
||||
t.Run("compression="+comp, func(t *testing.T) {
|
||||
outDir := t.TempDir()
|
||||
outFile := filepath.Join(outDir, "pipeline-"+comp+".stim")
|
||||
password := "integration-test-pw-" + comp
|
||||
|
||||
// Step 1: Collect (walk -> tar -> compress -> encrypt -> file).
|
||||
if err := CollectLocalStreaming(srcDir, outFile, comp, password); err != nil {
|
||||
t.Fatalf("CollectLocalStreaming(%q) error = %v", comp, err)
|
||||
}
|
||||
|
||||
// Step 2: Verify output exists and is non-empty.
|
||||
info, err := os.Stat(outFile)
|
||||
if err != nil {
|
||||
t.Fatalf("output file does not exist: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Fatal("output file is empty")
|
||||
}
|
||||
|
||||
// Step 3: Decrypt back into a DataNode.
|
||||
dn, err := DecryptStimV2(outFile, password)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptStimV2() error = %v", err)
|
||||
}
|
||||
|
||||
// Step 4: Verify all regular files exist in the DataNode.
|
||||
expectedFiles := []string{
|
||||
"readme.md",
|
||||
"config.json",
|
||||
"src/main.go",
|
||||
"src/pkg/lib.go",
|
||||
"large.bin",
|
||||
".hidden",
|
||||
}
|
||||
for _, name := range expectedFiles {
|
||||
exists, eerr := dn.Exists(name)
|
||||
if eerr != nil {
|
||||
t.Errorf("Exists(%q) error = %v", name, eerr)
|
||||
continue
|
||||
}
|
||||
if !exists {
|
||||
t.Errorf("expected file %q in DataNode but it is missing", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the valid symlink was included.
|
||||
linkExists, _ := dn.Exists("link-to-readme")
|
||||
if !linkExists {
|
||||
t.Error("expected symlink link-to-readme in DataNode but it is missing")
|
||||
}
|
||||
|
||||
// Step 5: Verify large file has correct content (first byte check).
|
||||
f, err := dn.Open("large.bin")
|
||||
if err != nil {
|
||||
t.Fatalf("Open(large.bin) error = %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Read the entire large file and verify size and first byte.
|
||||
allData, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
t.Fatalf("reading large.bin: %v", err)
|
||||
}
|
||||
if len(allData) != largeSize {
|
||||
t.Errorf("large.bin size = %d, want %d", len(allData), largeSize)
|
||||
}
|
||||
if len(allData) > 0 && allData[0] != byte(0%251) {
|
||||
t.Errorf("large.bin first byte = %d, want %d", allData[0], byte(0%251))
|
||||
}
|
||||
|
||||
// Verify content integrity of the whole large file.
|
||||
if !bytes.Equal(allData, largeContent) {
|
||||
t.Error("large.bin content does not match original after round-trip")
|
||||
}
|
||||
|
||||
// Step 6: Verify broken symlink was skipped.
|
||||
brokenExists, _ := dn.Exists("broken-link")
|
||||
if brokenExists {
|
||||
t.Error("broken symlink should have been skipped but was found in DataNode")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFullPipeline_WrongPassword_Bad encrypts with one password and attempts
|
||||
// to decrypt with a different password, verifying that an error is returned.
|
||||
func TestFullPipeline_WrongPassword_Bad(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
srcDir := t.TempDir()
|
||||
outDir := t.TempDir()
|
||||
|
||||
writeFile(t, srcDir, "secret.txt", "this is confidential\n")
|
||||
|
||||
outFile := filepath.Join(outDir, "wrong-pw.stim")
|
||||
|
||||
// Encrypt with the correct password.
|
||||
if err := CollectLocalStreaming(srcDir, outFile, "none", "correct-password"); err != nil {
|
||||
t.Fatalf("CollectLocalStreaming() error = %v", err)
|
||||
}
|
||||
|
||||
// Attempt to decrypt with the wrong password.
|
||||
_, err := DecryptStimV2(outFile, "wrong-password")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when decrypting with wrong password, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func writeFile(t *testing.T, base, rel, content string) {
|
||||
t.Helper()
|
||||
path := filepath.Join(base, rel)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to write %s: %v", rel, err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeFileBytes(t *testing.T, base, rel string, data []byte) {
|
||||
t.Helper()
|
||||
path := filepath.Join(base, rel)
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
t.Fatalf("failed to write %s: %v", rel, err)
|
||||
}
|
||||
}
|
||||
|
||||
func mkdirAll(t *testing.T, base, rel string) {
|
||||
t.Helper()
|
||||
path := filepath.Join(base, rel)
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
t.Fatalf("failed to mkdir %s: %v", rel, err)
|
||||
}
|
||||
}
|
||||
56
cmd/retry.go
Normal file
56
cmd/retry.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/Snider/Borg/pkg/failures"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var retryCmd = &cobra.Command{
|
||||
Use: "retry [run-directory]",
|
||||
Short: "Retry failures from a collection run",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
fmt.Printf("Retrying failures from %s...\n", args[0])
|
||||
|
||||
onlyRetryable, _ := cmd.Flags().GetBool("only-retryable")
|
||||
|
||||
reportPath := filepath.Join(args[0], "failures.json")
|
||||
data, err := os.ReadFile(reportPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read failure report: %w", err)
|
||||
}
|
||||
|
||||
var report failures.FailureReport
|
||||
if err := json.Unmarshal(data, &report); err != nil {
|
||||
return fmt.Errorf("failed to parse failure report: %w", err)
|
||||
}
|
||||
|
||||
for _, failure := range report.Failures {
|
||||
if onlyRetryable && !failure.Retryable {
|
||||
fmt.Printf("Skipping non-retryable failure: %s\n", failure.URL)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Retrying %s...\n", failure.URL)
|
||||
retryCmd := exec.Command("borg", "collect", "github", "repo", failure.URL)
|
||||
retryCmd.Stdout = os.Stdout
|
||||
retryCmd.Stderr = os.Stderr
|
||||
if err := retryCmd.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to retry %s: %v\n", failure.URL, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
RootCmd.AddCommand(retryCmd)
|
||||
retryCmd.Flags().Bool("only-retryable", false, "Retry only failures marked as retryable")
|
||||
}
|
||||
|
|
@ -16,7 +16,6 @@ packaging their contents into a single file, and managing the data within.`,
|
|||
}
|
||||
|
||||
rootCmd.PersistentFlags().BoolP("verbose", "v", false, "Enable verbose logging")
|
||||
rootCmd.PersistentFlags().BoolP("quiet", "q", false, "Suppress non-error output")
|
||||
return rootCmd
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,209 +0,0 @@
|
|||
# Borg Production Backup Upgrade — Design Document
|
||||
|
||||
**Date:** 2026-02-21
|
||||
**Status:** Implemented
|
||||
**Approach:** Bottom-Up Refactor
|
||||
|
||||
## Problem Statement
|
||||
|
||||
Borg's `collect local` command fails on large directories because DataNode loads
|
||||
everything into RAM. The UI spinner floods non-TTY output. Broken symlinks crash
|
||||
the collection pipeline. Key derivation uses bare SHA-256. These issues prevent
|
||||
Borg from being used for production backup workflows.
|
||||
|
||||
## Goals
|
||||
|
||||
1. Make `collect local` work reliably on large directories (10GB+)
|
||||
2. Handle symlinks properly (skip broken, follow/store valid)
|
||||
3. Add quiet/scripted mode for cron and pipeline use
|
||||
4. Harden encryption key derivation (Argon2id)
|
||||
5. Clean up the library for external consumers
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- Full core/go-* package integration (deferred — circular dependency risk since
|
||||
core imports Borg)
|
||||
- New CLI commands beyond fixing existing ones
|
||||
- Network transport or remote sync features
|
||||
- GUI or web interface
|
||||
|
||||
## Architecture
|
||||
|
||||
### Current Flow (Broken for Large Dirs)
|
||||
|
||||
```
|
||||
Walk directory → Load ALL files into DataNode (RAM) → Compress → Encrypt → Write
|
||||
```
|
||||
|
||||
### New Flow (Streaming)
|
||||
|
||||
```
|
||||
Walk directory → tar.Writer stream → compress stream → chunked encrypt → output file
|
||||
```
|
||||
|
||||
DataNode remains THE core abstraction — the I/O sandbox that keeps everything safe
|
||||
and portable. The streaming path bypasses DataNode for the `collect local` pipeline
|
||||
only, while DataNode continues to serve all other use cases (programmatic access,
|
||||
format conversion, inspection).
|
||||
|
||||
## Design Sections
|
||||
|
||||
### 1. DataNode Refactor
|
||||
|
||||
DataNode gains a `ToTarWriter(w io.Writer)` method for streaming out its contents
|
||||
without buffering the entire archive. This is the bridge between DataNode's sandbox
|
||||
model and streaming I/O.
|
||||
|
||||
New symlink handling:
|
||||
|
||||
| Symlink State | Behaviour |
|
||||
|---------------|-----------|
|
||||
| Valid, points inside DataNode root | Store as symlink entry |
|
||||
| Valid, points outside DataNode root | Follow and store target content |
|
||||
| Broken (dangling) | Skip with warning (configurable via `SkipBrokenSymlinks`) |
|
||||
|
||||
The `AddPath` method gets an options struct:
|
||||
|
||||
```go
|
||||
type AddPathOptions struct {
|
||||
SkipBrokenSymlinks bool // default: true
|
||||
FollowSymlinks bool // default: false (store as symlinks)
|
||||
ExcludePatterns []string
|
||||
}
|
||||
```
|
||||
|
||||
### 2. UI & Logger Cleanup
|
||||
|
||||
Replace direct spinner writes with a `Progress` interface:
|
||||
|
||||
```go
|
||||
type Progress interface {
|
||||
Start(label string)
|
||||
Update(current, total int64)
|
||||
Finish(label string)
|
||||
Log(level, msg string, args ...any)
|
||||
}
|
||||
```
|
||||
|
||||
Two implementations:
|
||||
- **InteractiveProgress** — spinner + progress bar (when `isatty(stdout)`)
|
||||
- **QuietProgress** — structured log lines only (cron, pipes, `--quiet` flag)
|
||||
|
||||
TTY detection at startup selects the implementation. All existing `ui.Spinner` and
|
||||
`fmt.Printf` calls in library code get replaced with `Progress` method calls.
|
||||
|
||||
New `--quiet` / `-q` flag on all commands suppresses non-error output.
|
||||
|
||||
### 3. TIM Streaming Encryption
|
||||
|
||||
ChaCha20-Poly1305 is AEAD — it needs the full plaintext to compute the auth tag.
|
||||
For streaming, we use a chunked block format:
|
||||
|
||||
```
|
||||
[magic: 4 bytes "STIM"]
|
||||
[version: 1 byte]
|
||||
[salt: 16 bytes] ← Argon2id salt
|
||||
[argon2 params: 12 bytes] ← time, memory, threads (uint32 LE each)
|
||||
|
||||
Per block (repeated):
|
||||
[nonce: 12 bytes]
|
||||
[length: 4 bytes LE] ← ciphertext length including 16-byte Poly1305 tag
|
||||
[ciphertext: N bytes] ← encrypted chunk + tag
|
||||
|
||||
Final block:
|
||||
[nonce: 12 bytes]
|
||||
[length: 4 bytes LE = 0] ← zero length signals EOF
|
||||
```
|
||||
|
||||
Block size: 1 MiB plaintext → ~1 MiB + 16 bytes ciphertext per block.
|
||||
|
||||
The `Sigil` (Enchantrix crypto handle) wraps this as `StreamEncrypt(r io.Reader,
|
||||
w io.Writer)` and `StreamDecrypt(r io.Reader, w io.Writer)`.
|
||||
|
||||
### 4. Key Derivation Hardening
|
||||
|
||||
Replace bare `SHA-256(password)` with Argon2id:
|
||||
|
||||
```go
|
||||
key := argon2.IDKey(password, salt, time=3, memory=64*1024, threads=4, keyLen=32)
|
||||
```
|
||||
|
||||
Parameters stored in the STIM header (section 3 above) so they can be tuned
|
||||
without breaking existing archives. Random 16-byte salt generated per archive.
|
||||
|
||||
Backward compatibility: detect old format by checking for "STIM" magic. Old files
|
||||
(no magic header) use legacy SHA-256 derivation with a deprecation warning.
|
||||
|
||||
### 5. Collect Local Streaming Pipeline
|
||||
|
||||
The new `collect local` pipeline for large directories:
|
||||
|
||||
```
|
||||
filepath.WalkDir
|
||||
→ tar.NewWriter (streaming)
|
||||
→ xz/gzip compressor (streaming)
|
||||
→ chunked AEAD encryptor (streaming)
|
||||
→ os.File output
|
||||
```
|
||||
|
||||
Memory usage: ~2 MiB regardless of input size (1 MiB compress buffer + 1 MiB
|
||||
encrypt block).
|
||||
|
||||
Error handling:
|
||||
- Broken symlinks: skip with warning (not fatal)
|
||||
- Permission denied: skip with warning, continue
|
||||
- Disk full on output: fatal, clean up partial file
|
||||
- Read errors mid-stream: fatal, clean up partial file
|
||||
|
||||
Compression selection: `--compress=xz` (default, best ratio) or `--compress=gzip`
|
||||
(faster). Matches existing Borg compression support.
|
||||
|
||||
### 6. Core Package Integration (Deferred)
|
||||
|
||||
Core imports Borg, so Borg cannot import core packages without creating a circular
|
||||
dependency. Integration points are marked with TODOs for when the dependency
|
||||
direction is resolved (likely by extracting shared interfaces to a common module):
|
||||
|
||||
- `core/go` config system → Borg config loading
|
||||
- `core/go` logging → Borg Progress interface backend
|
||||
- `core/go-store` → DataNode persistence
|
||||
- `core/go` io.Medium → DataNode filesystem abstraction
|
||||
|
||||
## File Impact Summary
|
||||
|
||||
| Area | Files | Change Type |
|
||||
|------|-------|-------------|
|
||||
| DataNode | `pkg/datanode/*.go` | Modify (ToTarWriter, symlinks, AddPathOptions) |
|
||||
| UI | `pkg/ui/*.go` | Rewrite (Progress interface, TTY detection) |
|
||||
| TIM/STIM | `pkg/tim/*.go` | Modify (streaming encrypt/decrypt, new header) |
|
||||
| Crypto | `pkg/tim/crypto.go` (new) | Create (Argon2id, chunked AEAD) |
|
||||
| Collect | `cmd/collect_local.go` | Rewrite (streaming pipeline) |
|
||||
| CLI | `cmd/root.go`, `cmd/*.go` | Modify (--quiet flag) |
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
- Unit tests for each component (DataNode, Progress, chunked AEAD, Argon2id)
|
||||
- Round-trip tests: encrypt → decrypt → compare original
|
||||
- Large file test: 100 MiB synthetic directory through full pipeline
|
||||
- Symlink matrix: valid internal, valid external, broken, nested
|
||||
- Backward compatibility: decrypt old-format STIM with new code
|
||||
- Race detector: `go test -race ./...`
|
||||
|
||||
## Dependencies
|
||||
|
||||
New:
|
||||
- `golang.org/x/crypto/argon2` (Argon2id key derivation)
|
||||
- `golang.org/x/term` (TTY detection via `term.IsTerminal`)
|
||||
|
||||
Existing (unchanged):
|
||||
- `github.com/snider/Enchantrix` (ChaCha20-Poly1305 via Sigil)
|
||||
- `github.com/ulikunitz/xz` (XZ compression)
|
||||
|
||||
## Risk Assessment
|
||||
|
||||
| Risk | Mitigation |
|
||||
|------|------------|
|
||||
| Breaking existing STIM format | Magic-byte detection for backward compat |
|
||||
| Chunked AEAD security | Standard construction (each block independent nonce) |
|
||||
| Circular dep with core | Deferred; TODO markers only |
|
||||
| Large directory edge cases | Extensive symlink + permission test matrix |
|
||||
File diff suppressed because it is too large
Load diff
BIN
examples/demo-sample.smsg
Normal file
BIN
examples/demo-sample.smsg
Normal file
Binary file not shown.
12
go.mod
12
go.mod
|
|
@ -13,9 +13,8 @@ require (
|
|||
github.com/spf13/cobra v1.10.1
|
||||
github.com/ulikunitz/xz v0.5.15
|
||||
github.com/wailsapp/wails/v2 v2.11.0
|
||||
golang.org/x/crypto v0.48.0
|
||||
golang.org/x/mod v0.32.0
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.33.0
|
||||
)
|
||||
|
||||
|
|
@ -61,8 +60,9 @@ require (
|
|||
github.com/wailsapp/go-webview2 v1.0.22 // indirect
|
||||
github.com/wailsapp/mimetype v1.4.1 // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/term v0.40.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/crypto v0.44.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/term v0.37.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
)
|
||||
|
|
|
|||
24
go.sum
24
go.sum
|
|
@ -155,18 +155,18 @@ github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI
|
|||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
|
||||
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210505024714-0287a6fb4125/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo=
|
||||
golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
|
|
@ -181,17 +181,17 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
|
|
|
|||
|
|
@ -3,34 +3,11 @@ package compress
|
|||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ulikunitz/xz"
|
||||
)
|
||||
|
||||
// nopCloser wraps an io.Writer with a no-op Close method.
|
||||
type nopCloser struct{ io.Writer }
|
||||
|
||||
func (n *nopCloser) Close() error { return nil }
|
||||
|
||||
// NewCompressWriter returns a streaming io.WriteCloser that compresses data
|
||||
// written to it into the underlying writer w using the specified format.
|
||||
// Supported formats: "gz" (gzip), "xz", "none" or "" (passthrough).
|
||||
// Unknown formats return an error.
|
||||
func NewCompressWriter(w io.Writer, format string) (io.WriteCloser, error) {
|
||||
switch format {
|
||||
case "gz":
|
||||
return gzip.NewWriter(w), nil
|
||||
case "xz":
|
||||
return xz.NewWriter(w)
|
||||
case "none", "":
|
||||
return &nopCloser{w}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported compression format: %q", format)
|
||||
}
|
||||
}
|
||||
|
||||
// Compress compresses data using the specified format.
|
||||
func Compress(data []byte, format string) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
|
|
|
|||
|
|
@ -5,108 +5,6 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func TestNewCompressWriter_Gzip_Good(t *testing.T) {
|
||||
original := []byte("hello, streaming gzip world")
|
||||
var buf bytes.Buffer
|
||||
|
||||
w, err := NewCompressWriter(&buf, "gz")
|
||||
if err != nil {
|
||||
t.Fatalf("NewCompressWriter(gz) error: %v", err)
|
||||
}
|
||||
if _, err := w.Write(original); err != nil {
|
||||
t.Fatalf("Write error: %v", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("Close error: %v", err)
|
||||
}
|
||||
|
||||
compressed := buf.Bytes()
|
||||
if bytes.Equal(original, compressed) {
|
||||
t.Fatal("compressed data should differ from original")
|
||||
}
|
||||
|
||||
decompressed, err := Decompress(compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("Decompress error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(original, decompressed) {
|
||||
t.Errorf("round-trip mismatch: got %q, want %q", decompressed, original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCompressWriter_Xz_Good(t *testing.T) {
|
||||
original := []byte("hello, streaming xz world")
|
||||
var buf bytes.Buffer
|
||||
|
||||
w, err := NewCompressWriter(&buf, "xz")
|
||||
if err != nil {
|
||||
t.Fatalf("NewCompressWriter(xz) error: %v", err)
|
||||
}
|
||||
if _, err := w.Write(original); err != nil {
|
||||
t.Fatalf("Write error: %v", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("Close error: %v", err)
|
||||
}
|
||||
|
||||
compressed := buf.Bytes()
|
||||
if bytes.Equal(original, compressed) {
|
||||
t.Fatal("compressed data should differ from original")
|
||||
}
|
||||
|
||||
decompressed, err := Decompress(compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("Decompress error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(original, decompressed) {
|
||||
t.Errorf("round-trip mismatch: got %q, want %q", decompressed, original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCompressWriter_None_Good(t *testing.T) {
|
||||
original := []byte("hello, passthrough world")
|
||||
var buf bytes.Buffer
|
||||
|
||||
w, err := NewCompressWriter(&buf, "none")
|
||||
if err != nil {
|
||||
t.Fatalf("NewCompressWriter(none) error: %v", err)
|
||||
}
|
||||
if _, err := w.Write(original); err != nil {
|
||||
t.Fatalf("Write error: %v", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("Close error: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(original, buf.Bytes()) {
|
||||
t.Errorf("passthrough mismatch: got %q, want %q", buf.Bytes(), original)
|
||||
}
|
||||
|
||||
// Also test empty string format
|
||||
var buf2 bytes.Buffer
|
||||
w2, err := NewCompressWriter(&buf2, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewCompressWriter('') error: %v", err)
|
||||
}
|
||||
if _, err := w2.Write(original); err != nil {
|
||||
t.Fatalf("Write error: %v", err)
|
||||
}
|
||||
if err := w2.Close(); err != nil {
|
||||
t.Fatalf("Close error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(original, buf2.Bytes()) {
|
||||
t.Errorf("passthrough (empty string) mismatch: got %q, want %q", buf2.Bytes(), original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCompressWriter_Bad(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
_, err := NewCompressWriter(&buf, "invalid-format")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown compression format, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGzip_Good(t *testing.T) {
|
||||
originalData := []byte("hello, gzip world")
|
||||
compressed, err := Compress(originalData, "gz")
|
||||
|
|
|
|||
|
|
@ -1,197 +0,0 @@
|
|||
package datanode
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAddPath_Good(t *testing.T) {
|
||||
// Create a temp directory with files and a nested subdirectory.
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
subdir := filepath.Join(dir, "sub")
|
||||
if err := os.Mkdir(subdir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(subdir, "world.txt"), []byte("world"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dn := New()
|
||||
if err := dn.AddPath(dir, AddPathOptions{}); err != nil {
|
||||
t.Fatalf("AddPath failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify files are stored with paths relative to dir, using forward slashes.
|
||||
file, ok := dn.files["hello.txt"]
|
||||
if !ok {
|
||||
t.Fatal("hello.txt not found in datanode")
|
||||
}
|
||||
if string(file.content) != "hello" {
|
||||
t.Errorf("expected content 'hello', got %q", file.content)
|
||||
}
|
||||
|
||||
file, ok = dn.files["sub/world.txt"]
|
||||
if !ok {
|
||||
t.Fatal("sub/world.txt not found in datanode")
|
||||
}
|
||||
if string(file.content) != "world" {
|
||||
t.Errorf("expected content 'world', got %q", file.content)
|
||||
}
|
||||
|
||||
// Directories should not be stored explicitly.
|
||||
if _, ok := dn.files["sub"]; ok {
|
||||
t.Error("directories should not be stored as explicit entries")
|
||||
}
|
||||
if _, ok := dn.files["sub/"]; ok {
|
||||
t.Error("directories should not be stored as explicit entries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPath_SkipBrokenSymlinks_Good(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlinks not reliably supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a real file.
|
||||
if err := os.WriteFile(filepath.Join(dir, "real.txt"), []byte("real"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a broken symlink (target does not exist).
|
||||
if err := os.Symlink("/nonexistent/target", filepath.Join(dir, "broken.txt")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dn := New()
|
||||
err := dn.AddPath(dir, AddPathOptions{SkipBrokenSymlinks: true})
|
||||
if err != nil {
|
||||
t.Fatalf("AddPath should not error with SkipBrokenSymlinks: %v", err)
|
||||
}
|
||||
|
||||
// The real file should be present.
|
||||
if _, ok := dn.files["real.txt"]; !ok {
|
||||
t.Error("real.txt should be present")
|
||||
}
|
||||
|
||||
// The broken symlink should be skipped.
|
||||
if _, ok := dn.files["broken.txt"]; ok {
|
||||
t.Error("broken.txt should have been skipped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPath_ExcludePatterns_Good(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
if err := os.WriteFile(filepath.Join(dir, "app.go"), []byte("package main"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "debug.log"), []byte("log data"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "error.log"), []byte("error data"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dn := New()
|
||||
err := dn.AddPath(dir, AddPathOptions{
|
||||
ExcludePatterns: []string{"*.log"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("AddPath failed: %v", err)
|
||||
}
|
||||
|
||||
// app.go should be present.
|
||||
if _, ok := dn.files["app.go"]; !ok {
|
||||
t.Error("app.go should be present")
|
||||
}
|
||||
|
||||
// .log files should be excluded.
|
||||
if _, ok := dn.files["debug.log"]; ok {
|
||||
t.Error("debug.log should have been excluded")
|
||||
}
|
||||
if _, ok := dn.files["error.log"]; ok {
|
||||
t.Error("error.log should have been excluded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPath_Bad(t *testing.T) {
|
||||
dn := New()
|
||||
err := dn.AddPath("/nonexistent/path/that/does/not/exist", AddPathOptions{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPath_ValidSymlink_Good(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlinks not reliably supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a real file.
|
||||
if err := os.WriteFile(filepath.Join(dir, "target.txt"), []byte("target content"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a valid symlink pointing to the real file.
|
||||
if err := os.Symlink("target.txt", filepath.Join(dir, "link.txt")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Default behavior (FollowSymlinks=false): store as symlink.
|
||||
dn := New()
|
||||
err := dn.AddPath(dir, AddPathOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("AddPath failed: %v", err)
|
||||
}
|
||||
|
||||
// The target file should be a regular file.
|
||||
targetFile, ok := dn.files["target.txt"]
|
||||
if !ok {
|
||||
t.Fatal("target.txt not found")
|
||||
}
|
||||
if targetFile.isSymlink() {
|
||||
t.Error("target.txt should not be a symlink")
|
||||
}
|
||||
if string(targetFile.content) != "target content" {
|
||||
t.Errorf("expected content 'target content', got %q", targetFile.content)
|
||||
}
|
||||
|
||||
// The symlink should be stored as a symlink entry.
|
||||
linkFile, ok := dn.files["link.txt"]
|
||||
if !ok {
|
||||
t.Fatal("link.txt not found")
|
||||
}
|
||||
if !linkFile.isSymlink() {
|
||||
t.Error("link.txt should be a symlink")
|
||||
}
|
||||
if linkFile.symlink != "target.txt" {
|
||||
t.Errorf("expected symlink target 'target.txt', got %q", linkFile.symlink)
|
||||
}
|
||||
|
||||
// Test with FollowSymlinks=true: store as regular file with target content.
|
||||
dn2 := New()
|
||||
err = dn2.AddPath(dir, AddPathOptions{FollowSymlinks: true})
|
||||
if err != nil {
|
||||
t.Fatalf("AddPath with FollowSymlinks failed: %v", err)
|
||||
}
|
||||
|
||||
linkFile2, ok := dn2.files["link.txt"]
|
||||
if !ok {
|
||||
t.Fatal("link.txt not found with FollowSymlinks")
|
||||
}
|
||||
if linkFile2.isSymlink() {
|
||||
t.Error("link.txt should NOT be a symlink when FollowSymlinks is true")
|
||||
}
|
||||
if string(linkFile2.content) != "target content" {
|
||||
t.Errorf("expected content 'target content', got %q", linkFile2.content)
|
||||
}
|
||||
}
|
||||
|
|
@ -8,7 +8,6 @@ import (
|
|||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -43,15 +42,12 @@ func FromTar(tarball []byte) (*DataNode, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeReg:
|
||||
if header.Typeflag == tar.TypeReg {
|
||||
data, err := io.ReadAll(tarReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dn.AddData(header.Name, data)
|
||||
case tar.TypeSymlink:
|
||||
dn.AddSymlink(header.Name, header.Linkname)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -64,30 +60,17 @@ func (d *DataNode) ToTar() ([]byte, error) {
|
|||
tw := tar.NewWriter(buf)
|
||||
|
||||
for _, file := range d.files {
|
||||
var hdr *tar.Header
|
||||
if file.isSymlink() {
|
||||
hdr = &tar.Header{
|
||||
Typeflag: tar.TypeSymlink,
|
||||
Name: file.name,
|
||||
Linkname: file.symlink,
|
||||
Mode: 0777,
|
||||
ModTime: file.modTime,
|
||||
}
|
||||
} else {
|
||||
hdr = &tar.Header{
|
||||
Name: file.name,
|
||||
Mode: 0600,
|
||||
Size: int64(len(file.content)),
|
||||
ModTime: file.modTime,
|
||||
}
|
||||
hdr := &tar.Header{
|
||||
Name: file.name,
|
||||
Mode: 0600,
|
||||
Size: int64(len(file.content)),
|
||||
ModTime: file.modTime,
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !file.isSymlink() {
|
||||
if _, err := tw.Write(file.content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := tw.Write(file.content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -98,51 +81,6 @@ func (d *DataNode) ToTar() ([]byte, error) {
|
|||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// ToTarWriter streams the DataNode contents to a tar writer.
|
||||
// File keys are sorted for deterministic output.
|
||||
func (d *DataNode) ToTarWriter(w io.Writer) error {
|
||||
tw := tar.NewWriter(w)
|
||||
defer tw.Close()
|
||||
|
||||
// Sort keys for deterministic output.
|
||||
keys := make([]string, 0, len(d.files))
|
||||
for k := range d.files {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
file := d.files[k]
|
||||
var hdr *tar.Header
|
||||
if file.isSymlink() {
|
||||
hdr = &tar.Header{
|
||||
Typeflag: tar.TypeSymlink,
|
||||
Name: file.name,
|
||||
Linkname: file.symlink,
|
||||
Mode: 0777,
|
||||
ModTime: file.modTime,
|
||||
}
|
||||
} else {
|
||||
hdr = &tar.Header{
|
||||
Name: file.name,
|
||||
Mode: 0600,
|
||||
Size: int64(len(file.content)),
|
||||
ModTime: file.modTime,
|
||||
}
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
if !file.isSymlink() {
|
||||
if _, err := tw.Write(file.content); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddData adds a file to the DataNode.
|
||||
func (d *DataNode) AddData(name string, content []byte) {
|
||||
name = strings.TrimPrefix(name, "/")
|
||||
|
|
@ -161,119 +99,6 @@ func (d *DataNode) AddData(name string, content []byte) {
|
|||
}
|
||||
}
|
||||
|
||||
// AddSymlink adds a symlink entry to the DataNode.
|
||||
func (d *DataNode) AddSymlink(name, target string) {
|
||||
name = strings.TrimPrefix(name, "/")
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(name, "/") {
|
||||
return
|
||||
}
|
||||
d.files[name] = &dataFile{
|
||||
name: name,
|
||||
symlink: target,
|
||||
modTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// AddPathOptions configures the behaviour of AddPath.
|
||||
type AddPathOptions struct {
|
||||
SkipBrokenSymlinks bool // skip broken symlinks instead of erroring
|
||||
FollowSymlinks bool // follow symlinks and store target content (default false = store as symlinks)
|
||||
ExcludePatterns []string // glob patterns to exclude (matched against basename)
|
||||
}
|
||||
|
||||
// AddPath walks a real directory and adds its files to the DataNode.
|
||||
// Paths are stored relative to dir, normalized with forward slashes.
|
||||
// Directories are implicit and not stored.
|
||||
func (d *DataNode) AddPath(dir string, opts AddPathOptions) error {
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return filepath.WalkDir(absDir, func(p string, entry fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip the root directory itself.
|
||||
if p == absDir {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Compute relative path and normalize to forward slashes.
|
||||
rel, err := filepath.Rel(absDir, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rel = filepath.ToSlash(rel)
|
||||
|
||||
// Skip directories — they are implicit in DataNode.
|
||||
isSymlink := entry.Type()&fs.ModeSymlink != 0
|
||||
if entry.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply exclude patterns against basename.
|
||||
base := filepath.Base(p)
|
||||
for _, pattern := range opts.ExcludePatterns {
|
||||
matched, matchErr := filepath.Match(pattern, base)
|
||||
if matchErr != nil {
|
||||
return matchErr
|
||||
}
|
||||
if matched {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle symlinks.
|
||||
if isSymlink {
|
||||
linkTarget, err := os.Readlink(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Resolve the symlink target to check if it exists.
|
||||
absTarget := linkTarget
|
||||
if !filepath.IsAbs(absTarget) {
|
||||
absTarget = filepath.Join(filepath.Dir(p), linkTarget)
|
||||
}
|
||||
|
||||
_, statErr := os.Stat(absTarget)
|
||||
if statErr != nil {
|
||||
// Broken symlink.
|
||||
if opts.SkipBrokenSymlinks {
|
||||
return nil
|
||||
}
|
||||
return statErr
|
||||
}
|
||||
|
||||
if opts.FollowSymlinks {
|
||||
// Read the target content and store as regular file.
|
||||
content, err := os.ReadFile(absTarget)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.AddData(rel, content)
|
||||
} else {
|
||||
// Store as symlink.
|
||||
d.AddSymlink(rel, linkTarget)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Regular file: read content and add.
|
||||
content, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.AddData(rel, content)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Open opens a file from the DataNode.
|
||||
func (d *DataNode) Open(name string) (fs.File, error) {
|
||||
name = strings.TrimPrefix(name, "/")
|
||||
|
|
@ -474,11 +299,8 @@ type dataFile struct {
|
|||
name string
|
||||
content []byte
|
||||
modTime time.Time
|
||||
symlink string
|
||||
}
|
||||
|
||||
func (d *dataFile) isSymlink() bool { return d.symlink != "" }
|
||||
|
||||
func (d *dataFile) Stat() (fs.FileInfo, error) { return &dataFileInfo{file: d}, nil }
|
||||
func (d *dataFile) Read(p []byte) (int, error) { return 0, io.EOF }
|
||||
func (d *dataFile) Close() error { return nil }
|
||||
|
|
@ -488,12 +310,7 @@ type dataFileInfo struct{ file *dataFile }
|
|||
|
||||
func (d *dataFileInfo) Name() string { return path.Base(d.file.name) }
|
||||
func (d *dataFileInfo) Size() int64 { return int64(len(d.file.content)) }
|
||||
func (d *dataFileInfo) Mode() fs.FileMode {
|
||||
if d.file.isSymlink() {
|
||||
return os.ModeSymlink | 0777
|
||||
}
|
||||
return 0444
|
||||
}
|
||||
func (d *dataFileInfo) Mode() fs.FileMode { return 0444 }
|
||||
func (d *dataFileInfo) ModTime() time.Time { return d.file.modTime }
|
||||
func (d *dataFileInfo) IsDir() bool { return false }
|
||||
func (d *dataFileInfo) Sys() interface{} { return nil }
|
||||
|
|
|
|||
|
|
@ -580,273 +580,6 @@ func TestFromTar_Bad(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAddSymlink_Good(t *testing.T) {
|
||||
dn := New()
|
||||
dn.AddSymlink("link.txt", "target.txt")
|
||||
|
||||
file, ok := dn.files["link.txt"]
|
||||
if !ok {
|
||||
t.Fatal("symlink not found in datanode")
|
||||
}
|
||||
if file.symlink != "target.txt" {
|
||||
t.Errorf("expected symlink target 'target.txt', got %q", file.symlink)
|
||||
}
|
||||
if !file.isSymlink() {
|
||||
t.Error("expected isSymlink() to return true")
|
||||
}
|
||||
|
||||
// Stat should return ModeSymlink
|
||||
info, err := dn.Stat("link.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Stat failed: %v", err)
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink == 0 {
|
||||
t.Error("expected ModeSymlink to be set in file mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSymlinkTarRoundTrip_Good(t *testing.T) {
|
||||
dn1 := New()
|
||||
dn1.AddData("real.txt", []byte("real content"))
|
||||
dn1.AddSymlink("link.txt", "real.txt")
|
||||
|
||||
tarball, err := dn1.ToTar()
|
||||
if err != nil {
|
||||
t.Fatalf("ToTar failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the tar contains a symlink entry
|
||||
tr := tar.NewReader(bytes.NewReader(tarball))
|
||||
foundSymlink := false
|
||||
foundFile := false
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("tar.Next failed: %v", err)
|
||||
}
|
||||
switch header.Name {
|
||||
case "link.txt":
|
||||
foundSymlink = true
|
||||
if header.Typeflag != tar.TypeSymlink {
|
||||
t.Errorf("expected TypeSymlink, got %d", header.Typeflag)
|
||||
}
|
||||
if header.Linkname != "real.txt" {
|
||||
t.Errorf("expected Linkname 'real.txt', got %q", header.Linkname)
|
||||
}
|
||||
if header.Mode != 0777 {
|
||||
t.Errorf("expected mode 0777, got %o", header.Mode)
|
||||
}
|
||||
case "real.txt":
|
||||
foundFile = true
|
||||
if header.Typeflag != tar.TypeReg {
|
||||
t.Errorf("expected TypeReg for real.txt, got %d", header.Typeflag)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundSymlink {
|
||||
t.Error("symlink entry not found in tarball")
|
||||
}
|
||||
if !foundFile {
|
||||
t.Error("regular file entry not found in tarball")
|
||||
}
|
||||
|
||||
// Round-trip: FromTar should restore the symlink
|
||||
dn2, err := FromTar(tarball)
|
||||
if err != nil {
|
||||
t.Fatalf("FromTar failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the regular file survived
|
||||
exists, _ := dn2.Exists("real.txt")
|
||||
if !exists {
|
||||
t.Error("real.txt missing after round-trip")
|
||||
}
|
||||
|
||||
// Verify the symlink survived
|
||||
linkFile, ok := dn2.files["link.txt"]
|
||||
if !ok {
|
||||
t.Fatal("link.txt missing after round-trip")
|
||||
}
|
||||
if !linkFile.isSymlink() {
|
||||
t.Error("expected link.txt to be a symlink after round-trip")
|
||||
}
|
||||
if linkFile.symlink != "real.txt" {
|
||||
t.Errorf("expected symlink target 'real.txt', got %q", linkFile.symlink)
|
||||
}
|
||||
|
||||
// Stat should still report ModeSymlink
|
||||
info, err := dn2.Stat("link.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Stat failed: %v", err)
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink == 0 {
|
||||
t.Error("expected ModeSymlink after round-trip")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddSymlink_Bad(t *testing.T) {
|
||||
dn := New()
|
||||
|
||||
// Empty name should be ignored
|
||||
dn.AddSymlink("", "target.txt")
|
||||
if len(dn.files) != 0 {
|
||||
t.Error("expected empty name to be ignored")
|
||||
}
|
||||
|
||||
// Leading slash should be stripped
|
||||
dn.AddSymlink("/link.txt", "target.txt")
|
||||
if _, ok := dn.files["link.txt"]; !ok {
|
||||
t.Error("expected leading slash to be stripped")
|
||||
}
|
||||
|
||||
// Directory-like name (trailing slash) should be ignored
|
||||
dn2 := New()
|
||||
dn2.AddSymlink("dir/", "target")
|
||||
if len(dn2.files) != 0 {
|
||||
t.Error("expected directory-like name to be ignored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToTarWriter_Good(t *testing.T) {
|
||||
dn := New()
|
||||
dn.AddData("foo.txt", []byte("hello"))
|
||||
dn.AddData("bar/baz.txt", []byte("world"))
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := dn.ToTarWriter(&buf); err != nil {
|
||||
t.Fatalf("ToTarWriter failed: %v", err)
|
||||
}
|
||||
|
||||
// Round-trip through FromTar to verify contents survived.
|
||||
dn2, err := FromTar(buf.Bytes())
|
||||
if err != nil {
|
||||
t.Fatalf("FromTar failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify foo.txt
|
||||
f1, ok := dn2.files["foo.txt"]
|
||||
if !ok {
|
||||
t.Fatal("foo.txt missing after round-trip")
|
||||
}
|
||||
if string(f1.content) != "hello" {
|
||||
t.Errorf("expected foo.txt content 'hello', got %q", f1.content)
|
||||
}
|
||||
|
||||
// Verify bar/baz.txt
|
||||
f2, ok := dn2.files["bar/baz.txt"]
|
||||
if !ok {
|
||||
t.Fatal("bar/baz.txt missing after round-trip")
|
||||
}
|
||||
if string(f2.content) != "world" {
|
||||
t.Errorf("expected bar/baz.txt content 'world', got %q", f2.content)
|
||||
}
|
||||
|
||||
// Verify deterministic ordering: bar/baz.txt should come before foo.txt.
|
||||
tr := tar.NewReader(bytes.NewReader(buf.Bytes()))
|
||||
header1, err := tr.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("tar.Next failed: %v", err)
|
||||
}
|
||||
header2, err := tr.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("tar.Next failed: %v", err)
|
||||
}
|
||||
if header1.Name != "bar/baz.txt" || header2.Name != "foo.txt" {
|
||||
t.Errorf("expected sorted order [bar/baz.txt, foo.txt], got [%s, %s]",
|
||||
header1.Name, header2.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToTarWriter_Symlinks_Good(t *testing.T) {
|
||||
dn := New()
|
||||
dn.AddData("real.txt", []byte("real content"))
|
||||
dn.AddSymlink("link.txt", "real.txt")
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := dn.ToTarWriter(&buf); err != nil {
|
||||
t.Fatalf("ToTarWriter failed: %v", err)
|
||||
}
|
||||
|
||||
// Round-trip through FromTar.
|
||||
dn2, err := FromTar(buf.Bytes())
|
||||
if err != nil {
|
||||
t.Fatalf("FromTar failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify regular file survived.
|
||||
realFile, ok := dn2.files["real.txt"]
|
||||
if !ok {
|
||||
t.Fatal("real.txt missing after round-trip")
|
||||
}
|
||||
if string(realFile.content) != "real content" {
|
||||
t.Errorf("expected 'real content', got %q", realFile.content)
|
||||
}
|
||||
|
||||
// Verify symlink survived.
|
||||
linkFile, ok := dn2.files["link.txt"]
|
||||
if !ok {
|
||||
t.Fatal("link.txt missing after round-trip")
|
||||
}
|
||||
if !linkFile.isSymlink() {
|
||||
t.Error("expected link.txt to be a symlink")
|
||||
}
|
||||
if linkFile.symlink != "real.txt" {
|
||||
t.Errorf("expected symlink target 'real.txt', got %q", linkFile.symlink)
|
||||
}
|
||||
|
||||
// Also verify the raw tar entries have correct types and modes.
|
||||
tr := tar.NewReader(bytes.NewReader(buf.Bytes()))
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("tar.Next failed: %v", err)
|
||||
}
|
||||
switch header.Name {
|
||||
case "link.txt":
|
||||
if header.Typeflag != tar.TypeSymlink {
|
||||
t.Errorf("expected TypeSymlink for link.txt, got %d", header.Typeflag)
|
||||
}
|
||||
if header.Linkname != "real.txt" {
|
||||
t.Errorf("expected Linkname 'real.txt', got %q", header.Linkname)
|
||||
}
|
||||
if header.Mode != 0777 {
|
||||
t.Errorf("expected mode 0777 for symlink, got %o", header.Mode)
|
||||
}
|
||||
case "real.txt":
|
||||
if header.Typeflag != tar.TypeReg {
|
||||
t.Errorf("expected TypeReg for real.txt, got %d", header.Typeflag)
|
||||
}
|
||||
if header.Mode != 0600 {
|
||||
t.Errorf("expected mode 0600 for regular file, got %o", header.Mode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToTarWriter_Empty_Good(t *testing.T) {
|
||||
dn := New()
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := dn.ToTarWriter(&buf); err != nil {
|
||||
t.Fatalf("ToTarWriter on empty DataNode should not error, got: %v", err)
|
||||
}
|
||||
|
||||
// The buffer should contain a valid (empty) tar archive.
|
||||
dn2, err := FromTar(buf.Bytes())
|
||||
if err != nil {
|
||||
t.Fatalf("FromTar on empty tar failed: %v", err)
|
||||
}
|
||||
if len(dn2.files) != 0 {
|
||||
t.Errorf("expected 0 files in empty round-trip, got %d", len(dn2.files))
|
||||
}
|
||||
}
|
||||
|
||||
func toSortedNames(entries []fs.DirEntry) []string {
|
||||
var names []string
|
||||
for _, e := range entries {
|
||||
|
|
|
|||
81
pkg/failures/manager.go
Normal file
81
pkg/failures/manager.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
package failures
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Manager handles the lifecycle of a failure report.
|
||||
type Manager struct {
|
||||
failuresDir string
|
||||
runDir string
|
||||
report *FailureReport
|
||||
}
|
||||
|
||||
// NewManager creates a new failure manager for a given collection.
|
||||
func NewManager(failuresDir, collection string) (*Manager, error) {
|
||||
if failuresDir == "" {
|
||||
failuresDir = ".borg-failures"
|
||||
}
|
||||
runDir := filepath.Join(failuresDir, time.Now().Format("2006-01-02T15-04-05"))
|
||||
if err := os.MkdirAll(runDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create failures directory: %w", err)
|
||||
}
|
||||
|
||||
return &Manager{
|
||||
failuresDir: failuresDir,
|
||||
runDir: runDir,
|
||||
report: &FailureReport{
|
||||
Collection: collection,
|
||||
Started: time.Now(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RecordFailure records a single failure.
|
||||
func (m *Manager) RecordFailure(failure *Failure) {
|
||||
m.report.Failures = append(m.report.Failures, failure)
|
||||
m.report.Stats.Failed++
|
||||
}
|
||||
|
||||
// SetTotal sets the total number of items to be processed.
|
||||
func (m *Manager) SetTotal(total int) {
|
||||
m.report.Stats.Total = total
|
||||
}
|
||||
|
||||
// Finalize completes the failure report, writing it to disk.
|
||||
func (m *Manager) Finalize() error {
|
||||
m.report.Completed = time.Now()
|
||||
m.report.Stats.Success = m.report.Stats.Total - m.report.Stats.Failed
|
||||
|
||||
// Write failures.json
|
||||
reportPath := filepath.Join(m.runDir, "failures.json")
|
||||
reportFile, err := os.Create(reportPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create failures.json: %w", err)
|
||||
}
|
||||
defer reportFile.Close()
|
||||
|
||||
encoder := json.NewEncoder(reportFile)
|
||||
encoder.SetIndent("", " ")
|
||||
if err := encoder.Encode(m.report); err != nil {
|
||||
return fmt.Errorf("failed to write failures.json: %w", err)
|
||||
}
|
||||
|
||||
// Write retry.sh
|
||||
var retryScript strings.Builder
|
||||
retryScript.WriteString("#!/bin/bash\n\n")
|
||||
for _, failure := range m.report.Failures {
|
||||
retryScript.WriteString(fmt.Sprintf("borg collect github repo %s\n", failure.URL))
|
||||
}
|
||||
retryPath := filepath.Join(m.runDir, "retry.sh")
|
||||
if err := os.WriteFile(retryPath, []byte(retryScript.String()), 0755); err != nil {
|
||||
return fmt.Errorf("failed to write retry.sh: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
74
pkg/failures/manager_test.go
Normal file
74
pkg/failures/manager_test.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
package failures
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManager(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "borg-failures-test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
manager, err := NewManager(tempDir, "test-collection")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
manager.SetTotal(1)
|
||||
manager.RecordFailure(&Failure{
|
||||
URL: "http://example.com/failed",
|
||||
Error: "test error",
|
||||
Retryable: true,
|
||||
})
|
||||
|
||||
if err := manager.Finalize(); err != nil {
|
||||
t.Fatalf("failed to finalize manager: %v", err)
|
||||
}
|
||||
|
||||
// Verify failures.json
|
||||
reportPath := filepath.Join(manager.runDir, "failures.json")
|
||||
if _, err := os.Stat(reportPath); os.IsNotExist(err) {
|
||||
t.Fatalf("failures.json was not created")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(reportPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read failures.json: %v", err)
|
||||
}
|
||||
|
||||
var report FailureReport
|
||||
if err := json.Unmarshal(data, &report); err != nil {
|
||||
t.Fatalf("failed to unmarshal failures.json: %v", err)
|
||||
}
|
||||
|
||||
if report.Collection != "test-collection" {
|
||||
t.Errorf("expected collection 'test-collection', got '%s'", report.Collection)
|
||||
}
|
||||
if len(report.Failures) != 1 {
|
||||
t.Fatalf("expected 1 failure, got %d", len(report.Failures))
|
||||
}
|
||||
if report.Failures[0].URL != "http://example.com/failed" {
|
||||
t.Errorf("unexpected failure URL: %s", report.Failures[0].URL)
|
||||
}
|
||||
|
||||
// Verify retry.sh
|
||||
retryPath := filepath.Join(manager.runDir, "retry.sh")
|
||||
if _, err := os.Stat(retryPath); os.IsNotExist(err) {
|
||||
t.Fatalf("retry.sh was not created")
|
||||
}
|
||||
|
||||
retryScript, err := os.ReadFile(retryPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read retry.sh: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(retryScript), "http://example.com/failed") {
|
||||
t.Errorf("retry.sh does not contain the failed URL")
|
||||
}
|
||||
}
|
||||
24
pkg/failures/types.go
Normal file
24
pkg/failures/types.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package failures
|
||||
|
||||
import "time"
|
||||
|
||||
// Failure represents a single failure event.
|
||||
type Failure struct {
|
||||
URL string `json:"url"`
|
||||
Error string `json:"error"`
|
||||
Attempts int `json:"attempts"`
|
||||
Retryable bool `json:"retryable"`
|
||||
}
|
||||
|
||||
// FailureReport represents a collection of failures for a specific run.
|
||||
type FailureReport struct {
|
||||
Collection string `json:"collection"`
|
||||
Started time.Time `json:"started"`
|
||||
Completed time.Time `json:"completed"`
|
||||
Stats struct {
|
||||
Total int `json:"total"`
|
||||
Success int `json:"success"`
|
||||
Failed int `json:"failed"`
|
||||
} `json:"stats"`
|
||||
Failures []*Failure `json:"failures"`
|
||||
}
|
||||
|
|
@ -217,9 +217,7 @@ func (p *pwaClient) DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progr
|
|||
if path == "" {
|
||||
path = "index.html"
|
||||
}
|
||||
mu.Lock()
|
||||
dn.AddData(path, body)
|
||||
mu.Unlock()
|
||||
|
||||
// Parse HTML for additional assets
|
||||
if parseHTML && isHTMLContent(resp.Header.Get("Content-Type"), body) {
|
||||
|
|
|
|||
|
|
@ -1,198 +0,0 @@
|
|||
package tim
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
|
||||
borgtrix "github.com/Snider/Borg/pkg/trix"
|
||||
)
|
||||
|
||||
const (
|
||||
blockSize = 1024 * 1024 // 1 MiB plaintext blocks
|
||||
saltSize = 16
|
||||
nonceSize = 12 // chacha20poly1305.NonceSize
|
||||
lengthSize = 4
|
||||
headerSize = 33 // 4 (magic) + 1 (version) + 16 (salt) + 12 (argon2 params)
|
||||
)
|
||||
|
||||
var (
|
||||
stimMagic = [4]byte{'S', 'T', 'I', 'M'}
|
||||
|
||||
ErrInvalidMagic = errors.New("invalid STIM magic header")
|
||||
ErrUnsupportedVersion = errors.New("unsupported STIM version")
|
||||
ErrStreamDecrypt = errors.New("stream decryption failed")
|
||||
)
|
||||
|
||||
// StreamEncrypt reads plaintext from r and writes STIM v2 chunked AEAD
|
||||
// encrypted data to w. Each 1 MiB block is independently encrypted with
|
||||
// ChaCha20-Poly1305 using a unique random nonce.
|
||||
func StreamEncrypt(r io.Reader, w io.Writer, password string) error {
|
||||
// Generate random salt
|
||||
salt := make([]byte, saltSize)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Derive key using Argon2id with default params
|
||||
params := borgtrix.DefaultArgon2Params()
|
||||
key := borgtrix.DeriveKeyArgon2(password, salt)
|
||||
|
||||
// Create AEAD cipher
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create AEAD: %w", err)
|
||||
}
|
||||
|
||||
// Write header: magic(4) + version(1) + salt(16) + argon2params(12) = 33 bytes
|
||||
header := make([]byte, headerSize)
|
||||
copy(header[0:4], stimMagic[:])
|
||||
header[4] = 2 // version
|
||||
copy(header[5:21], salt)
|
||||
copy(header[21:33], params.Encode())
|
||||
|
||||
if _, err := w.Write(header); err != nil {
|
||||
return fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt data in blocks
|
||||
buf := make([]byte, blockSize)
|
||||
nonce := make([]byte, nonceSize)
|
||||
|
||||
for {
|
||||
n, readErr := io.ReadFull(r, buf)
|
||||
|
||||
if n > 0 {
|
||||
// Generate unique nonce for this block
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt: ciphertext includes the Poly1305 auth tag (16 bytes)
|
||||
ciphertext := aead.Seal(nil, nonce, buf[:n], nil)
|
||||
|
||||
// Write [nonce(12)][length(4)][ciphertext(n+16)]
|
||||
if _, err := w.Write(nonce); err != nil {
|
||||
return fmt.Errorf("failed to write nonce: %w", err)
|
||||
}
|
||||
|
||||
lenBuf := make([]byte, lengthSize)
|
||||
binary.LittleEndian.PutUint32(lenBuf, uint32(len(ciphertext)))
|
||||
if _, err := w.Write(lenBuf); err != nil {
|
||||
return fmt.Errorf("failed to write length: %w", err)
|
||||
}
|
||||
|
||||
if _, err := w.Write(ciphertext); err != nil {
|
||||
return fmt.Errorf("failed to write ciphertext: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
if readErr == io.EOF || readErr == io.ErrUnexpectedEOF {
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("failed to read input: %w", readErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Write EOF marker: [nonce(12)][length=0(4)]
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return fmt.Errorf("failed to generate EOF nonce: %w", err)
|
||||
}
|
||||
if _, err := w.Write(nonce); err != nil {
|
||||
return fmt.Errorf("failed to write EOF nonce: %w", err)
|
||||
}
|
||||
|
||||
eofLen := make([]byte, lengthSize)
|
||||
// length is already zero (zero-value)
|
||||
if _, err := w.Write(eofLen); err != nil {
|
||||
return fmt.Errorf("failed to write EOF length: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamDecrypt reads STIM v2 chunked AEAD encrypted data from r and writes
|
||||
// the decrypted plaintext to w. Returns an error if the header is invalid,
|
||||
// the password is wrong, or data has been tampered with.
|
||||
func StreamDecrypt(r io.Reader, w io.Writer, password string) error {
|
||||
// Read header
|
||||
header := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(r, header); err != nil {
|
||||
return fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
// Validate magic
|
||||
if header[0] != stimMagic[0] || header[1] != stimMagic[1] ||
|
||||
header[2] != stimMagic[2] || header[3] != stimMagic[3] {
|
||||
return ErrInvalidMagic
|
||||
}
|
||||
|
||||
// Validate version
|
||||
if header[4] != 2 {
|
||||
return fmt.Errorf("%w: got %d", ErrUnsupportedVersion, header[4])
|
||||
}
|
||||
|
||||
// Extract salt and params
|
||||
salt := header[5:21]
|
||||
params := borgtrix.DecodeArgon2Params(header[21:33])
|
||||
|
||||
// Derive key using stored params
|
||||
key := deriveKeyWithParams(password, salt, params)
|
||||
|
||||
// Create AEAD cipher
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create AEAD: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt blocks
|
||||
nonce := make([]byte, nonceSize)
|
||||
lenBuf := make([]byte, lengthSize)
|
||||
|
||||
for {
|
||||
// Read nonce
|
||||
if _, err := io.ReadFull(r, nonce); err != nil {
|
||||
return fmt.Errorf("failed to read block nonce: %w", err)
|
||||
}
|
||||
|
||||
// Read length
|
||||
if _, err := io.ReadFull(r, lenBuf); err != nil {
|
||||
return fmt.Errorf("failed to read block length: %w", err)
|
||||
}
|
||||
|
||||
ctLen := binary.LittleEndian.Uint32(lenBuf)
|
||||
|
||||
// EOF marker: length == 0
|
||||
if ctLen == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read ciphertext
|
||||
ciphertext := make([]byte, ctLen)
|
||||
if _, err := io.ReadFull(r, ciphertext); err != nil {
|
||||
return fmt.Errorf("failed to read ciphertext: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt and authenticate
|
||||
plaintext, err := aead.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrStreamDecrypt, err)
|
||||
}
|
||||
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
return fmt.Errorf("failed to write plaintext: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// deriveKeyWithParams derives a 32-byte key using Argon2id with specific
|
||||
// parameters read from the STIM header (rather than using defaults).
|
||||
func deriveKeyWithParams(password string, salt []byte, params borgtrix.Argon2Params) []byte {
|
||||
return argon2.IDKey([]byte(password), salt, params.Time, params.Memory, uint8(params.Threads), 32)
|
||||
}
|
||||
|
|
@ -1,203 +0,0 @@
|
|||
package tim
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStreamRoundTrip_Good(t *testing.T) {
|
||||
plaintext := []byte("Hello, STIM v2 streaming encryption!")
|
||||
password := "test-password-123"
|
||||
|
||||
// Encrypt
|
||||
var cipherBuf bytes.Buffer
|
||||
if err := StreamEncrypt(bytes.NewReader(plaintext), &cipherBuf, password); err != nil {
|
||||
t.Fatalf("StreamEncrypt() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify header magic
|
||||
encrypted := cipherBuf.Bytes()
|
||||
if len(encrypted) < 5 {
|
||||
t.Fatal("encrypted output too short for header")
|
||||
}
|
||||
if string(encrypted[:4]) != "STIM" {
|
||||
t.Errorf("expected magic 'STIM', got %q", string(encrypted[:4]))
|
||||
}
|
||||
if encrypted[4] != 2 {
|
||||
t.Errorf("expected version 2, got %d", encrypted[4])
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
var plainBuf bytes.Buffer
|
||||
if err := StreamDecrypt(bytes.NewReader(encrypted), &plainBuf, password); err != nil {
|
||||
t.Fatalf("StreamDecrypt() error = %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(plainBuf.Bytes(), plaintext) {
|
||||
t.Errorf("round-trip mismatch:\n got: %q\n want: %q", plainBuf.Bytes(), plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamRoundTrip_Large_Good(t *testing.T) {
|
||||
// 3 MiB of pseudo-random data spans multiple 1 MiB blocks
|
||||
plaintext := make([]byte, 3*1024*1024)
|
||||
if _, err := rand.Read(plaintext); err != nil {
|
||||
t.Fatalf("failed to generate random data: %v", err)
|
||||
}
|
||||
|
||||
password := "large-data-password"
|
||||
|
||||
// Encrypt
|
||||
var cipherBuf bytes.Buffer
|
||||
if err := StreamEncrypt(bytes.NewReader(plaintext), &cipherBuf, password); err != nil {
|
||||
t.Fatalf("StreamEncrypt() error = %v", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
var plainBuf bytes.Buffer
|
||||
if err := StreamDecrypt(bytes.NewReader(cipherBuf.Bytes()), &plainBuf, password); err != nil {
|
||||
t.Fatalf("StreamDecrypt() error = %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(plainBuf.Bytes(), plaintext) {
|
||||
t.Errorf("round-trip mismatch: got %d bytes, want %d bytes", plainBuf.Len(), len(plaintext))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamEncrypt_Empty_Good(t *testing.T) {
|
||||
password := "empty-test"
|
||||
|
||||
// Encrypt empty input
|
||||
var cipherBuf bytes.Buffer
|
||||
if err := StreamEncrypt(bytes.NewReader(nil), &cipherBuf, password); err != nil {
|
||||
t.Fatalf("StreamEncrypt() error = %v", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
var plainBuf bytes.Buffer
|
||||
if err := StreamDecrypt(bytes.NewReader(cipherBuf.Bytes()), &plainBuf, password); err != nil {
|
||||
t.Fatalf("StreamDecrypt() error = %v", err)
|
||||
}
|
||||
|
||||
if plainBuf.Len() != 0 {
|
||||
t.Errorf("expected empty output, got %d bytes", plainBuf.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamDecrypt_WrongPassword_Bad(t *testing.T) {
|
||||
plaintext := []byte("secret data that should not decrypt with wrong key")
|
||||
correctPassword := "correct-password"
|
||||
wrongPassword := "wrong-password"
|
||||
|
||||
// Encrypt with correct password
|
||||
var cipherBuf bytes.Buffer
|
||||
if err := StreamEncrypt(bytes.NewReader(plaintext), &cipherBuf, correctPassword); err != nil {
|
||||
t.Fatalf("StreamEncrypt() error = %v", err)
|
||||
}
|
||||
|
||||
// Attempt decrypt with wrong password
|
||||
var plainBuf bytes.Buffer
|
||||
err := StreamDecrypt(bytes.NewReader(cipherBuf.Bytes()), &plainBuf, wrongPassword)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when decrypting with wrong password, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamDecrypt_Truncated_Bad(t *testing.T) {
|
||||
plaintext := []byte("data that will be truncated after encryption")
|
||||
password := "truncation-test"
|
||||
|
||||
// Encrypt
|
||||
var cipherBuf bytes.Buffer
|
||||
if err := StreamEncrypt(bytes.NewReader(plaintext), &cipherBuf, password); err != nil {
|
||||
t.Fatalf("StreamEncrypt() error = %v", err)
|
||||
}
|
||||
|
||||
encrypted := cipherBuf.Bytes()
|
||||
|
||||
// Truncate to just past the header (33 bytes) but before the full first block
|
||||
if len(encrypted) > 40 {
|
||||
truncated := encrypted[:40]
|
||||
var plainBuf bytes.Buffer
|
||||
err := StreamDecrypt(bytes.NewReader(truncated), &plainBuf, password)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when decrypting truncated data, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate mid-way through the ciphertext
|
||||
if len(encrypted) > headerSize+nonceSize+lengthSize+5 {
|
||||
midpoint := headerSize + nonceSize + lengthSize + 5
|
||||
truncated := encrypted[:midpoint]
|
||||
var plainBuf bytes.Buffer
|
||||
err := StreamDecrypt(bytes.NewReader(truncated), &plainBuf, password)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when decrypting mid-block truncated data, got nil")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamDecrypt_InvalidMagic_Bad(t *testing.T) {
|
||||
// Construct data with wrong magic
|
||||
data := []byte("NOPE\x02")
|
||||
data = append(data, make([]byte, 28)...) // pad to header size
|
||||
|
||||
var plainBuf bytes.Buffer
|
||||
err := StreamDecrypt(bytes.NewReader(data), &plainBuf, "password")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid magic, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamDecrypt_InvalidVersion_Bad(t *testing.T) {
|
||||
// Construct data with wrong version
|
||||
data := []byte("STIM\x01")
|
||||
data = append(data, make([]byte, 28)...) // pad to header size
|
||||
|
||||
var plainBuf bytes.Buffer
|
||||
err := StreamDecrypt(bytes.NewReader(data), &plainBuf, "password")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unsupported version, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamDecrypt_ShortHeader_Bad(t *testing.T) {
|
||||
// Too short to contain full header
|
||||
data := []byte("STIM\x02")
|
||||
var plainBuf bytes.Buffer
|
||||
err := StreamDecrypt(bytes.NewReader(data), &plainBuf, "password")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for short header, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamEncrypt_WriterError_Bad(t *testing.T) {
|
||||
plaintext := []byte("test data")
|
||||
// Use a writer that fails after a few bytes
|
||||
w := &limitedWriter{limit: 5}
|
||||
err := StreamEncrypt(bytes.NewReader(plaintext), w, "password")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when writer fails, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// limitedWriter fails after writing limit bytes.
|
||||
type limitedWriter struct {
|
||||
limit int
|
||||
written int
|
||||
}
|
||||
|
||||
func (w *limitedWriter) Write(p []byte) (int, error) {
|
||||
remaining := w.limit - w.written
|
||||
if remaining <= 0 {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
if len(p) > remaining {
|
||||
w.written += remaining
|
||||
return remaining, io.ErrShortWrite
|
||||
}
|
||||
w.written += len(p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
|
@ -2,12 +2,9 @@ package trix
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Enchantrix/pkg/crypt"
|
||||
"github.com/Snider/Enchantrix/pkg/enchantrix"
|
||||
|
|
@ -64,53 +61,11 @@ func FromTrix(data []byte, password string) (*datanode.DataNode, error) {
|
|||
|
||||
// DeriveKey derives a 32-byte key from a password using SHA-256.
|
||||
// This is used for ChaCha20-Poly1305 encryption which requires a 32-byte key.
|
||||
// Deprecated: Use DeriveKeyArgon2 for new code; this remains for backward compatibility.
|
||||
func DeriveKey(password string) []byte {
|
||||
hash := sha256.Sum256([]byte(password))
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
// Argon2Params holds the tunable parameters for Argon2id key derivation.
|
||||
type Argon2Params struct {
|
||||
Time uint32
|
||||
Memory uint32 // in KiB
|
||||
Threads uint32
|
||||
}
|
||||
|
||||
// DefaultArgon2Params returns sensible default parameters for Argon2id.
|
||||
func DefaultArgon2Params() Argon2Params {
|
||||
return Argon2Params{
|
||||
Time: 3,
|
||||
Memory: 64 * 1024,
|
||||
Threads: 4,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode serialises the Argon2Params as 12 bytes (3 x uint32 little-endian).
|
||||
func (p Argon2Params) Encode() []byte {
|
||||
buf := make([]byte, 12)
|
||||
binary.LittleEndian.PutUint32(buf[0:4], p.Time)
|
||||
binary.LittleEndian.PutUint32(buf[4:8], p.Memory)
|
||||
binary.LittleEndian.PutUint32(buf[8:12], p.Threads)
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeArgon2Params reads 12 bytes (3 x uint32 little-endian) into Argon2Params.
|
||||
func DecodeArgon2Params(data []byte) Argon2Params {
|
||||
return Argon2Params{
|
||||
Time: binary.LittleEndian.Uint32(data[0:4]),
|
||||
Memory: binary.LittleEndian.Uint32(data[4:8]),
|
||||
Threads: binary.LittleEndian.Uint32(data[8:12]),
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveKeyArgon2 derives a 32-byte key from a password and salt using Argon2id
|
||||
// with DefaultArgon2Params. This is the recommended key derivation for new code.
|
||||
func DeriveKeyArgon2(password string, salt []byte) []byte {
|
||||
p := DefaultArgon2Params()
|
||||
return argon2.IDKey([]byte(password), salt, p.Time, p.Memory, uint8(p.Threads), 32)
|
||||
}
|
||||
|
||||
// ToTrixChaCha converts a DataNode to encrypted Trix format using ChaCha20-Poly1305.
|
||||
func ToTrixChaCha(dn *datanode.DataNode, password string) ([]byte, error) {
|
||||
if password == "" {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package trix
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
|
|
@ -238,85 +236,3 @@ func TestToTrixChaChaWithLargeData(t *testing.T) {
|
|||
t.Fatalf("Failed to open large.bin: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Argon2id key derivation tests ---
|
||||
|
||||
func TestDeriveKeyArgon2_Good(t *testing.T) {
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
t.Fatalf("failed to generate salt: %v", err)
|
||||
}
|
||||
|
||||
key := DeriveKeyArgon2("test-password", salt)
|
||||
if len(key) != 32 {
|
||||
t.Fatalf("expected 32-byte key, got %d bytes", len(key))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveKeyArgon2_Deterministic_Good(t *testing.T) {
|
||||
salt := []byte("fixed-salt-value")
|
||||
|
||||
key1 := DeriveKeyArgon2("same-password", salt)
|
||||
key2 := DeriveKeyArgon2("same-password", salt)
|
||||
|
||||
if !bytes.Equal(key1, key2) {
|
||||
t.Fatal("same password and salt must produce the same key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveKeyArgon2_DifferentSalt_Good(t *testing.T) {
|
||||
salt1 := []byte("salt-one-value!!")
|
||||
salt2 := []byte("salt-two-value!!")
|
||||
|
||||
key1 := DeriveKeyArgon2("same-password", salt1)
|
||||
key2 := DeriveKeyArgon2("same-password", salt2)
|
||||
|
||||
if bytes.Equal(key1, key2) {
|
||||
t.Fatal("different salts must produce different keys")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveKeyLegacy_Good(t *testing.T) {
|
||||
key1 := DeriveKey("backward-compat")
|
||||
key2 := DeriveKey("backward-compat")
|
||||
|
||||
if len(key1) != 32 {
|
||||
t.Fatalf("expected 32-byte key, got %d bytes", len(key1))
|
||||
}
|
||||
if !bytes.Equal(key1, key2) {
|
||||
t.Fatal("legacy DeriveKey must be deterministic")
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2Params_Good(t *testing.T) {
|
||||
params := DefaultArgon2Params()
|
||||
|
||||
// Non-zero values
|
||||
if params.Time == 0 {
|
||||
t.Fatal("Time must be non-zero")
|
||||
}
|
||||
if params.Memory == 0 {
|
||||
t.Fatal("Memory must be non-zero")
|
||||
}
|
||||
if params.Threads == 0 {
|
||||
t.Fatal("Threads must be non-zero")
|
||||
}
|
||||
|
||||
// Encode produces 12 bytes (3 x uint32 LE)
|
||||
encoded := params.Encode()
|
||||
if len(encoded) != 12 {
|
||||
t.Fatalf("expected 12-byte encoding, got %d bytes", len(encoded))
|
||||
}
|
||||
|
||||
// Round-trip: Decode must recover original values
|
||||
decoded := DecodeArgon2Params(encoded)
|
||||
if decoded.Time != params.Time {
|
||||
t.Fatalf("Time mismatch: got %d, want %d", decoded.Time, params.Time)
|
||||
}
|
||||
if decoded.Memory != params.Memory {
|
||||
t.Fatalf("Memory mismatch: got %d, want %d", decoded.Memory, params.Memory)
|
||||
}
|
||||
if decoded.Threads != params.Threads {
|
||||
t.Fatalf("Threads mismatch: got %d, want %d", decoded.Threads, params.Threads)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,93 +0,0 @@
|
|||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/mattn/go-isatty"
|
||||
)
|
||||
|
||||
// Progress abstracts output for both interactive and scripted use.
|
||||
type Progress interface {
|
||||
Start(label string)
|
||||
Update(current, total int64)
|
||||
Finish(label string)
|
||||
Log(level, msg string, args ...any)
|
||||
}
|
||||
|
||||
// QuietProgress writes structured log lines. For cron, pipes, --quiet.
|
||||
type QuietProgress struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func NewQuietProgress(w io.Writer) *QuietProgress {
|
||||
return &QuietProgress{w: w}
|
||||
}
|
||||
|
||||
func (q *QuietProgress) Start(label string) {
|
||||
fmt.Fprintf(q.w, "[START] %s\n", label)
|
||||
}
|
||||
|
||||
func (q *QuietProgress) Update(current, total int64) {
|
||||
if total > 0 {
|
||||
fmt.Fprintf(q.w, "[PROGRESS] %d/%d\n", current, total)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *QuietProgress) Finish(label string) {
|
||||
fmt.Fprintf(q.w, "[DONE] %s\n", label)
|
||||
}
|
||||
|
||||
func (q *QuietProgress) Log(level, msg string, args ...any) {
|
||||
fmt.Fprintf(q.w, "[%s] %s", level, msg)
|
||||
for i := 0; i+1 < len(args); i += 2 {
|
||||
fmt.Fprintf(q.w, " %v=%v", args[i], args[i+1])
|
||||
}
|
||||
fmt.Fprintln(q.w)
|
||||
}
|
||||
|
||||
// InteractiveProgress uses simple terminal output for TTY sessions.
|
||||
type InteractiveProgress struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func NewInteractiveProgress(w io.Writer) *InteractiveProgress {
|
||||
return &InteractiveProgress{w: w}
|
||||
}
|
||||
|
||||
func (p *InteractiveProgress) Start(label string) {
|
||||
fmt.Fprintf(p.w, "→ %s\n", label)
|
||||
}
|
||||
|
||||
func (p *InteractiveProgress) Update(current, total int64) {
|
||||
if total > 0 {
|
||||
pct := current * 100 / total
|
||||
fmt.Fprintf(p.w, "\r %d%%", pct)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *InteractiveProgress) Finish(label string) {
|
||||
fmt.Fprintf(p.w, "\r✓ %s\n", label)
|
||||
}
|
||||
|
||||
func (p *InteractiveProgress) Log(level, msg string, args ...any) {
|
||||
fmt.Fprintf(p.w, " %s", msg)
|
||||
for i := 0; i+1 < len(args); i += 2 {
|
||||
fmt.Fprintf(p.w, " %v=%v", args[i], args[i+1])
|
||||
}
|
||||
fmt.Fprintln(p.w)
|
||||
}
|
||||
|
||||
// IsTTY returns true if the given file descriptor is a terminal.
|
||||
func IsTTY(fd uintptr) bool {
|
||||
return isatty.IsTerminal(fd) || isatty.IsCygwinTerminal(fd)
|
||||
}
|
||||
|
||||
// DefaultProgress returns InteractiveProgress for TTYs, QuietProgress otherwise.
|
||||
func DefaultProgress() Progress {
|
||||
if IsTTY(os.Stdout.Fd()) {
|
||||
return NewInteractiveProgress(os.Stdout)
|
||||
}
|
||||
return NewQuietProgress(os.Stdout)
|
||||
}
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
package ui
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQuietProgress_Log_Good(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
p := NewQuietProgress(&buf)
|
||||
p.Log("info", "test message", "key", "val")
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "test message") {
|
||||
t.Fatalf("expected log output to contain 'test message', got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuietProgress_StartFinish_Good(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
p := NewQuietProgress(&buf)
|
||||
p.Start("collecting")
|
||||
p.Update(50, 100)
|
||||
p.Finish("done")
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "collecting") {
|
||||
t.Fatalf("expected 'collecting' in output, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "done") {
|
||||
t.Fatalf("expected 'done' in output, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuietProgress_Update_Ugly(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
p := NewQuietProgress(&buf)
|
||||
// Should not panic with zero total
|
||||
p.Update(0, 0)
|
||||
p.Update(5, 0)
|
||||
}
|
||||
|
||||
func TestInteractiveProgress_StartFinish_Good(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
p := NewInteractiveProgress(&buf)
|
||||
p.Start("collecting")
|
||||
p.Finish("done")
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "collecting") {
|
||||
t.Fatalf("expected 'collecting', got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "done") {
|
||||
t.Fatalf("expected 'done', got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInteractiveProgress_Update_Good(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
p := NewInteractiveProgress(&buf)
|
||||
p.Update(50, 100)
|
||||
if !strings.Contains(buf.String(), "50%") {
|
||||
t.Fatalf("expected '50%%', got: %s", buf.String())
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue