Compare commits
3 commits
main
...
feat-paral
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cfd644b421 | ||
|
|
9baf943862 | ||
|
|
05bfafad2b |
11 changed files with 363 additions and 437 deletions
|
|
@ -2,8 +2,14 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/github"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
|
@ -17,17 +23,80 @@ var collectGithubReposCmd = &cobra.Command{
|
|||
Short: "Collects all public repositories for a user or organization",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
parallel, _ := cmd.Flags().GetInt("parallel")
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
password, _ := cmd.Flags().GetString("password")
|
||||
|
||||
repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, repo := range repos {
|
||||
fmt.Fprintln(cmd.OutOrStdout(), repo)
|
||||
|
||||
prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote)
|
||||
prompter.Start()
|
||||
defer prompter.Stop()
|
||||
var bar *progressbar.ProgressBar
|
||||
if prompter.IsInteractive() {
|
||||
bar = ui.NewProgressBar(len(repos), "Cloning repositories")
|
||||
}
|
||||
|
||||
downloader := github.NewDownloader(parallel, bar)
|
||||
dn, err := downloader.DownloadRepositories(cmd.Context(), repos)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if format == "tim" {
|
||||
tim, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = tim.ToTar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing tim: %w", err)
|
||||
}
|
||||
} else if format == "trix" {
|
||||
data, err = trix.ToTrix(dn, password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing trix: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing DataNode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error compressing data: %w", err)
|
||||
}
|
||||
|
||||
if outputFile == "" {
|
||||
outputFile = args[0] + "." + format
|
||||
if compression != "none" {
|
||||
outputFile += "." + compression
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing repos to file: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Repositories saved to", outputFile)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
collectGithubCmd.AddCommand(collectGithubReposCmd)
|
||||
collectGithubReposCmd.PersistentFlags().Int("parallel", 1, "Number of concurrent workers")
|
||||
collectGithubReposCmd.PersistentFlags().String("output", "", "Output file for the DataNode")
|
||||
collectGithubReposCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
|
||||
collectGithubReposCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
collectGithubReposCmd.PersistentFlags().String("password", "", "Password for encryption")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,333 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type CollectLocalCmd struct {
|
||||
cobra.Command
|
||||
}
|
||||
|
||||
// NewCollectLocalCmd creates a new collect local command
|
||||
func NewCollectLocalCmd() *CollectLocalCmd {
|
||||
c := &CollectLocalCmd{}
|
||||
c.Command = cobra.Command{
|
||||
Use: "local [directory]",
|
||||
Short: "Collect files from a local directory",
|
||||
Long: `Collect files from a local directory and store them in a DataNode.
|
||||
|
||||
If no directory is specified, the current working directory is used.
|
||||
|
||||
Examples:
|
||||
borg collect local
|
||||
borg collect local ./src
|
||||
borg collect local /path/to/project --output project.tar
|
||||
borg collect local . --format stim --password secret
|
||||
borg collect local . --exclude "*.log" --exclude "node_modules"`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
directory := "."
|
||||
if len(args) > 0 {
|
||||
directory = args[0]
|
||||
}
|
||||
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
password, _ := cmd.Flags().GetString("password")
|
||||
excludes, _ := cmd.Flags().GetStringSlice("exclude")
|
||||
includeHidden, _ := cmd.Flags().GetBool("hidden")
|
||||
respectGitignore, _ := cmd.Flags().GetBool("gitignore")
|
||||
|
||||
finalPath, err := CollectLocal(directory, outputFile, format, compression, password, excludes, includeHidden, respectGitignore)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Files saved to", finalPath)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
c.Flags().String("output", "", "Output file for the DataNode")
|
||||
c.Flags().String("format", "datanode", "Output format (datanode, tim, trix, or stim)")
|
||||
c.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
c.Flags().String("password", "", "Password for encryption (required for stim/trix format)")
|
||||
c.Flags().StringSlice("exclude", nil, "Patterns to exclude (can be specified multiple times)")
|
||||
c.Flags().Bool("hidden", false, "Include hidden files and directories")
|
||||
c.Flags().Bool("gitignore", true, "Respect .gitignore files (default: true)")
|
||||
return c
|
||||
}
|
||||
|
||||
func init() {
|
||||
collectCmd.AddCommand(&NewCollectLocalCmd().Command)
|
||||
}
|
||||
|
||||
// CollectLocal collects files from a local directory into a DataNode
|
||||
func CollectLocal(directory string, outputFile string, format string, compression string, password string, excludes []string, includeHidden bool, respectGitignore bool) (string, error) {
|
||||
// Validate format
|
||||
if format != "datanode" && format != "tim" && format != "trix" && format != "stim" {
|
||||
return "", fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', 'trix', or 'stim')", format)
|
||||
}
|
||||
if (format == "stim" || format == "trix") && password == "" {
|
||||
return "", fmt.Errorf("password is required for %s format", format)
|
||||
}
|
||||
if compression != "none" && compression != "gz" && compression != "xz" {
|
||||
return "", fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression)
|
||||
}
|
||||
|
||||
// Resolve directory path
|
||||
absDir, err := filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error resolving directory path: %w", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(absDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error accessing directory: %w", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory: %s", absDir)
|
||||
}
|
||||
|
||||
// Load gitignore patterns if enabled
|
||||
var gitignorePatterns []string
|
||||
if respectGitignore {
|
||||
gitignorePatterns = loadGitignore(absDir)
|
||||
}
|
||||
|
||||
// Create DataNode and collect files
|
||||
dn := datanode.New()
|
||||
var fileCount int
|
||||
|
||||
bar := ui.NewProgressBar(-1, "Scanning files")
|
||||
defer bar.Finish()
|
||||
|
||||
err = filepath.WalkDir(absDir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get relative path
|
||||
relPath, err := filepath.Rel(absDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip root
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip hidden files/dirs unless explicitly included
|
||||
if !includeHidden && isHidden(relPath) {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check gitignore patterns
|
||||
if respectGitignore && matchesGitignore(relPath, d.IsDir(), gitignorePatterns) {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check exclude patterns
|
||||
if matchesExclude(relPath, excludes) {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip directories (they're implicit in DataNode)
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read file content
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading %s: %w", relPath, err)
|
||||
}
|
||||
|
||||
// Add to DataNode with forward slashes (tar convention)
|
||||
dn.AddData(filepath.ToSlash(relPath), content)
|
||||
fileCount++
|
||||
bar.Describe(fmt.Sprintf("Collected %d files", fileCount))
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error walking directory: %w", err)
|
||||
}
|
||||
|
||||
if fileCount == 0 {
|
||||
return "", fmt.Errorf("no files found in %s", directory)
|
||||
}
|
||||
|
||||
bar.Describe(fmt.Sprintf("Packaging %d files", fileCount))
|
||||
|
||||
// Convert to output format
|
||||
var data []byte
|
||||
if format == "tim" {
|
||||
t, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = t.ToTar()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing tim: %w", err)
|
||||
}
|
||||
} else if format == "stim" {
|
||||
t, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = t.ToSigil(password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encrypting stim: %w", err)
|
||||
}
|
||||
} else if format == "trix" {
|
||||
data, err = trix.ToTrix(dn, password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing trix: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing DataNode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply compression
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error compressing data: %w", err)
|
||||
}
|
||||
|
||||
// Determine output filename
|
||||
if outputFile == "" {
|
||||
baseName := filepath.Base(absDir)
|
||||
if baseName == "." || baseName == "/" {
|
||||
baseName = "local"
|
||||
}
|
||||
outputFile = baseName + "." + format
|
||||
if compression != "none" {
|
||||
outputFile += "." + compression
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error writing output file: %w", err)
|
||||
}
|
||||
|
||||
return outputFile, nil
|
||||
}
|
||||
|
||||
// isHidden checks if a path component starts with a dot
|
||||
func isHidden(path string) bool {
|
||||
parts := strings.Split(filepath.ToSlash(path), "/")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(part, ".") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadGitignore loads patterns from .gitignore if it exists
|
||||
func loadGitignore(dir string) []string {
|
||||
var patterns []string
|
||||
|
||||
gitignorePath := filepath.Join(dir, ".gitignore")
|
||||
content, err := os.ReadFile(gitignorePath)
|
||||
if err != nil {
|
||||
return patterns
|
||||
}
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
// Skip empty lines and comments
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
patterns = append(patterns, line)
|
||||
}
|
||||
|
||||
return patterns
|
||||
}
|
||||
|
||||
// matchesGitignore checks if a path matches any gitignore pattern
|
||||
func matchesGitignore(path string, isDir bool, patterns []string) bool {
|
||||
for _, pattern := range patterns {
|
||||
// Handle directory-only patterns
|
||||
if strings.HasSuffix(pattern, "/") {
|
||||
if !isDir {
|
||||
continue
|
||||
}
|
||||
pattern = strings.TrimSuffix(pattern, "/")
|
||||
}
|
||||
|
||||
// Handle negation (simplified - just skip negated patterns)
|
||||
if strings.HasPrefix(pattern, "!") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Match against path components
|
||||
matched, _ := filepath.Match(pattern, filepath.Base(path))
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Also try matching the full path
|
||||
matched, _ = filepath.Match(pattern, path)
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle ** patterns (simplified)
|
||||
if strings.Contains(pattern, "**") {
|
||||
simplePattern := strings.ReplaceAll(pattern, "**", "*")
|
||||
matched, _ = filepath.Match(simplePattern, path)
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesExclude checks if a path matches any exclude pattern
|
||||
func matchesExclude(path string, excludes []string) bool {
|
||||
for _, pattern := range excludes {
|
||||
// Match against basename
|
||||
matched, _ := filepath.Match(pattern, filepath.Base(path))
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match against full path
|
||||
matched, _ = filepath.Match(pattern, path)
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -35,6 +35,8 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
websiteURL := args[0]
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
depth, _ := cmd.Flags().GetInt("depth")
|
||||
parallel, _ := cmd.Flags().GetInt("parallel")
|
||||
rateLimit, _ := cmd.Flags().GetFloat64("rate-limit")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
password, _ := cmd.Flags().GetString("password")
|
||||
|
|
@ -51,7 +53,7 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
bar = ui.NewProgressBar(-1, "Crawling website")
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
|
||||
dn, err := website.DownloadAndPackageWebsite(cmd.Context(), websiteURL, depth, parallel, rateLimit, bar)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error downloading and packaging website: %w", err)
|
||||
}
|
||||
|
|
@ -101,6 +103,8 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
}
|
||||
collectWebsiteCmd.PersistentFlags().String("output", "", "Output file for the DataNode")
|
||||
collectWebsiteCmd.PersistentFlags().Int("depth", 2, "Recursion depth for downloading")
|
||||
collectWebsiteCmd.PersistentFlags().Int("parallel", 1, "Number of concurrent workers")
|
||||
collectWebsiteCmd.PersistentFlags().Float64("rate-limit", 0, "Max requests per second per domain")
|
||||
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
|
||||
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
|
||||
|
|
|
|||
|
|
@ -11,10 +11,14 @@ import (
|
|||
"github.com/schollz/progressbar/v3"
|
||||
)
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func TestCollectWebsiteCmd_Good(t *testing.T) {
|
||||
// Mock the website downloader
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
website.DownloadAndPackageWebsite = func(ctx context.Context, startURL string, maxDepth, parallel int, rateLimit float64, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
return datanode.New(), nil
|
||||
}
|
||||
defer func() {
|
||||
|
|
@ -35,7 +39,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
|
|||
func TestCollectWebsiteCmd_Bad(t *testing.T) {
|
||||
// Mock the website downloader to return an error
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
website.DownloadAndPackageWebsite = func(ctx context.Context, startURL string, maxDepth, parallel int, rateLimit float64, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
return nil, fmt.Errorf("website error")
|
||||
}
|
||||
defer func() {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
|
|
@ -11,7 +12,7 @@ func main() {
|
|||
log.Println("Collecting website...")
|
||||
|
||||
// Download and package the website.
|
||||
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
|
||||
dn, err := website.DownloadAndPackageWebsite(context.Background(), "https://example.com", 2, 1, 0, nil)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to collect website: %v", err)
|
||||
}
|
||||
|
|
|
|||
BIN
examples/demo-sample.smsg
Normal file
BIN
examples/demo-sample.smsg
Normal file
Binary file not shown.
3
go.mod
3
go.mod
|
|
@ -60,9 +60,10 @@ require (
|
|||
github.com/wailsapp/go-webview2 v1.0.22 // indirect
|
||||
github.com/wailsapp/mimetype v1.4.1 // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/crypto v0.44.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/term v0.37.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/time v0.8.0 // indirect
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
)
|
||||
|
|
|
|||
6
go.sum
6
go.sum
|
|
@ -155,8 +155,8 @@ github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI
|
|||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
|
|
@ -192,6 +192,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
|
||||
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
|
|
|
|||
128
pkg/github/downloader.go
Normal file
128
pkg/github/downloader.go
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
package github
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/vcs"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
)
|
||||
|
||||
// Downloader manages a pool of workers for cloning repositories.
|
||||
type Downloader struct {
|
||||
parallel int
|
||||
bar *progressbar.ProgressBar
|
||||
cloner vcs.GitCloner
|
||||
}
|
||||
|
||||
// NewDownloader creates a new Downloader.
|
||||
func NewDownloader(parallel int, bar *progressbar.ProgressBar) *Downloader {
|
||||
return &Downloader{
|
||||
parallel: parallel,
|
||||
bar: bar,
|
||||
cloner: vcs.NewGitCloner(),
|
||||
}
|
||||
}
|
||||
|
||||
// DownloadRepositories downloads a list of repositories in parallel.
|
||||
func (d *Downloader) DownloadRepositories(ctx context.Context, repos []string) (*datanode.DataNode, error) {
|
||||
var wg sync.WaitGroup
|
||||
repoChan := make(chan string, len(repos))
|
||||
errChan := make(chan error, len(repos))
|
||||
mergedDN := datanode.New()
|
||||
var mu sync.Mutex
|
||||
|
||||
for i := 0; i < d.parallel; i++ {
|
||||
wg.Add(1)
|
||||
go d.worker(ctx, &wg, repoChan, mergedDN, &mu, errChan)
|
||||
}
|
||||
|
||||
for _, repo := range repos {
|
||||
select {
|
||||
case repoChan <- repo:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
close(repoChan)
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
var errs []error
|
||||
for err := range errChan {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return nil, fmt.Errorf("errors cloning repositories: %v", errs)
|
||||
}
|
||||
|
||||
return mergedDN, nil
|
||||
}
|
||||
|
||||
func (d *Downloader) worker(ctx context.Context, wg *sync.WaitGroup, repoChan <-chan string, mergedDN *datanode.DataNode, mu *sync.Mutex, errChan chan<- error) {
|
||||
defer wg.Done()
|
||||
for repoURL := range repoChan {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
repoName, err := GetRepoNameFromURL(repoURL)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
continue
|
||||
}
|
||||
|
||||
dn, err := d.cloner.CloneGitRepository(repoURL, nil)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error cloning %s: %w", repoURL, err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = dn.Walk(".", func(path string, de fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !de.IsDir() {
|
||||
file, err := dn.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mu.Lock()
|
||||
mergedDN.AddData(fmt.Sprintf("%s/%s", repoName, path), content)
|
||||
mu.Unlock()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
|
||||
if d.bar != nil {
|
||||
d.bar.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRepoNameFromURL extracts the repository name from a Git URL.
|
||||
func GetRepoNameFromURL(repoURL string) (string, error) {
|
||||
u, err := url.Parse(repoURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
path := strings.TrimSuffix(u.Path, ".git")
|
||||
return strings.TrimPrefix(path, "/"), nil
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
package website
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
|
@ -9,8 +10,9 @@ import (
|
|||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
"golang.org/x/time/rate"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var DownloadAndPackageWebsite = downloadAndPackageWebsite
|
||||
|
|
@ -21,38 +23,51 @@ type Downloader struct {
|
|||
dn *datanode.DataNode
|
||||
visited map[string]bool
|
||||
maxDepth int
|
||||
parallel int
|
||||
progressBar *progressbar.ProgressBar
|
||||
client *http.Client
|
||||
errors []error
|
||||
mu sync.Mutex
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
// NewDownloader creates a new Downloader.
|
||||
func NewDownloader(maxDepth int) *Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
|
||||
func NewDownloader(maxDepth, parallel int, rateLimit float64) *Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, parallel, rateLimit, http.DefaultClient)
|
||||
}
|
||||
|
||||
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
|
||||
func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
||||
func NewDownloaderWithClient(maxDepth, parallel int, rateLimit float64, client *http.Client) *Downloader {
|
||||
var limiter *rate.Limiter
|
||||
if rateLimit > 0 {
|
||||
limiter = rate.NewLimiter(rate.Limit(rateLimit), 1)
|
||||
}
|
||||
return &Downloader{
|
||||
dn: datanode.New(),
|
||||
visited: make(map[string]bool),
|
||||
maxDepth: maxDepth,
|
||||
parallel: parallel,
|
||||
client: client,
|
||||
errors: make([]error, 0),
|
||||
limiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
func downloadAndPackageWebsite(ctx context.Context, startURL string, maxDepth, parallel int, rateLimit float64, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(startURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := NewDownloader(maxDepth)
|
||||
d := NewDownloader(maxDepth, parallel, rateLimit)
|
||||
d.baseURL = baseURL
|
||||
d.progressBar = bar
|
||||
d.crawl(startURL, 0)
|
||||
d.crawl(ctx, startURL)
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(d.errors) > 0 {
|
||||
var errs []string
|
||||
|
|
@ -65,102 +80,136 @@ func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P
|
|||
return d.dn, nil
|
||||
}
|
||||
|
||||
func (d *Downloader) crawl(pageURL string, depth int) {
|
||||
if depth > d.maxDepth || d.visited[pageURL] {
|
||||
return
|
||||
}
|
||||
d.visited[pageURL] = true
|
||||
if d.progressBar != nil {
|
||||
d.progressBar.Add(1)
|
||||
type crawlJob struct {
|
||||
url string
|
||||
depth int
|
||||
}
|
||||
|
||||
func (d *Downloader) crawl(ctx context.Context, startURL string) {
|
||||
var wg sync.WaitGroup
|
||||
var jobWg sync.WaitGroup
|
||||
jobChan := make(chan crawlJob, 100)
|
||||
|
||||
for i := 0; i < d.parallel; i++ {
|
||||
wg.Add(1)
|
||||
go d.worker(ctx, &wg, &jobWg, jobChan)
|
||||
}
|
||||
|
||||
resp, err := d.client.Get(pageURL)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
jobWg.Add(1)
|
||||
jobChan <- crawlJob{url: startURL, depth: 0}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status))
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
jobWg.Wait()
|
||||
close(jobChan)
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err))
|
||||
return
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
relPath := d.getRelativePath(pageURL)
|
||||
d.dn.AddData(relPath, body)
|
||||
func (d *Downloader) worker(ctx context.Context, wg *sync.WaitGroup, jobWg *sync.WaitGroup, jobChan chan crawlJob) {
|
||||
defer wg.Done()
|
||||
for job := range jobChan {
|
||||
func() {
|
||||
defer jobWg.Done()
|
||||
|
||||
// Don't try to parse non-html content
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "text/html") {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
doc, err := html.Parse(strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error parsing HTML of %s: %w", pageURL, err))
|
||||
return
|
||||
}
|
||||
if job.depth > d.maxDepth {
|
||||
return
|
||||
}
|
||||
|
||||
var f func(*html.Node)
|
||||
f = func(n *html.Node) {
|
||||
if n.Type == html.ElementNode {
|
||||
for _, a := range n.Attr {
|
||||
if a.Key == "href" || a.Key == "src" {
|
||||
link, err := d.resolveURL(pageURL, a.Val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if d.isLocal(link) {
|
||||
if isAsset(link) {
|
||||
d.downloadAsset(link)
|
||||
} else {
|
||||
d.crawl(link, depth+1)
|
||||
d.mu.Lock()
|
||||
if d.visited[job.url] {
|
||||
d.mu.Unlock()
|
||||
return
|
||||
}
|
||||
d.visited[job.url] = true
|
||||
d.mu.Unlock()
|
||||
|
||||
if d.progressBar != nil {
|
||||
d.progressBar.Add(1)
|
||||
}
|
||||
|
||||
if d.limiter != nil {
|
||||
d.limiter.Wait(ctx)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", job.url, nil)
|
||||
if err != nil {
|
||||
d.addError(fmt.Errorf("Error creating request for %s: %w", job.url, err))
|
||||
return
|
||||
}
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
d.addError(fmt.Errorf("Error getting %s: %w", job.url, err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
d.addError(fmt.Errorf("bad status for %s: %s", job.url, resp.Status))
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
d.addError(fmt.Errorf("Error reading body of %s: %w", job.url, err))
|
||||
return
|
||||
}
|
||||
|
||||
relPath := d.getRelativePath(job.url)
|
||||
d.mu.Lock()
|
||||
d.dn.AddData(relPath, body)
|
||||
d.mu.Unlock()
|
||||
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "text/html") {
|
||||
return
|
||||
}
|
||||
|
||||
doc, err := html.Parse(strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
d.addError(fmt.Errorf("Error parsing HTML of %s: %w", job.url, err))
|
||||
return
|
||||
}
|
||||
|
||||
var f func(*html.Node)
|
||||
f = func(n *html.Node) {
|
||||
if n.Type == html.ElementNode {
|
||||
for _, a := range n.Attr {
|
||||
if a.Key == "href" || a.Key == "src" {
|
||||
link, err := d.resolveURL(job.url, a.Val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if d.isLocal(link) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
jobWg.Add(1)
|
||||
jobChan <- crawlJob{url: link, depth: job.depth + 1}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||
f(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||
f(c)
|
||||
}
|
||||
f(doc)
|
||||
}()
|
||||
}
|
||||
f(doc)
|
||||
}
|
||||
|
||||
func (d *Downloader) downloadAsset(assetURL string) {
|
||||
if d.visited[assetURL] {
|
||||
return
|
||||
}
|
||||
d.visited[assetURL] = true
|
||||
if d.progressBar != nil {
|
||||
d.progressBar.Add(1)
|
||||
}
|
||||
|
||||
resp, err := d.client.Get(assetURL)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting asset %s: %w", assetURL, err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
d.errors = append(d.errors, fmt.Errorf("bad status for asset %s: %s", assetURL, resp.Status))
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error reading body of asset %s: %w", assetURL, err))
|
||||
return
|
||||
}
|
||||
|
||||
relPath := d.getRelativePath(assetURL)
|
||||
d.dn.AddData(relPath, body)
|
||||
func (d *Downloader) addError(err error) {
|
||||
d.mu.Lock()
|
||||
d.errors = append(d.errors, err)
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
func (d *Downloader) getRelativePath(pageURL string) string {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package website
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
|
|
@ -20,7 +21,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar)
|
||||
dn, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 2, 1, 0, bar)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -52,7 +53,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
|
|||
|
||||
func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
||||
t.Run("Invalid Start URL", func(t *testing.T) {
|
||||
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil)
|
||||
_, err := DownloadAndPackageWebsite(context.TODO(), "http://invalid-url", 1, 1, 0, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for an invalid start URL, but got nil")
|
||||
}
|
||||
|
|
@ -63,7 +64,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
|||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
_, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 1, 1, 0, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for a server error on the start URL, but got nil")
|
||||
}
|
||||
|
|
@ -80,7 +81,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
|||
}))
|
||||
defer server.Close()
|
||||
// We expect an error because the link is broken.
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
dn, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 1, 1, 0, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for a broken link, but got nil")
|
||||
}
|
||||
|
|
@ -99,7 +100,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1
|
||||
dn, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 1, 1, 0, bar) // Max depth of 1
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -122,7 +123,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
fmt.Fprint(w, `<a href="http://externalsite.com/page.html">External</a>`)
|
||||
}))
|
||||
defer server.Close()
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
dn, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 1, 1, 0, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -156,7 +157,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
// For now, we'll just test that it doesn't hang forever.
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
_, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 1, 1, 0, nil)
|
||||
if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
|
||||
// We expect a timeout error, but other errors are failures.
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue