Compare commits
1 commit
main
...
feat-proxy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e057ddfa7f |
9 changed files with 296 additions and 357 deletions
|
|
@ -1,333 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type CollectLocalCmd struct {
|
||||
cobra.Command
|
||||
}
|
||||
|
||||
// NewCollectLocalCmd creates a new collect local command
|
||||
func NewCollectLocalCmd() *CollectLocalCmd {
|
||||
c := &CollectLocalCmd{}
|
||||
c.Command = cobra.Command{
|
||||
Use: "local [directory]",
|
||||
Short: "Collect files from a local directory",
|
||||
Long: `Collect files from a local directory and store them in a DataNode.
|
||||
|
||||
If no directory is specified, the current working directory is used.
|
||||
|
||||
Examples:
|
||||
borg collect local
|
||||
borg collect local ./src
|
||||
borg collect local /path/to/project --output project.tar
|
||||
borg collect local . --format stim --password secret
|
||||
borg collect local . --exclude "*.log" --exclude "node_modules"`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
directory := "."
|
||||
if len(args) > 0 {
|
||||
directory = args[0]
|
||||
}
|
||||
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
password, _ := cmd.Flags().GetString("password")
|
||||
excludes, _ := cmd.Flags().GetStringSlice("exclude")
|
||||
includeHidden, _ := cmd.Flags().GetBool("hidden")
|
||||
respectGitignore, _ := cmd.Flags().GetBool("gitignore")
|
||||
|
||||
finalPath, err := CollectLocal(directory, outputFile, format, compression, password, excludes, includeHidden, respectGitignore)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Files saved to", finalPath)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
c.Flags().String("output", "", "Output file for the DataNode")
|
||||
c.Flags().String("format", "datanode", "Output format (datanode, tim, trix, or stim)")
|
||||
c.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
c.Flags().String("password", "", "Password for encryption (required for stim/trix format)")
|
||||
c.Flags().StringSlice("exclude", nil, "Patterns to exclude (can be specified multiple times)")
|
||||
c.Flags().Bool("hidden", false, "Include hidden files and directories")
|
||||
c.Flags().Bool("gitignore", true, "Respect .gitignore files (default: true)")
|
||||
return c
|
||||
}
|
||||
|
||||
func init() {
|
||||
collectCmd.AddCommand(&NewCollectLocalCmd().Command)
|
||||
}
|
||||
|
||||
// CollectLocal collects files from a local directory into a DataNode
|
||||
func CollectLocal(directory string, outputFile string, format string, compression string, password string, excludes []string, includeHidden bool, respectGitignore bool) (string, error) {
|
||||
// Validate format
|
||||
if format != "datanode" && format != "tim" && format != "trix" && format != "stim" {
|
||||
return "", fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', 'trix', or 'stim')", format)
|
||||
}
|
||||
if (format == "stim" || format == "trix") && password == "" {
|
||||
return "", fmt.Errorf("password is required for %s format", format)
|
||||
}
|
||||
if compression != "none" && compression != "gz" && compression != "xz" {
|
||||
return "", fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression)
|
||||
}
|
||||
|
||||
// Resolve directory path
|
||||
absDir, err := filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error resolving directory path: %w", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(absDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error accessing directory: %w", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory: %s", absDir)
|
||||
}
|
||||
|
||||
// Load gitignore patterns if enabled
|
||||
var gitignorePatterns []string
|
||||
if respectGitignore {
|
||||
gitignorePatterns = loadGitignore(absDir)
|
||||
}
|
||||
|
||||
// Create DataNode and collect files
|
||||
dn := datanode.New()
|
||||
var fileCount int
|
||||
|
||||
bar := ui.NewProgressBar(-1, "Scanning files")
|
||||
defer bar.Finish()
|
||||
|
||||
err = filepath.WalkDir(absDir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get relative path
|
||||
relPath, err := filepath.Rel(absDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip root
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip hidden files/dirs unless explicitly included
|
||||
if !includeHidden && isHidden(relPath) {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check gitignore patterns
|
||||
if respectGitignore && matchesGitignore(relPath, d.IsDir(), gitignorePatterns) {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check exclude patterns
|
||||
if matchesExclude(relPath, excludes) {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip directories (they're implicit in DataNode)
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read file content
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading %s: %w", relPath, err)
|
||||
}
|
||||
|
||||
// Add to DataNode with forward slashes (tar convention)
|
||||
dn.AddData(filepath.ToSlash(relPath), content)
|
||||
fileCount++
|
||||
bar.Describe(fmt.Sprintf("Collected %d files", fileCount))
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error walking directory: %w", err)
|
||||
}
|
||||
|
||||
if fileCount == 0 {
|
||||
return "", fmt.Errorf("no files found in %s", directory)
|
||||
}
|
||||
|
||||
bar.Describe(fmt.Sprintf("Packaging %d files", fileCount))
|
||||
|
||||
// Convert to output format
|
||||
var data []byte
|
||||
if format == "tim" {
|
||||
t, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = t.ToTar()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing tim: %w", err)
|
||||
}
|
||||
} else if format == "stim" {
|
||||
t, err := tim.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating tim: %w", err)
|
||||
}
|
||||
data, err = t.ToSigil(password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encrypting stim: %w", err)
|
||||
}
|
||||
} else if format == "trix" {
|
||||
data, err = trix.ToTrix(dn, password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing trix: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing DataNode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply compression
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error compressing data: %w", err)
|
||||
}
|
||||
|
||||
// Determine output filename
|
||||
if outputFile == "" {
|
||||
baseName := filepath.Base(absDir)
|
||||
if baseName == "." || baseName == "/" {
|
||||
baseName = "local"
|
||||
}
|
||||
outputFile = baseName + "." + format
|
||||
if compression != "none" {
|
||||
outputFile += "." + compression
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error writing output file: %w", err)
|
||||
}
|
||||
|
||||
return outputFile, nil
|
||||
}
|
||||
|
||||
// isHidden checks if a path component starts with a dot
|
||||
func isHidden(path string) bool {
|
||||
parts := strings.Split(filepath.ToSlash(path), "/")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(part, ".") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadGitignore loads patterns from .gitignore if it exists
|
||||
func loadGitignore(dir string) []string {
|
||||
var patterns []string
|
||||
|
||||
gitignorePath := filepath.Join(dir, ".gitignore")
|
||||
content, err := os.ReadFile(gitignorePath)
|
||||
if err != nil {
|
||||
return patterns
|
||||
}
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
// Skip empty lines and comments
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
patterns = append(patterns, line)
|
||||
}
|
||||
|
||||
return patterns
|
||||
}
|
||||
|
||||
// matchesGitignore checks if a path matches any gitignore pattern
|
||||
func matchesGitignore(path string, isDir bool, patterns []string) bool {
|
||||
for _, pattern := range patterns {
|
||||
// Handle directory-only patterns
|
||||
if strings.HasSuffix(pattern, "/") {
|
||||
if !isDir {
|
||||
continue
|
||||
}
|
||||
pattern = strings.TrimSuffix(pattern, "/")
|
||||
}
|
||||
|
||||
// Handle negation (simplified - just skip negated patterns)
|
||||
if strings.HasPrefix(pattern, "!") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Match against path components
|
||||
matched, _ := filepath.Match(pattern, filepath.Base(path))
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Also try matching the full path
|
||||
matched, _ = filepath.Match(pattern, path)
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle ** patterns (simplified)
|
||||
if strings.Contains(pattern, "**") {
|
||||
simplePattern := strings.ReplaceAll(pattern, "**", "*")
|
||||
matched, _ = filepath.Match(simplePattern, path)
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesExclude checks if a path matches any exclude pattern
|
||||
func matchesExclude(path string, excludes []string) bool {
|
||||
for _, pattern := range excludes {
|
||||
// Match against basename
|
||||
matched, _ := filepath.Match(pattern, filepath.Base(path))
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match against full path
|
||||
matched, _ = filepath.Match(pattern, path)
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/httpclient"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
|
|
@ -51,7 +52,34 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
bar = ui.NewProgressBar(-1, "Crawling website")
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
|
||||
proxy, _ := cmd.Flags().GetString("proxy")
|
||||
proxyList, _ := cmd.Flags().GetString("proxy-list")
|
||||
tor, _ := cmd.Flags().GetBool("tor")
|
||||
|
||||
// Validate that only one proxy flag is used
|
||||
proxyFlags := 0
|
||||
if proxy != "" {
|
||||
proxyFlags++
|
||||
}
|
||||
if proxyList != "" {
|
||||
proxyFlags++
|
||||
}
|
||||
if tor {
|
||||
proxyFlags++
|
||||
}
|
||||
if proxyFlags > 1 {
|
||||
return fmt.Errorf("only one of --proxy, --proxy-list, or --tor can be used at a time")
|
||||
}
|
||||
|
||||
httpClient, err := httpclient.NewClient(proxy, proxyList, tor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating http client: %w", err)
|
||||
}
|
||||
|
||||
downloader := website.NewDownloaderWithClient(depth, httpClient)
|
||||
downloader.SetProgressBar(bar)
|
||||
|
||||
dn, err := downloader.Download(websiteURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error downloading and packaging website: %w", err)
|
||||
}
|
||||
|
|
@ -104,5 +132,8 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
|
||||
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
|
||||
collectWebsiteCmd.PersistentFlags().String("proxy", "", "Proxy URL (e.g. http://proxy:8080)")
|
||||
collectWebsiteCmd.PersistentFlags().String("proxy-list", "", "Path to a file with a list of proxies (one randomly selected)")
|
||||
collectWebsiteCmd.PersistentFlags().Bool("tor", false, "Use Tor for requests")
|
||||
return collectWebsiteCmd
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
|
@ -32,15 +33,26 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type mockDownloader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockDownloader) Download(startURL string) (*datanode.DataNode, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
func (m *mockDownloader) SetProgressBar(bar *progressbar.ProgressBar) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func TestCollectWebsiteCmd_Bad(t *testing.T) {
|
||||
// Mock the website downloader to return an error
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
return nil, fmt.Errorf("website error")
|
||||
oldNewDownloader := website.NewDownloaderWithClient
|
||||
website.NewDownloaderWithClient = func(maxDepth int, client *http.Client) website.Downloader {
|
||||
return &mockDownloader{err: fmt.Errorf("website error")}
|
||||
}
|
||||
defer func() {
|
||||
website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
website.NewDownloaderWithClient = oldNewDownloader
|
||||
})
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(GetCollectCmd())
|
||||
|
|
|
|||
BIN
examples/demo-sample.smsg
Normal file
BIN
examples/demo-sample.smsg
Normal file
Binary file not shown.
2
go.mod
2
go.mod
|
|
@ -60,7 +60,7 @@ require (
|
|||
github.com/wailsapp/go-webview2 v1.0.22 // indirect
|
||||
github.com/wailsapp/mimetype v1.4.1 // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/crypto v0.44.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/term v0.37.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -155,8 +155,8 @@ github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI
|
|||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
|
|
|
|||
90
pkg/httpclient/client.go
Normal file
90
pkg/httpclient/client.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
package httpclient
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// NewClient creates a new http.Client with the specified proxy settings.
|
||||
func NewClient(proxyURL, proxyList string, useTor bool) (*http.Client, error) {
|
||||
if useTor {
|
||||
proxyURL = "socks5://127.0.0.1:9050"
|
||||
}
|
||||
|
||||
if proxyList != "" {
|
||||
proxies, err := readProxyList(proxyList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(proxies) > 0 {
|
||||
proxyURL = proxies[rand.Intn(len(proxies))]
|
||||
}
|
||||
}
|
||||
|
||||
if proxyURL != "" {
|
||||
proxyURLParsed, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing proxy URL: %w", err)
|
||||
}
|
||||
|
||||
var transport http.RoundTripper
|
||||
if proxyURLParsed.Scheme == "socks5" {
|
||||
var auth *proxy.Auth
|
||||
if proxyURLParsed.User != nil {
|
||||
password, _ := proxyURLParsed.User.Password()
|
||||
auth = &proxy.Auth{
|
||||
User: proxyURLParsed.User.Username(),
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyURLParsed.Host, auth, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating SOCKS5 dialer: %w", err)
|
||||
}
|
||||
transport = &http.Transport{
|
||||
Dial: dialer.Dial,
|
||||
}
|
||||
} else {
|
||||
transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURLParsed),
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &http.Client{}, nil
|
||||
}
|
||||
|
||||
func readProxyList(filename string) ([]string, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening proxy list file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var proxies []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
proxies = append(proxies, scanner.Text())
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading proxy list file: %w", err)
|
||||
}
|
||||
|
||||
return proxies, nil
|
||||
}
|
||||
124
pkg/httpclient/client_test.go
Normal file
124
pkg/httpclient/client_test.go
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
package httpclient
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
// Test case for no proxy
|
||||
t.Run("NoProxy", func(t *testing.T) {
|
||||
client, err := NewClient("", "", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if client.Transport != nil {
|
||||
t.Errorf("Expected nil Transport, got %T", client.Transport)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case for Tor
|
||||
t.Run("Tor", func(t *testing.T) {
|
||||
client, err := NewClient("", "", true)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("Expected http.Transport, got %T", client.Transport)
|
||||
}
|
||||
if transport.Dial == nil {
|
||||
t.Error("Expected a custom dialer for Tor, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
// Test case for a single proxy
|
||||
t.Run("SingleProxy", func(t *testing.T) {
|
||||
proxyURL := "http://localhost:8080"
|
||||
client, err := NewClient(proxyURL, "", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("Expected http.Transport, got %T", client.Transport)
|
||||
}
|
||||
proxyFunc := transport.Proxy
|
||||
if proxyFunc == nil {
|
||||
t.Fatal("Expected a proxy function, got nil")
|
||||
}
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
proxy, err := proxyFunc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error from proxy func, got %v", err)
|
||||
}
|
||||
expectedProxy, _ := url.Parse(proxyURL)
|
||||
if proxy.String() != expectedProxy.String() {
|
||||
t.Errorf("Expected proxy %s, got %s", expectedProxy, proxy)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case for a proxy list
|
||||
t.Run("ProxyList", func(t *testing.T) {
|
||||
proxies := []string{
|
||||
"http://localhost:8081",
|
||||
"http://localhost:8082",
|
||||
"http://localhost:8083",
|
||||
}
|
||||
proxyFile, err := os.CreateTemp("", "proxies.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp proxy file: %v", err)
|
||||
}
|
||||
defer os.Remove(proxyFile.Name())
|
||||
for _, p := range proxies {
|
||||
proxyFile.WriteString(p + "\n")
|
||||
}
|
||||
proxyFile.Close()
|
||||
|
||||
client, err := NewClient("", proxyFile.Name(), false)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("Expected http.Transport, got %T", client.Transport)
|
||||
}
|
||||
proxyFunc := transport.Proxy
|
||||
if proxyFunc == nil {
|
||||
t.Fatal("Expected a proxy function, got nil")
|
||||
}
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
proxy, err := proxyFunc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error from proxy func, got %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, p := range proxies {
|
||||
if proxy.String() == p {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected proxy to be one of %v, got %s", proxies, proxy)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SOCKS5WithAuth", func(t *testing.T) {
|
||||
proxyURL := "socks5://user:password@localhost:1080"
|
||||
client, err := NewClient(proxyURL, "", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("Expected http.Transport, got %T", client.Transport)
|
||||
}
|
||||
if transport.Dial == nil {
|
||||
t.Fatal("Expected a custom dialer for SOCKS5, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -15,8 +15,14 @@ import (
|
|||
|
||||
var DownloadAndPackageWebsite = downloadAndPackageWebsite
|
||||
|
||||
// Downloader is a recursive website downloader.
|
||||
type Downloader struct {
|
||||
// Downloader is an interface for a recursive website downloader.
|
||||
type Downloader interface {
|
||||
Download(startURL string) (*datanode.DataNode, error)
|
||||
SetProgressBar(bar *progressbar.ProgressBar)
|
||||
}
|
||||
|
||||
// downloader is a recursive website downloader.
|
||||
type downloader struct {
|
||||
baseURL *url.URL
|
||||
dn *datanode.DataNode
|
||||
visited map[string]bool
|
||||
|
|
@ -27,13 +33,13 @@ type Downloader struct {
|
|||
}
|
||||
|
||||
// NewDownloader creates a new Downloader.
|
||||
func NewDownloader(maxDepth int) *Downloader {
|
||||
func NewDownloader(maxDepth int) Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
|
||||
}
|
||||
|
||||
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
|
||||
func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
||||
return &Downloader{
|
||||
var NewDownloaderWithClient = func(maxDepth int, client *http.Client) Downloader {
|
||||
return &downloader{
|
||||
dn: datanode.New(),
|
||||
visited: make(map[string]bool),
|
||||
maxDepth: maxDepth,
|
||||
|
|
@ -44,14 +50,23 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
|||
|
||||
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
d := NewDownloader(maxDepth)
|
||||
d.SetProgressBar(bar)
|
||||
return d.Download(startURL)
|
||||
}
|
||||
|
||||
// SetProgressBar sets the progress bar for the downloader.
|
||||
func (d *downloader) SetProgressBar(bar *progressbar.ProgressBar) {
|
||||
d.progressBar = bar
|
||||
}
|
||||
|
||||
// Download downloads a website and packages it into a DataNode.
|
||||
func (d *downloader) Download(startURL string) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(startURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := NewDownloader(maxDepth)
|
||||
d.baseURL = baseURL
|
||||
d.progressBar = bar
|
||||
d.crawl(startURL, 0)
|
||||
|
||||
if len(d.errors) > 0 {
|
||||
|
|
@ -65,7 +80,7 @@ func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P
|
|||
return d.dn, nil
|
||||
}
|
||||
|
||||
func (d *Downloader) crawl(pageURL string, depth int) {
|
||||
func (d *downloader) crawl(pageURL string, depth int) {
|
||||
if depth > d.maxDepth || d.visited[pageURL] {
|
||||
return
|
||||
}
|
||||
|
|
@ -132,7 +147,7 @@ func (d *Downloader) crawl(pageURL string, depth int) {
|
|||
f(doc)
|
||||
}
|
||||
|
||||
func (d *Downloader) downloadAsset(assetURL string) {
|
||||
func (d *downloader) downloadAsset(assetURL string) {
|
||||
if d.visited[assetURL] {
|
||||
return
|
||||
}
|
||||
|
|
@ -163,7 +178,7 @@ func (d *Downloader) downloadAsset(assetURL string) {
|
|||
d.dn.AddData(relPath, body)
|
||||
}
|
||||
|
||||
func (d *Downloader) getRelativePath(pageURL string) string {
|
||||
func (d *downloader) getRelativePath(pageURL string) string {
|
||||
u, err := url.Parse(pageURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
|
|
@ -175,7 +190,7 @@ func (d *Downloader) getRelativePath(pageURL string) string {
|
|||
return path
|
||||
}
|
||||
|
||||
func (d *Downloader) resolveURL(base, ref string) (string, error) {
|
||||
func (d *downloader) resolveURL(base, ref string) (string, error) {
|
||||
baseURL, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -187,7 +202,7 @@ func (d *Downloader) resolveURL(base, ref string) (string, error) {
|
|||
return baseURL.ResolveReference(refURL).String(), nil
|
||||
}
|
||||
|
||||
func (d *Downloader) isLocal(pageURL string) bool {
|
||||
func (d *downloader) isLocal(pageURL string) bool {
|
||||
u, err := url.Parse(pageURL)
|
||||
if err != nil {
|
||||
return false
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue