Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
google-labs-jules[bot]
e057ddfa7f feat: Add Proxy and Tor Support for Website Collector
This commit introduces support for routing website collection requests through HTTP/SOCKS5 proxies or Tor.

New flags for the `borg collect website` command:
- `--proxy <url>`: Route requests through a single HTTP or SOCKS5 proxy.
- `--proxy-list <file>`: Select a random proxy from a file for all requests.
- `--tor`: Route requests through a Tor SOCKS5 proxy (defaults to 127.0.0.1:9050).

Key changes:
- Created a new `pkg/httpclient` to centralize the creation of proxy-configured `http.Client` instances.
- Refactored `pkg/website` to use a `Downloader` interface, allowing for dependency injection of the HTTP client.
- Added validation to ensure the new proxy flags are mutually exclusive.
- Implemented support for SOCKS5 authentication via credentials in the proxy URL.
- Added comprehensive unit tests for the new functionality.

Co-authored-by: Snider <631881+Snider@users.noreply.github.com>
2026-02-02 00:52:32 +00:00
5 changed files with 293 additions and 21 deletions

View file

@ -6,6 +6,7 @@ import (
"github.com/schollz/progressbar/v3"
"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/httpclient"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
"github.com/Snider/Borg/pkg/ui"
@ -51,7 +52,34 @@ func NewCollectWebsiteCmd() *cobra.Command {
bar = ui.NewProgressBar(-1, "Crawling website")
}
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
proxy, _ := cmd.Flags().GetString("proxy")
proxyList, _ := cmd.Flags().GetString("proxy-list")
tor, _ := cmd.Flags().GetBool("tor")
// Validate that only one proxy flag is used
proxyFlags := 0
if proxy != "" {
proxyFlags++
}
if proxyList != "" {
proxyFlags++
}
if tor {
proxyFlags++
}
if proxyFlags > 1 {
return fmt.Errorf("only one of --proxy, --proxy-list, or --tor can be used at a time")
}
httpClient, err := httpclient.NewClient(proxy, proxyList, tor)
if err != nil {
return fmt.Errorf("error creating http client: %w", err)
}
downloader := website.NewDownloaderWithClient(depth, httpClient)
downloader.SetProgressBar(bar)
dn, err := downloader.Download(websiteURL)
if err != nil {
return fmt.Errorf("error downloading and packaging website: %w", err)
}
@ -104,5 +132,8 @@ func NewCollectWebsiteCmd() *cobra.Command {
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
collectWebsiteCmd.PersistentFlags().String("proxy", "", "Proxy URL (e.g. http://proxy:8080)")
collectWebsiteCmd.PersistentFlags().String("proxy-list", "", "Path to a file with a list of proxies (one randomly selected)")
collectWebsiteCmd.PersistentFlags().Bool("tor", false, "Use Tor for requests")
return collectWebsiteCmd
}

View file

@ -2,6 +2,7 @@ package cmd
import (
"fmt"
"net/http"
"path/filepath"
"strings"
"testing"
@ -32,15 +33,26 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
}
}
type mockDownloader struct {
err error
}
func (m *mockDownloader) Download(startURL string) (*datanode.DataNode, error) {
return nil, m.err
}
func (m *mockDownloader) SetProgressBar(bar *progressbar.ProgressBar) {
// do nothing
}
func TestCollectWebsiteCmd_Bad(t *testing.T) {
// Mock the website downloader to return an error
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
return nil, fmt.Errorf("website error")
oldNewDownloader := website.NewDownloaderWithClient
website.NewDownloaderWithClient = func(maxDepth int, client *http.Client) website.Downloader {
return &mockDownloader{err: fmt.Errorf("website error")}
}
defer func() {
website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite
}()
t.Cleanup(func() {
website.NewDownloaderWithClient = oldNewDownloader
})
rootCmd := NewRootCmd()
rootCmd.AddCommand(GetCollectCmd())

90
pkg/httpclient/client.go Normal file
View file

@ -0,0 +1,90 @@
package httpclient
import (
"bufio"
"fmt"
"net/http"
"net/url"
"os"
"math/rand"
"time"
"golang.org/x/net/proxy"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
// NewClient creates a new http.Client with the specified proxy settings.
func NewClient(proxyURL, proxyList string, useTor bool) (*http.Client, error) {
if useTor {
proxyURL = "socks5://127.0.0.1:9050"
}
if proxyList != "" {
proxies, err := readProxyList(proxyList)
if err != nil {
return nil, err
}
if len(proxies) > 0 {
proxyURL = proxies[rand.Intn(len(proxies))]
}
}
if proxyURL != "" {
proxyURLParsed, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("error parsing proxy URL: %w", err)
}
var transport http.RoundTripper
if proxyURLParsed.Scheme == "socks5" {
var auth *proxy.Auth
if proxyURLParsed.User != nil {
password, _ := proxyURLParsed.User.Password()
auth = &proxy.Auth{
User: proxyURLParsed.User.Username(),
Password: password,
}
}
dialer, err := proxy.SOCKS5("tcp", proxyURLParsed.Host, auth, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("error creating SOCKS5 dialer: %w", err)
}
transport = &http.Transport{
Dial: dialer.Dial,
}
} else {
transport = &http.Transport{
Proxy: http.ProxyURL(proxyURLParsed),
}
}
return &http.Client{
Transport: transport,
}, nil
}
return &http.Client{}, nil
}
func readProxyList(filename string) ([]string, error) {
file, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("error opening proxy list file: %w", err)
}
defer file.Close()
var proxies []string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
proxies = append(proxies, scanner.Text())
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading proxy list file: %w", err)
}
return proxies, nil
}

View file

@ -0,0 +1,124 @@
package httpclient
import (
"net/http"
"net/url"
"os"
"testing"
)
func TestNewClient(t *testing.T) {
// Test case for no proxy
t.Run("NoProxy", func(t *testing.T) {
client, err := NewClient("", "", false)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if client.Transport != nil {
t.Errorf("Expected nil Transport, got %T", client.Transport)
}
})
// Test case for Tor
t.Run("Tor", func(t *testing.T) {
client, err := NewClient("", "", true)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("Expected http.Transport, got %T", client.Transport)
}
if transport.Dial == nil {
t.Error("Expected a custom dialer for Tor, got nil")
}
})
// Test case for a single proxy
t.Run("SingleProxy", func(t *testing.T) {
proxyURL := "http://localhost:8080"
client, err := NewClient(proxyURL, "", false)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("Expected http.Transport, got %T", client.Transport)
}
proxyFunc := transport.Proxy
if proxyFunc == nil {
t.Fatal("Expected a proxy function, got nil")
}
req, _ := http.NewRequest("GET", "http://example.com", nil)
proxy, err := proxyFunc(req)
if err != nil {
t.Fatalf("Expected no error from proxy func, got %v", err)
}
expectedProxy, _ := url.Parse(proxyURL)
if proxy.String() != expectedProxy.String() {
t.Errorf("Expected proxy %s, got %s", expectedProxy, proxy)
}
})
// Test case for a proxy list
t.Run("ProxyList", func(t *testing.T) {
proxies := []string{
"http://localhost:8081",
"http://localhost:8082",
"http://localhost:8083",
}
proxyFile, err := os.CreateTemp("", "proxies.txt")
if err != nil {
t.Fatalf("Failed to create temp proxy file: %v", err)
}
defer os.Remove(proxyFile.Name())
for _, p := range proxies {
proxyFile.WriteString(p + "\n")
}
proxyFile.Close()
client, err := NewClient("", proxyFile.Name(), false)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("Expected http.Transport, got %T", client.Transport)
}
proxyFunc := transport.Proxy
if proxyFunc == nil {
t.Fatal("Expected a proxy function, got nil")
}
req, _ := http.NewRequest("GET", "http://example.com", nil)
proxy, err := proxyFunc(req)
if err != nil {
t.Fatalf("Expected no error from proxy func, got %v", err)
}
found := false
for _, p := range proxies {
if proxy.String() == p {
found = true
break
}
}
if !found {
t.Errorf("Expected proxy to be one of %v, got %s", proxies, proxy)
}
})
t.Run("SOCKS5WithAuth", func(t *testing.T) {
proxyURL := "socks5://user:password@localhost:1080"
client, err := NewClient(proxyURL, "", false)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("Expected http.Transport, got %T", client.Transport)
}
if transport.Dial == nil {
t.Fatal("Expected a custom dialer for SOCKS5, got nil")
}
})
}

View file

@ -15,8 +15,14 @@ import (
var DownloadAndPackageWebsite = downloadAndPackageWebsite
// Downloader is a recursive website downloader.
type Downloader struct {
// Downloader is an interface for a recursive website downloader.
type Downloader interface {
Download(startURL string) (*datanode.DataNode, error)
SetProgressBar(bar *progressbar.ProgressBar)
}
// downloader is a recursive website downloader.
type downloader struct {
baseURL *url.URL
dn *datanode.DataNode
visited map[string]bool
@ -27,13 +33,13 @@ type Downloader struct {
}
// NewDownloader creates a new Downloader.
func NewDownloader(maxDepth int) *Downloader {
func NewDownloader(maxDepth int) Downloader {
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
}
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
return &Downloader{
var NewDownloaderWithClient = func(maxDepth int, client *http.Client) Downloader {
return &downloader{
dn: datanode.New(),
visited: make(map[string]bool),
maxDepth: maxDepth,
@ -44,14 +50,23 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
d := NewDownloader(maxDepth)
d.SetProgressBar(bar)
return d.Download(startURL)
}
// SetProgressBar sets the progress bar for the downloader.
func (d *downloader) SetProgressBar(bar *progressbar.ProgressBar) {
d.progressBar = bar
}
// Download downloads a website and packages it into a DataNode.
func (d *downloader) Download(startURL string) (*datanode.DataNode, error) {
baseURL, err := url.Parse(startURL)
if err != nil {
return nil, err
}
d := NewDownloader(maxDepth)
d.baseURL = baseURL
d.progressBar = bar
d.crawl(startURL, 0)
if len(d.errors) > 0 {
@ -65,7 +80,7 @@ func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P
return d.dn, nil
}
func (d *Downloader) crawl(pageURL string, depth int) {
func (d *downloader) crawl(pageURL string, depth int) {
if depth > d.maxDepth || d.visited[pageURL] {
return
}
@ -132,7 +147,7 @@ func (d *Downloader) crawl(pageURL string, depth int) {
f(doc)
}
func (d *Downloader) downloadAsset(assetURL string) {
func (d *downloader) downloadAsset(assetURL string) {
if d.visited[assetURL] {
return
}
@ -163,7 +178,7 @@ func (d *Downloader) downloadAsset(assetURL string) {
d.dn.AddData(relPath, body)
}
func (d *Downloader) getRelativePath(pageURL string) string {
func (d *downloader) getRelativePath(pageURL string) string {
u, err := url.Parse(pageURL)
if err != nil {
return ""
@ -175,7 +190,7 @@ func (d *Downloader) getRelativePath(pageURL string) string {
return path
}
func (d *Downloader) resolveURL(base, ref string) (string, error) {
func (d *downloader) resolveURL(base, ref string) (string, error) {
baseURL, err := url.Parse(base)
if err != nil {
return "", err
@ -187,7 +202,7 @@ func (d *Downloader) resolveURL(base, ref string) (string, error) {
return baseURL.ResolveReference(refURL).String(), nil
}
func (d *Downloader) isLocal(pageURL string) bool {
func (d *downloader) isLocal(pageURL string) bool {
u, err := url.Parse(pageURL)
if err != nil {
return false