Compare commits
1 commit
main
...
feat-proxy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e057ddfa7f |
5 changed files with 293 additions and 21 deletions
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/httpclient"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
|
|
@ -51,7 +52,34 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
bar = ui.NewProgressBar(-1, "Crawling website")
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
|
||||
proxy, _ := cmd.Flags().GetString("proxy")
|
||||
proxyList, _ := cmd.Flags().GetString("proxy-list")
|
||||
tor, _ := cmd.Flags().GetBool("tor")
|
||||
|
||||
// Validate that only one proxy flag is used
|
||||
proxyFlags := 0
|
||||
if proxy != "" {
|
||||
proxyFlags++
|
||||
}
|
||||
if proxyList != "" {
|
||||
proxyFlags++
|
||||
}
|
||||
if tor {
|
||||
proxyFlags++
|
||||
}
|
||||
if proxyFlags > 1 {
|
||||
return fmt.Errorf("only one of --proxy, --proxy-list, or --tor can be used at a time")
|
||||
}
|
||||
|
||||
httpClient, err := httpclient.NewClient(proxy, proxyList, tor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating http client: %w", err)
|
||||
}
|
||||
|
||||
downloader := website.NewDownloaderWithClient(depth, httpClient)
|
||||
downloader.SetProgressBar(bar)
|
||||
|
||||
dn, err := downloader.Download(websiteURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error downloading and packaging website: %w", err)
|
||||
}
|
||||
|
|
@ -104,5 +132,8 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
|
||||
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
|
||||
collectWebsiteCmd.PersistentFlags().String("proxy", "", "Proxy URL (e.g. http://proxy:8080)")
|
||||
collectWebsiteCmd.PersistentFlags().String("proxy-list", "", "Path to a file with a list of proxies (one randomly selected)")
|
||||
collectWebsiteCmd.PersistentFlags().Bool("tor", false, "Use Tor for requests")
|
||||
return collectWebsiteCmd
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
|
@ -32,15 +33,26 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type mockDownloader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockDownloader) Download(startURL string) (*datanode.DataNode, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
func (m *mockDownloader) SetProgressBar(bar *progressbar.ProgressBar) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func TestCollectWebsiteCmd_Bad(t *testing.T) {
|
||||
// Mock the website downloader to return an error
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
return nil, fmt.Errorf("website error")
|
||||
oldNewDownloader := website.NewDownloaderWithClient
|
||||
website.NewDownloaderWithClient = func(maxDepth int, client *http.Client) website.Downloader {
|
||||
return &mockDownloader{err: fmt.Errorf("website error")}
|
||||
}
|
||||
defer func() {
|
||||
website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
website.NewDownloaderWithClient = oldNewDownloader
|
||||
})
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(GetCollectCmd())
|
||||
|
|
|
|||
90
pkg/httpclient/client.go
Normal file
90
pkg/httpclient/client.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
package httpclient
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// NewClient creates a new http.Client with the specified proxy settings.
|
||||
func NewClient(proxyURL, proxyList string, useTor bool) (*http.Client, error) {
|
||||
if useTor {
|
||||
proxyURL = "socks5://127.0.0.1:9050"
|
||||
}
|
||||
|
||||
if proxyList != "" {
|
||||
proxies, err := readProxyList(proxyList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(proxies) > 0 {
|
||||
proxyURL = proxies[rand.Intn(len(proxies))]
|
||||
}
|
||||
}
|
||||
|
||||
if proxyURL != "" {
|
||||
proxyURLParsed, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing proxy URL: %w", err)
|
||||
}
|
||||
|
||||
var transport http.RoundTripper
|
||||
if proxyURLParsed.Scheme == "socks5" {
|
||||
var auth *proxy.Auth
|
||||
if proxyURLParsed.User != nil {
|
||||
password, _ := proxyURLParsed.User.Password()
|
||||
auth = &proxy.Auth{
|
||||
User: proxyURLParsed.User.Username(),
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyURLParsed.Host, auth, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating SOCKS5 dialer: %w", err)
|
||||
}
|
||||
transport = &http.Transport{
|
||||
Dial: dialer.Dial,
|
||||
}
|
||||
} else {
|
||||
transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURLParsed),
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &http.Client{}, nil
|
||||
}
|
||||
|
||||
func readProxyList(filename string) ([]string, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening proxy list file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var proxies []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
proxies = append(proxies, scanner.Text())
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading proxy list file: %w", err)
|
||||
}
|
||||
|
||||
return proxies, nil
|
||||
}
|
||||
124
pkg/httpclient/client_test.go
Normal file
124
pkg/httpclient/client_test.go
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
package httpclient
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
// Test case for no proxy
|
||||
t.Run("NoProxy", func(t *testing.T) {
|
||||
client, err := NewClient("", "", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if client.Transport != nil {
|
||||
t.Errorf("Expected nil Transport, got %T", client.Transport)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case for Tor
|
||||
t.Run("Tor", func(t *testing.T) {
|
||||
client, err := NewClient("", "", true)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("Expected http.Transport, got %T", client.Transport)
|
||||
}
|
||||
if transport.Dial == nil {
|
||||
t.Error("Expected a custom dialer for Tor, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
// Test case for a single proxy
|
||||
t.Run("SingleProxy", func(t *testing.T) {
|
||||
proxyURL := "http://localhost:8080"
|
||||
client, err := NewClient(proxyURL, "", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("Expected http.Transport, got %T", client.Transport)
|
||||
}
|
||||
proxyFunc := transport.Proxy
|
||||
if proxyFunc == nil {
|
||||
t.Fatal("Expected a proxy function, got nil")
|
||||
}
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
proxy, err := proxyFunc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error from proxy func, got %v", err)
|
||||
}
|
||||
expectedProxy, _ := url.Parse(proxyURL)
|
||||
if proxy.String() != expectedProxy.String() {
|
||||
t.Errorf("Expected proxy %s, got %s", expectedProxy, proxy)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case for a proxy list
|
||||
t.Run("ProxyList", func(t *testing.T) {
|
||||
proxies := []string{
|
||||
"http://localhost:8081",
|
||||
"http://localhost:8082",
|
||||
"http://localhost:8083",
|
||||
}
|
||||
proxyFile, err := os.CreateTemp("", "proxies.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp proxy file: %v", err)
|
||||
}
|
||||
defer os.Remove(proxyFile.Name())
|
||||
for _, p := range proxies {
|
||||
proxyFile.WriteString(p + "\n")
|
||||
}
|
||||
proxyFile.Close()
|
||||
|
||||
client, err := NewClient("", proxyFile.Name(), false)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("Expected http.Transport, got %T", client.Transport)
|
||||
}
|
||||
proxyFunc := transport.Proxy
|
||||
if proxyFunc == nil {
|
||||
t.Fatal("Expected a proxy function, got nil")
|
||||
}
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
proxy, err := proxyFunc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error from proxy func, got %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, p := range proxies {
|
||||
if proxy.String() == p {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected proxy to be one of %v, got %s", proxies, proxy)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SOCKS5WithAuth", func(t *testing.T) {
|
||||
proxyURL := "socks5://user:password@localhost:1080"
|
||||
client, err := NewClient(proxyURL, "", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("Expected http.Transport, got %T", client.Transport)
|
||||
}
|
||||
if transport.Dial == nil {
|
||||
t.Fatal("Expected a custom dialer for SOCKS5, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -15,8 +15,14 @@ import (
|
|||
|
||||
var DownloadAndPackageWebsite = downloadAndPackageWebsite
|
||||
|
||||
// Downloader is a recursive website downloader.
|
||||
type Downloader struct {
|
||||
// Downloader is an interface for a recursive website downloader.
|
||||
type Downloader interface {
|
||||
Download(startURL string) (*datanode.DataNode, error)
|
||||
SetProgressBar(bar *progressbar.ProgressBar)
|
||||
}
|
||||
|
||||
// downloader is a recursive website downloader.
|
||||
type downloader struct {
|
||||
baseURL *url.URL
|
||||
dn *datanode.DataNode
|
||||
visited map[string]bool
|
||||
|
|
@ -27,13 +33,13 @@ type Downloader struct {
|
|||
}
|
||||
|
||||
// NewDownloader creates a new Downloader.
|
||||
func NewDownloader(maxDepth int) *Downloader {
|
||||
func NewDownloader(maxDepth int) Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
|
||||
}
|
||||
|
||||
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
|
||||
func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
||||
return &Downloader{
|
||||
var NewDownloaderWithClient = func(maxDepth int, client *http.Client) Downloader {
|
||||
return &downloader{
|
||||
dn: datanode.New(),
|
||||
visited: make(map[string]bool),
|
||||
maxDepth: maxDepth,
|
||||
|
|
@ -44,14 +50,23 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
|||
|
||||
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
d := NewDownloader(maxDepth)
|
||||
d.SetProgressBar(bar)
|
||||
return d.Download(startURL)
|
||||
}
|
||||
|
||||
// SetProgressBar sets the progress bar for the downloader.
|
||||
func (d *downloader) SetProgressBar(bar *progressbar.ProgressBar) {
|
||||
d.progressBar = bar
|
||||
}
|
||||
|
||||
// Download downloads a website and packages it into a DataNode.
|
||||
func (d *downloader) Download(startURL string) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(startURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := NewDownloader(maxDepth)
|
||||
d.baseURL = baseURL
|
||||
d.progressBar = bar
|
||||
d.crawl(startURL, 0)
|
||||
|
||||
if len(d.errors) > 0 {
|
||||
|
|
@ -65,7 +80,7 @@ func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P
|
|||
return d.dn, nil
|
||||
}
|
||||
|
||||
func (d *Downloader) crawl(pageURL string, depth int) {
|
||||
func (d *downloader) crawl(pageURL string, depth int) {
|
||||
if depth > d.maxDepth || d.visited[pageURL] {
|
||||
return
|
||||
}
|
||||
|
|
@ -132,7 +147,7 @@ func (d *Downloader) crawl(pageURL string, depth int) {
|
|||
f(doc)
|
||||
}
|
||||
|
||||
func (d *Downloader) downloadAsset(assetURL string) {
|
||||
func (d *downloader) downloadAsset(assetURL string) {
|
||||
if d.visited[assetURL] {
|
||||
return
|
||||
}
|
||||
|
|
@ -163,7 +178,7 @@ func (d *Downloader) downloadAsset(assetURL string) {
|
|||
d.dn.AddData(relPath, body)
|
||||
}
|
||||
|
||||
func (d *Downloader) getRelativePath(pageURL string) string {
|
||||
func (d *downloader) getRelativePath(pageURL string) string {
|
||||
u, err := url.Parse(pageURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
|
|
@ -175,7 +190,7 @@ func (d *Downloader) getRelativePath(pageURL string) string {
|
|||
return path
|
||||
}
|
||||
|
||||
func (d *Downloader) resolveURL(base, ref string) (string, error) {
|
||||
func (d *downloader) resolveURL(base, ref string) (string, error) {
|
||||
baseURL, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -187,7 +202,7 @@ func (d *Downloader) resolveURL(base, ref string) (string, error) {
|
|||
return baseURL.ResolveReference(refURL).String(), nil
|
||||
}
|
||||
|
||||
func (d *Downloader) isLocal(pageURL string) bool {
|
||||
func (d *downloader) isLocal(pageURL string) bool {
|
||||
u, err := url.Parse(pageURL)
|
||||
if err != nil {
|
||||
return false
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue