Compare commits
1 commit
main
...
feat/conne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76a097295f |
6 changed files with 116 additions and 14 deletions
|
|
@ -3,9 +3,11 @@ package cmd
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/httpclient"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
|
|
@ -38,6 +40,11 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
password, _ := cmd.Flags().GetString("password")
|
||||
maxConnections, _ := cmd.Flags().GetInt("max-connections")
|
||||
noKeepAlive, _ := cmd.Flags().GetBool("no-keepalive")
|
||||
http1, _ := cmd.Flags().GetBool("http1")
|
||||
idleTimeout, _ := cmd.Flags().GetDuration("idle-timeout")
|
||||
maxIdle, _ := cmd.Flags().GetInt("max-idle")
|
||||
|
||||
if format != "datanode" && format != "tim" && format != "trix" {
|
||||
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format)
|
||||
|
|
@ -51,11 +58,25 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
bar = ui.NewProgressBar(-1, "Crawling website")
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
|
||||
// Create a new HTTP client with the specified options.
|
||||
client, metrics := httpclient.New(httpclient.Options{
|
||||
MaxPerHost: maxConnections,
|
||||
NoKeepAlive: noKeepAlive,
|
||||
HTTP1: http1,
|
||||
IdleTimeout: idleTimeout,
|
||||
MaxIdle: maxIdle,
|
||||
})
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error downloading and packaging website: %w", err)
|
||||
}
|
||||
|
||||
// Display the connection reuse metrics.
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Connection Metrics:")
|
||||
fmt.Fprintf(cmd.OutOrStdout(), " Connections Reused: %d\n", metrics.ConnectionsReused)
|
||||
fmt.Fprintf(cmd.OutOrStdout(), " Connections Created: %d\n", metrics.ConnectionsCreated)
|
||||
|
||||
var data []byte
|
||||
if format == "tim" {
|
||||
tim, err := tim.FromDataNode(dn)
|
||||
|
|
@ -104,5 +125,10 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
|
||||
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
|
||||
collectWebsiteCmd.Flags().Int("max-connections", 6, "Max connections per domain")
|
||||
collectWebsiteCmd.Flags().Bool("no-keepalive", false, "Disable keep-alive")
|
||||
collectWebsiteCmd.Flags().Bool("http1", false, "Force HTTP/1.1")
|
||||
collectWebsiteCmd.Flags().Duration("idle-timeout", 90*time.Second, "Close idle connections after")
|
||||
collectWebsiteCmd.Flags().Int("max-idle", 100, "Max idle connections total")
|
||||
return collectWebsiteCmd
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
|
@ -14,7 +15,7 @@ import (
|
|||
func TestCollectWebsiteCmd_Good(t *testing.T) {
|
||||
// Mock the website downloader
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
|
||||
return datanode.New(), nil
|
||||
}
|
||||
defer func() {
|
||||
|
|
@ -35,7 +36,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
|
|||
func TestCollectWebsiteCmd_Bad(t *testing.T) {
|
||||
// Mock the website downloader to return an error
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
|
||||
return nil, fmt.Errorf("website error")
|
||||
}
|
||||
defer func() {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/Snider/Borg/pkg/website"
|
||||
|
|
@ -11,7 +12,7 @@ func main() {
|
|||
log.Println("Collecting website...")
|
||||
|
||||
// Download and package the website.
|
||||
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
|
||||
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil, http.DefaultClient)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to collect website: %v", err)
|
||||
}
|
||||
|
|
|
|||
74
pkg/httpclient/client.go
Normal file
74
pkg/httpclient/client.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
package httpclient
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Metrics holds the connection reuse metrics.
|
||||
type Metrics struct {
|
||||
ConnectionsReused int64
|
||||
ConnectionsCreated int64
|
||||
}
|
||||
|
||||
// Options represents the configuration for the HTTP client.
|
||||
type Options struct {
|
||||
MaxPerHost int
|
||||
MaxIdle int
|
||||
IdleTimeout time.Duration
|
||||
NoKeepAlive bool
|
||||
HTTP1 bool
|
||||
}
|
||||
|
||||
// New creates a new HTTP client with the given options.
|
||||
func New(opts Options) (*http.Client, *Metrics) {
|
||||
metrics := &Metrics{}
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
MaxIdleConns: opts.MaxIdle,
|
||||
IdleConnTimeout: opts.IdleTimeout,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
MaxConnsPerHost: opts.MaxPerHost,
|
||||
DisableKeepAlives: opts.NoKeepAlive,
|
||||
}
|
||||
|
||||
if opts.HTTP1 {
|
||||
// Disable HTTP/2 by preventing the TLS next protocol negotiation.
|
||||
transport.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper)
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &metricsRoundTripper{
|
||||
transport: transport,
|
||||
metrics: metrics,
|
||||
},
|
||||
}, metrics
|
||||
}
|
||||
|
||||
// metricsRoundTripper is a custom http.RoundTripper that collects connection reuse metrics.
|
||||
type metricsRoundTripper struct {
|
||||
transport http.RoundTripper
|
||||
metrics *Metrics
|
||||
}
|
||||
|
||||
func (m *metricsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
trace := &httptrace.ClientTrace{
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
if info.Reused {
|
||||
atomic.AddInt64(&m.metrics.ConnectionsReused, 1)
|
||||
} else {
|
||||
atomic.AddInt64(&m.metrics.ConnectionsCreated, 1)
|
||||
}
|
||||
},
|
||||
}
|
||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
||||
return m.transport.RoundTrip(req)
|
||||
}
|
||||
|
|
@ -13,7 +13,7 @@ import (
|
|||
"golang.org/x/net/html"
|
||||
)
|
||||
|
||||
var DownloadAndPackageWebsite = downloadAndPackageWebsite
|
||||
var DownloadAndPackageWebsite func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) = downloadAndPackageWebsite
|
||||
|
||||
// Downloader is a recursive website downloader.
|
||||
type Downloader struct {
|
||||
|
|
@ -43,13 +43,13 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
|||
}
|
||||
|
||||
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(startURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := NewDownloader(maxDepth)
|
||||
d := NewDownloaderWithClient(maxDepth, client)
|
||||
d.baseURL = baseURL
|
||||
d.progressBar = bar
|
||||
d.crawl(startURL, 0)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar)
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar, http.DefaultClient)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -52,7 +52,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
|
|||
|
||||
func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
||||
t.Run("Invalid Start URL", func(t *testing.T) {
|
||||
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil)
|
||||
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil, http.DefaultClient)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for an invalid start URL, but got nil")
|
||||
}
|
||||
|
|
@ -63,7 +63,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
|||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for a server error on the start URL, but got nil")
|
||||
}
|
||||
|
|
@ -80,7 +80,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
|||
}))
|
||||
defer server.Close()
|
||||
// We expect an error because the link is broken.
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for a broken link, but got nil")
|
||||
}
|
||||
|
|
@ -99,7 +99,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar, http.DefaultClient) // Max depth of 1
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -122,7 +122,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
fmt.Fprint(w, `<a href="http://externalsite.com/page.html">External</a>`)
|
||||
}))
|
||||
defer server.Close()
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -156,7 +156,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
// For now, we'll just test that it doesn't hang forever.
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
|
||||
if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
|
||||
// We expect a timeout error, but other errors are failures.
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue