From 2ff65938ca9fa277f156a3e0eff429068a78c56e Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 2 Feb 2026 00:59:38 +0000 Subject: [PATCH] feat: Circuit breaker for failing domains Implement a circuit breaker for the website collector to prevent hammering domains that are consistently failing. The circuit breaker has three states: CLOSED, OPEN, and HALF-OPEN. It tracks failures per-domain and will open the circuit after a configurable number of consecutive failures. After a cooldown period, the circuit will transition to HALF-OPEN and allow a limited number of test requests to check for recovery. The following command-line flags have been added to the `collect website` command: - `--no-circuit-breaker`: Disable the circuit breaker - `--circuit-failures`: Number of failures to trip the circuit breaker - `--circuit-cooldown`: Cooldown time for the circuit breaker - `--circuit-success-threshold`: Number of successes to close the circuit breaker - `--circuit-half-open-requests`: Number of test requests in the half-open state The implementation also includes: - A new `circuitbreaker` package with the core logic - Integration into the `website` package with per-domain tracking - Improved logging to include the domain name and state changes - Integration tests to verify the circuit breaker's behavior Co-authored-by: Snider <631881+Snider@users.noreply.github.com> --- cmd/collect_website.go | 33 ++++- cmd/collect_website_test.go | 5 +- examples/collect_website/main.go | 7 +- go.mod | 1 + pkg/circuitbreaker/circuitbreaker.go | 160 ++++++++++++++++++++++ pkg/circuitbreaker/circuitbreaker_test.go | 101 ++++++++++++++ pkg/website/website.go | 124 ++++++++++++----- pkg/website/website_test.go | 69 +++++++++- 8 files changed, 455 insertions(+), 45 deletions(-) create mode 100644 pkg/circuitbreaker/circuitbreaker.go create mode 100644 pkg/circuitbreaker/circuitbreaker_test.go diff --git a/cmd/collect_website.go b/cmd/collect_website.go index 3811f32..4affbec 100644 --- a/cmd/collect_website.go +++ b/cmd/collect_website.go @@ -3,8 +3,11 @@ package cmd import ( "fmt" "os" + "time" + "github.com/Snider/Borg/pkg/circuitbreaker" "github.com/schollz/progressbar/v3" + "golang.org/x/exp/slog" "github.com/Snider/Borg/pkg/compress" "github.com/Snider/Borg/pkg/tim" "github.com/Snider/Borg/pkg/trix" @@ -38,6 +41,11 @@ func NewCollectWebsiteCmd() *cobra.Command { format, _ := cmd.Flags().GetString("format") compression, _ := cmd.Flags().GetString("compression") password, _ := cmd.Flags().GetString("password") + noCircuitBreaker, _ := cmd.Flags().GetBool("no-circuit-breaker") + circuitFailures, _ := cmd.Flags().GetInt("circuit-failures") + circuitCooldown, _ := cmd.Flags().GetDuration("circuit-cooldown") + circuitSuccessThreshold, _ := cmd.Flags().GetInt("circuit-success-threshold") + circuitHalfOpenRequests, _ := cmd.Flags().GetInt("circuit-half-open-requests") if format != "datanode" && format != "tim" && format != "trix" { return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format) @@ -51,7 +59,25 @@ func NewCollectWebsiteCmd() *cobra.Command { bar = ui.NewProgressBar(-1, "Crawling website") } - dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + + opts := website.DownloadOptions{ + URL: websiteURL, + MaxDepth: depth, + ProgressBar: bar, + EnableCircuitBreaker: !noCircuitBreaker, + CBSettings: circuitbreaker.Settings{ + FailureThreshold: circuitFailures, + SuccessThreshold: circuitSuccessThreshold, + Cooldown: circuitCooldown, + HalfOpenRequests: circuitHalfOpenRequests, + Logger: logger, + }, + } + + dn, err := website.DownloadAndPackageWebsite(opts) if err != nil { return fmt.Errorf("error downloading and packaging website: %w", err) } @@ -104,5 +130,10 @@ func NewCollectWebsiteCmd() *cobra.Command { collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)") collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption") + collectWebsiteCmd.Flags().Bool("no-circuit-breaker", false, "Disable the circuit breaker") + collectWebsiteCmd.Flags().Int("circuit-failures", 5, "Number of failures to trip the circuit breaker") + collectWebsiteCmd.Flags().Duration("circuit-cooldown", 30*time.Second, "Cooldown time for the circuit breaker") + collectWebsiteCmd.Flags().Int("circuit-success-threshold", 2, "Number of successes to close the circuit breaker") + collectWebsiteCmd.Flags().Int("circuit-half-open-requests", 1, "Number of test requests in half-open state") return collectWebsiteCmd } diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go index 2c39674..e75fb0d 100644 --- a/cmd/collect_website_test.go +++ b/cmd/collect_website_test.go @@ -8,13 +8,12 @@ import ( "github.com/Snider/Borg/pkg/datanode" "github.com/Snider/Borg/pkg/website" - "github.com/schollz/progressbar/v3" ) func TestCollectWebsiteCmd_Good(t *testing.T) { // Mock the website downloader oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(opts website.DownloadOptions) (*datanode.DataNode, error) { return datanode.New(), nil } defer func() { @@ -35,7 +34,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) { func TestCollectWebsiteCmd_Bad(t *testing.T) { // Mock the website downloader to return an error oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(opts website.DownloadOptions) (*datanode.DataNode, error) { return nil, fmt.Errorf("website error") } defer func() { diff --git a/examples/collect_website/main.go b/examples/collect_website/main.go index 2e2f606..66fe70d 100644 --- a/examples/collect_website/main.go +++ b/examples/collect_website/main.go @@ -11,7 +11,12 @@ func main() { log.Println("Collecting website...") // Download and package the website. - dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil) + opts := website.DownloadOptions{ + URL: "https://example.com", + MaxDepth: 2, + EnableCircuitBreaker: true, + } + dn, err := website.DownloadAndPackageWebsite(opts) if err != nil { log.Fatalf("Failed to collect website: %v", err) } diff --git a/go.mod b/go.mod index d1c5f08..2e6e7cf 100644 --- a/go.mod +++ b/go.mod @@ -61,6 +61,7 @@ require ( github.com/wailsapp/mimetype v1.4.1 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect golang.org/x/crypto v0.44.0 // indirect + golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect diff --git a/pkg/circuitbreaker/circuitbreaker.go b/pkg/circuitbreaker/circuitbreaker.go new file mode 100644 index 0000000..bbf3cf7 --- /dev/null +++ b/pkg/circuitbreaker/circuitbreaker.go @@ -0,0 +1,160 @@ + +package circuitbreaker + +import ( + "fmt" + "io" + "sync" + "time" + + "golang.org/x/exp/slog" +) + +// State represents the state of the circuit breaker. +type State int + +const ( + // ClosedState is the initial state of the circuit breaker. + ClosedState State = iota + // OpenState is the state when the circuit breaker is tripped. + OpenState + // HalfOpenState is the state when the circuit breaker is testing for recovery. + HalfOpenState +) + +// String returns the string representation of the state. +func (s State) String() string { + switch s { + case ClosedState: + return "CLOSED" + case OpenState: + return "OPEN" + case HalfOpenState: + return "HALF-OPEN" + default: + return "UNKNOWN" + } +} + +// Settings configures the circuit breaker. +type Settings struct { + // FailureThreshold is the number of consecutive failures before opening the circuit. + FailureThreshold int + // SuccessThreshold is the number of consecutive successes to close the circuit. + SuccessThreshold int + // Cooldown is the time to wait in the open state before transitioning to half-open. + Cooldown time.Duration + // HalfOpenRequests is the number of test requests to allow in the half-open state. + HalfOpenRequests int + // Logger is the logger to use for state changes. + Logger *slog.Logger +} + +// CircuitBreaker is a state machine that prevents repeated calls to a failing service. +type CircuitBreaker struct { + settings Settings + domain string + mu sync.Mutex + state State + failures int + successes int + halfOpenRequests int + lastError error + expiry time.Time +} + +// New creates a new CircuitBreaker. +func New(domain string, settings Settings) *CircuitBreaker { + if settings.Logger == nil { + settings.Logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ + Level: slog.LevelError, + })) + } + return &CircuitBreaker{ + settings: settings, + domain: domain, + state: ClosedState, + } +} + +// Execute runs the given function, protected by the circuit breaker. +func (cb *CircuitBreaker) Execute(fn func() (interface{}, error)) (interface{}, error) { + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case OpenState: + if time.Now().After(cb.expiry) { + cb.setState(HalfOpenState) + return cb.executeHalfOpen(fn) + } + return nil, fmt.Errorf("circuit is open for: %w", cb.lastError) + case HalfOpenState: + return cb.executeHalfOpen(fn) + default: // ClosedState + return cb.executeClosed(fn) + } +} + +func (cb *CircuitBreaker) executeClosed(fn func() (interface{}, error)) (interface{}, error) { + res, err := fn() + if err != nil { + cb.failures++ + if cb.failures >= cb.settings.FailureThreshold { + cb.lastError = err + cb.setState(OpenState) + } + return nil, err + } + + cb.failures = 0 + return res, nil +} + +func (cb *CircuitBreaker) executeHalfOpen(fn func() (interface{}, error)) (interface{}, error) { + if cb.halfOpenRequests >= cb.settings.HalfOpenRequests { + return nil, fmt.Errorf("circuit is half-open, test requests exhausted: %w", cb.lastError) + } + cb.halfOpenRequests++ + + res, err := fn() + if err != nil { + cb.lastError = err + cb.setState(OpenState) + return nil, err + } + + cb.successes++ + // If enough test requests succeed, close the circuit. + if cb.successes >= cb.settings.SuccessThreshold { + cb.setState(ClosedState) + } + + return res, nil +} + +func (cb *CircuitBreaker) setState(state State) { + if cb.state == state { + return + } + + cb.state = state + logMessage := fmt.Sprintf("Circuit %s for %s", state, cb.domain) + if state == OpenState { + cb.settings.Logger.Warn(logMessage) + } else { + cb.settings.Logger.Info(logMessage) + } + + switch state { + case OpenState: + cb.expiry = time.Now().Add(cb.settings.Cooldown) + cb.successes = 0 + case ClosedState: + cb.failures = 0 + cb.successes = 0 + case HalfOpenState: + cb.failures = 0 + cb.halfOpenRequests = 0 + } +} diff --git a/pkg/circuitbreaker/circuitbreaker_test.go b/pkg/circuitbreaker/circuitbreaker_test.go new file mode 100644 index 0000000..3b5643b --- /dev/null +++ b/pkg/circuitbreaker/circuitbreaker_test.go @@ -0,0 +1,101 @@ + +package circuitbreaker + +import ( + "errors" + "testing" + "time" +) + +func TestCircuitBreaker(t *testing.T) { + settings := Settings{ + FailureThreshold: 2, + SuccessThreshold: 2, + Cooldown: 100 * time.Millisecond, + HalfOpenRequests: 2, + } + cb := New("test.com", settings) + + // Initially closed + _, err := cb.Execute(func() (interface{}, error) { + return "success", nil + }) + if err != nil { + t.Errorf("Expected success, got %v", err) + } + + // Trip the breaker + _, err = cb.Execute(func() (interface{}, error) { + return nil, errors.New("failure 1") + }) + if err == nil { + t.Error("Expected failure, got nil") + } + _, err = cb.Execute(func() (interface{}, error) { + return nil, errors.New("failure 2") + }) + if err == nil { + t.Error("Expected failure, got nil") + } + + // Now open + _, err = cb.Execute(func() (interface{}, error) { + return "should not be called", nil + }) + if err == nil || err.Error() != "circuit is open for: failure 2" { + t.Errorf("Expected open circuit error, got %v", err) + } + + // Wait for cooldown + time.Sleep(150 * time.Millisecond) + + // Half-open, should succeed + _, err = cb.Execute(func() (interface{}, error) { + return "success", nil + }) + if err != nil { + t.Errorf("Expected success in half-open, got %v", err) + } + + // Still half-open, need another success + _, err = cb.Execute(func() (interface{}, error) { + return "success", nil + }) + if err != nil { + t.Errorf("Expected success in half-open, got %v", err) + } + + // Now closed again + _, err = cb.Execute(func() (interface{}, error) { + return "success", nil + }) + if err != nil { + t.Errorf("Expected success in closed state, got %v", err) + } + + // Trip again to test half-open failure + cb.Execute(func() (interface{}, error) { + return nil, errors.New("failure 1") + }) + cb.Execute(func() (interface{}, error) { + return nil, errors.New("failure 2") + }) + + time.Sleep(150 * time.Millisecond) + + // Half-open, but fail + _, err = cb.Execute(func() (interface{}, error) { + return nil, errors.New("half-open failure") + }) + if err == nil { + t.Error("Expected failure in half-open, got nil") + } + + // Should be open again + _, err = cb.Execute(func() (interface{}, error) { + return "should not be called", nil + }) + if err == nil || err.Error() != "circuit is open for: half-open failure" { + t.Errorf("Expected open circuit error, got %v", err) + } +} diff --git a/pkg/website/website.go b/pkg/website/website.go index b2bd517..caefe5a 100644 --- a/pkg/website/website.go +++ b/pkg/website/website.go @@ -6,53 +6,70 @@ import ( "net/http" "net/url" "strings" + "sync" + "github.com/Snider/Borg/pkg/circuitbreaker" "github.com/Snider/Borg/pkg/datanode" "github.com/schollz/progressbar/v3" - "golang.org/x/net/html" ) var DownloadAndPackageWebsite = downloadAndPackageWebsite +// DownloadOptions configures the website downloader. +type DownloadOptions struct { + URL string + MaxDepth int + ProgressBar *progressbar.ProgressBar + EnableCircuitBreaker bool + CBSettings circuitbreaker.Settings +} + // Downloader is a recursive website downloader. type Downloader struct { - baseURL *url.URL - dn *datanode.DataNode - visited map[string]bool - maxDepth int - progressBar *progressbar.ProgressBar - client *http.Client - errors []error + baseURL *url.URL + dn *datanode.DataNode + visited map[string]bool + maxDepth int + progressBar *progressbar.ProgressBar + client *http.Client + errors []error + cbEnabled bool + cbSettings circuitbreaker.Settings + circuitBreakers map[string]*circuitbreaker.CircuitBreaker + cbMutex sync.Mutex } // NewDownloader creates a new Downloader. -func NewDownloader(maxDepth int) *Downloader { - return NewDownloaderWithClient(maxDepth, http.DefaultClient) +func NewDownloader(opts DownloadOptions) *Downloader { + return NewDownloaderWithClient(opts, http.DefaultClient) } // NewDownloaderWithClient creates a new Downloader with a custom http.Client. -func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { +func NewDownloaderWithClient(opts DownloadOptions, client *http.Client) *Downloader { return &Downloader{ - dn: datanode.New(), - visited: make(map[string]bool), - maxDepth: maxDepth, - client: client, - errors: make([]error, 0), + dn: datanode.New(), + visited: make(map[string]bool), + maxDepth: opts.MaxDepth, + client: client, + errors: make([]error, 0), + cbEnabled: opts.EnableCircuitBreaker, + cbSettings: opts.CBSettings, + circuitBreakers: make(map[string]*circuitbreaker.CircuitBreaker), } } // downloadAndPackageWebsite downloads a website and packages it into a DataNode. -func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { - baseURL, err := url.Parse(startURL) +func downloadAndPackageWebsite(opts DownloadOptions) (*datanode.DataNode, error) { + baseURL, err := url.Parse(opts.URL) if err != nil { return nil, err } - d := NewDownloader(maxDepth) + d := NewDownloader(opts) d.baseURL = baseURL - d.progressBar = bar - d.crawl(startURL, 0) + d.progressBar = opts.ProgressBar + d.crawl(opts.URL, 0) if len(d.errors) > 0 { var errs []string @@ -65,6 +82,57 @@ func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P return d.dn, nil } +func (d *Downloader) getCircuitBreaker(host string) *circuitbreaker.CircuitBreaker { + d.cbMutex.Lock() + defer d.cbMutex.Unlock() + + if cb, ok := d.circuitBreakers[host]; ok { + return cb + } + + cb := circuitbreaker.New(host, d.cbSettings) + d.circuitBreakers[host] = cb + return cb +} + +func (d *Downloader) fetchURL(pageURL string) (*http.Response, error) { + if !d.cbEnabled { + resp, err := d.client.Get(pageURL) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + resp.Body.Close() + return nil, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status) + } + return resp, nil + } + + u, err := url.Parse(pageURL) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + cb := d.getCircuitBreaker(u.Hostname()) + result, err := cb.Execute(func() (interface{}, error) { + resp, err := d.client.Get(pageURL) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + resp.Body.Close() + return nil, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status) + } + return resp, nil + }) + + if err != nil { + return nil, err + } + + return result.(*http.Response), nil +} + func (d *Downloader) crawl(pageURL string, depth int) { if depth > d.maxDepth || d.visited[pageURL] { return @@ -74,18 +142,13 @@ func (d *Downloader) crawl(pageURL string, depth int) { d.progressBar.Add(1) } - resp, err := d.client.Get(pageURL) + resp, err := d.fetchURL(pageURL) if err != nil { d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err)) return } defer resp.Body.Close() - if resp.StatusCode >= 400 { - d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status)) - return - } - body, err := io.ReadAll(resp.Body) if err != nil { d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err)) @@ -141,18 +204,13 @@ func (d *Downloader) downloadAsset(assetURL string) { d.progressBar.Add(1) } - resp, err := d.client.Get(assetURL) + resp, err := d.fetchURL(assetURL) if err != nil { d.errors = append(d.errors, fmt.Errorf("Error getting asset %s: %w", assetURL, err)) return } defer resp.Body.Close() - if resp.StatusCode >= 400 { - d.errors = append(d.errors, fmt.Errorf("bad status for asset %s: %s", assetURL, resp.Status)) - return - } - body, err := io.ReadAll(resp.Body) if err != nil { d.errors = append(d.errors, fmt.Errorf("Error reading body of asset %s: %w", assetURL, err)) diff --git a/pkg/website/website_test.go b/pkg/website/website_test.go index d3685e5..437bc01 100644 --- a/pkg/website/website_test.go +++ b/pkg/website/website_test.go @@ -7,9 +7,11 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "time" + "github.com/Snider/Borg/pkg/circuitbreaker" "github.com/schollz/progressbar/v3" ) @@ -20,7 +22,12 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 2, bar) + opts := DownloadOptions{ + URL: server.URL, + MaxDepth: 2, + ProgressBar: bar, + } + dn, err := DownloadAndPackageWebsite(opts) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -52,7 +59,8 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { func TestDownloadAndPackageWebsite_Bad(t *testing.T) { t.Run("Invalid Start URL", func(t *testing.T) { - _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil) + opts := DownloadOptions{URL: "http://invalid-url", MaxDepth: 1} + _, err := DownloadAndPackageWebsite(opts) if err == nil { t.Fatal("Expected an error for an invalid start URL, but got nil") } @@ -63,7 +71,8 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { http.Error(w, "Internal Server Error", http.StatusInternalServerError) })) defer server.Close() - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + opts := DownloadOptions{URL: server.URL, MaxDepth: 1} + _, err := DownloadAndPackageWebsite(opts) if err == nil { t.Fatal("Expected an error for a server error on the start URL, but got nil") } @@ -80,7 +89,8 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { })) defer server.Close() // We expect an error because the link is broken. - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + opts := DownloadOptions{URL: server.URL, MaxDepth: 1} + dn, err := DownloadAndPackageWebsite(opts) if err == nil { t.Fatal("Expected an error for a broken link, but got nil") } @@ -99,7 +109,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1 + opts := DownloadOptions{URL: server.URL, MaxDepth: 1, ProgressBar: bar} + dn, err := DownloadAndPackageWebsite(opts) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -122,7 +133,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { fmt.Fprint(w, `External`) })) defer server.Close() - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + opts := DownloadOptions{URL: server.URL, MaxDepth: 1} + dn, err := DownloadAndPackageWebsite(opts) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -156,7 +168,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { // For now, we'll just test that it doesn't hang forever. done := make(chan bool) go func() { - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + opts := DownloadOptions{URL: server.URL, MaxDepth: 1} + _, err := DownloadAndPackageWebsite(opts) if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") { // We expect a timeout error, but other errors are failures. t.Errorf("unexpected error: %v", err) @@ -170,6 +183,48 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { t.Fatal("Test timed out") } }) + + t.Run("Circuit Breaker", func(t *testing.T) { + var requestCount int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + atomic.AddInt32(&requestCount, 1) + } + if r.URL.Path == "/" { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, ` + Fail 1 + Fail 2 + Fail 3 + `) + } else { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + })) + defer server.Close() + + opts := DownloadOptions{ + URL: server.URL, + MaxDepth: 1, + EnableCircuitBreaker: true, + CBSettings: circuitbreaker.Settings{ + FailureThreshold: 2, + Cooldown: 1 * time.Second, + }, + } + _, err := DownloadAndPackageWebsite(opts) + if err == nil { + t.Fatal("Expected an error, but got nil") + } + + if count := atomic.LoadInt32(&requestCount); count != 2 { + t.Errorf("Expected 2 requests to failing paths, but got %d", count) + } + + if !strings.Contains(err.Error(), "circuit is open") { + t.Errorf("Expected error to contain 'circuit is open', but got: %v", err) + } + }) } // --- Helpers ---