Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
google-labs-jules[bot]
76a097295f feat: Add connection pooling and keep-alive
This change introduces connection pooling, keep-alive, and HTTP/2 support to the website collector. It adds a new httpclient package to create a configurable http.Client and exposes the configuration options as command-line flags. It also adds connection reuse metrics to the output of the collect website command.

Co-authored-by: Snider <631881+Snider@users.noreply.github.com>
2026-02-02 00:52:51 +00:00
6 changed files with 116 additions and 14 deletions

View file

@ -3,9 +3,11 @@ package cmd
import (
"fmt"
"os"
"time"
"github.com/schollz/progressbar/v3"
"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/httpclient"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
"github.com/Snider/Borg/pkg/ui"
@ -38,6 +40,11 @@ func NewCollectWebsiteCmd() *cobra.Command {
format, _ := cmd.Flags().GetString("format")
compression, _ := cmd.Flags().GetString("compression")
password, _ := cmd.Flags().GetString("password")
maxConnections, _ := cmd.Flags().GetInt("max-connections")
noKeepAlive, _ := cmd.Flags().GetBool("no-keepalive")
http1, _ := cmd.Flags().GetBool("http1")
idleTimeout, _ := cmd.Flags().GetDuration("idle-timeout")
maxIdle, _ := cmd.Flags().GetInt("max-idle")
if format != "datanode" && format != "tim" && format != "trix" {
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format)
@ -51,11 +58,25 @@ func NewCollectWebsiteCmd() *cobra.Command {
bar = ui.NewProgressBar(-1, "Crawling website")
}
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
// Create a new HTTP client with the specified options.
client, metrics := httpclient.New(httpclient.Options{
MaxPerHost: maxConnections,
NoKeepAlive: noKeepAlive,
HTTP1: http1,
IdleTimeout: idleTimeout,
MaxIdle: maxIdle,
})
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, client)
if err != nil {
return fmt.Errorf("error downloading and packaging website: %w", err)
}
// Display the connection reuse metrics.
fmt.Fprintln(cmd.OutOrStdout(), "Connection Metrics:")
fmt.Fprintf(cmd.OutOrStdout(), " Connections Reused: %d\n", metrics.ConnectionsReused)
fmt.Fprintf(cmd.OutOrStdout(), " Connections Created: %d\n", metrics.ConnectionsCreated)
var data []byte
if format == "tim" {
tim, err := tim.FromDataNode(dn)
@ -104,5 +125,10 @@ func NewCollectWebsiteCmd() *cobra.Command {
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
collectWebsiteCmd.Flags().Int("max-connections", 6, "Max connections per domain")
collectWebsiteCmd.Flags().Bool("no-keepalive", false, "Disable keep-alive")
collectWebsiteCmd.Flags().Bool("http1", false, "Force HTTP/1.1")
collectWebsiteCmd.Flags().Duration("idle-timeout", 90*time.Second, "Close idle connections after")
collectWebsiteCmd.Flags().Int("max-idle", 100, "Max idle connections total")
return collectWebsiteCmd
}

View file

@ -2,6 +2,7 @@ package cmd
import (
"fmt"
"net/http"
"path/filepath"
"strings"
"testing"
@ -14,7 +15,7 @@ import (
func TestCollectWebsiteCmd_Good(t *testing.T) {
// Mock the website downloader
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
return datanode.New(), nil
}
defer func() {
@ -35,7 +36,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
func TestCollectWebsiteCmd_Bad(t *testing.T) {
// Mock the website downloader to return an error
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
return nil, fmt.Errorf("website error")
}
defer func() {

View file

@ -2,6 +2,7 @@ package main
import (
"log"
"net/http"
"os"
"github.com/Snider/Borg/pkg/website"
@ -11,7 +12,7 @@ func main() {
log.Println("Collecting website...")
// Download and package the website.
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil, http.DefaultClient)
if err != nil {
log.Fatalf("Failed to collect website: %v", err)
}

74
pkg/httpclient/client.go Normal file
View file

@ -0,0 +1,74 @@
package httpclient
import (
"crypto/tls"
"net"
"net/http"
"net/http/httptrace"
"sync/atomic"
"time"
)
// Metrics holds the connection reuse metrics.
type Metrics struct {
ConnectionsReused int64
ConnectionsCreated int64
}
// Options represents the configuration for the HTTP client.
type Options struct {
MaxPerHost int
MaxIdle int
IdleTimeout time.Duration
NoKeepAlive bool
HTTP1 bool
}
// New creates a new HTTP client with the given options.
func New(opts Options) (*http.Client, *Metrics) {
metrics := &Metrics{}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: opts.MaxIdle,
IdleConnTimeout: opts.IdleTimeout,
TLSHandshakeTimeout: 10 * time.Second,
MaxConnsPerHost: opts.MaxPerHost,
DisableKeepAlives: opts.NoKeepAlive,
}
if opts.HTTP1 {
// Disable HTTP/2 by preventing the TLS next protocol negotiation.
transport.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper)
}
return &http.Client{
Transport: &metricsRoundTripper{
transport: transport,
metrics: metrics,
},
}, metrics
}
// metricsRoundTripper is a custom http.RoundTripper that collects connection reuse metrics.
type metricsRoundTripper struct {
transport http.RoundTripper
metrics *Metrics
}
func (m *metricsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
trace := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
if info.Reused {
atomic.AddInt64(&m.metrics.ConnectionsReused, 1)
} else {
atomic.AddInt64(&m.metrics.ConnectionsCreated, 1)
}
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
return m.transport.RoundTrip(req)
}

View file

@ -13,7 +13,7 @@ import (
"golang.org/x/net/html"
)
var DownloadAndPackageWebsite = downloadAndPackageWebsite
var DownloadAndPackageWebsite func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) = downloadAndPackageWebsite
// Downloader is a recursive website downloader.
type Downloader struct {
@ -43,13 +43,13 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
}
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
baseURL, err := url.Parse(startURL)
if err != nil {
return nil, err
}
d := NewDownloader(maxDepth)
d := NewDownloaderWithClient(maxDepth, client)
d.baseURL = baseURL
d.progressBar = bar
d.crawl(startURL, 0)

View file

@ -20,7 +20,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
defer server.Close()
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar)
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar, http.DefaultClient)
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
@ -52,7 +52,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
t.Run("Invalid Start URL", func(t *testing.T) {
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil)
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil, http.DefaultClient)
if err == nil {
t.Fatal("Expected an error for an invalid start URL, but got nil")
}
@ -63,7 +63,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}))
defer server.Close()
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
_, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
if err == nil {
t.Fatal("Expected an error for a server error on the start URL, but got nil")
}
@ -80,7 +80,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
}))
defer server.Close()
// We expect an error because the link is broken.
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
if err == nil {
t.Fatal("Expected an error for a broken link, but got nil")
}
@ -99,7 +99,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
defer server.Close()
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar, http.DefaultClient) // Max depth of 1
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
@ -122,7 +122,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
fmt.Fprint(w, `<a href="http://externalsite.com/page.html">External</a>`)
}))
defer server.Close()
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
@ -156,7 +156,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
// For now, we'll just test that it doesn't hang forever.
done := make(chan bool)
go func() {
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
_, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
// We expect a timeout error, but other errors are failures.
t.Errorf("unexpected error: %v", err)