diff --git a/cmd/collect_website.go b/cmd/collect_website.go
index 3811f32..4affbec 100644
--- a/cmd/collect_website.go
+++ b/cmd/collect_website.go
@@ -3,8 +3,11 @@ package cmd
import (
"fmt"
"os"
+ "time"
+ "github.com/Snider/Borg/pkg/circuitbreaker"
"github.com/schollz/progressbar/v3"
+ "golang.org/x/exp/slog"
"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
@@ -38,6 +41,11 @@ func NewCollectWebsiteCmd() *cobra.Command {
format, _ := cmd.Flags().GetString("format")
compression, _ := cmd.Flags().GetString("compression")
password, _ := cmd.Flags().GetString("password")
+ noCircuitBreaker, _ := cmd.Flags().GetBool("no-circuit-breaker")
+ circuitFailures, _ := cmd.Flags().GetInt("circuit-failures")
+ circuitCooldown, _ := cmd.Flags().GetDuration("circuit-cooldown")
+ circuitSuccessThreshold, _ := cmd.Flags().GetInt("circuit-success-threshold")
+ circuitHalfOpenRequests, _ := cmd.Flags().GetInt("circuit-half-open-requests")
if format != "datanode" && format != "tim" && format != "trix" {
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format)
@@ -51,7 +59,25 @@ func NewCollectWebsiteCmd() *cobra.Command {
bar = ui.NewProgressBar(-1, "Crawling website")
}
- dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
+ Level: slog.LevelInfo,
+ }))
+
+ opts := website.DownloadOptions{
+ URL: websiteURL,
+ MaxDepth: depth,
+ ProgressBar: bar,
+ EnableCircuitBreaker: !noCircuitBreaker,
+ CBSettings: circuitbreaker.Settings{
+ FailureThreshold: circuitFailures,
+ SuccessThreshold: circuitSuccessThreshold,
+ Cooldown: circuitCooldown,
+ HalfOpenRequests: circuitHalfOpenRequests,
+ Logger: logger,
+ },
+ }
+
+ dn, err := website.DownloadAndPackageWebsite(opts)
if err != nil {
return fmt.Errorf("error downloading and packaging website: %w", err)
}
@@ -104,5 +130,10 @@ func NewCollectWebsiteCmd() *cobra.Command {
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
+ collectWebsiteCmd.Flags().Bool("no-circuit-breaker", false, "Disable the circuit breaker")
+ collectWebsiteCmd.Flags().Int("circuit-failures", 5, "Number of failures to trip the circuit breaker")
+ collectWebsiteCmd.Flags().Duration("circuit-cooldown", 30*time.Second, "Cooldown time for the circuit breaker")
+ collectWebsiteCmd.Flags().Int("circuit-success-threshold", 2, "Number of successes to close the circuit breaker")
+ collectWebsiteCmd.Flags().Int("circuit-half-open-requests", 1, "Number of test requests in half-open state")
return collectWebsiteCmd
}
diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go
index 2c39674..e75fb0d 100644
--- a/cmd/collect_website_test.go
+++ b/cmd/collect_website_test.go
@@ -8,13 +8,12 @@ import (
"github.com/Snider/Borg/pkg/datanode"
"github.com/Snider/Borg/pkg/website"
- "github.com/schollz/progressbar/v3"
)
func TestCollectWebsiteCmd_Good(t *testing.T) {
// Mock the website downloader
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
- website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
+ website.DownloadAndPackageWebsite = func(opts website.DownloadOptions) (*datanode.DataNode, error) {
return datanode.New(), nil
}
defer func() {
@@ -35,7 +34,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
func TestCollectWebsiteCmd_Bad(t *testing.T) {
// Mock the website downloader to return an error
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
- website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
+ website.DownloadAndPackageWebsite = func(opts website.DownloadOptions) (*datanode.DataNode, error) {
return nil, fmt.Errorf("website error")
}
defer func() {
diff --git a/examples/collect_website/main.go b/examples/collect_website/main.go
index 2e2f606..66fe70d 100644
--- a/examples/collect_website/main.go
+++ b/examples/collect_website/main.go
@@ -11,7 +11,12 @@ func main() {
log.Println("Collecting website...")
// Download and package the website.
- dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
+ opts := website.DownloadOptions{
+ URL: "https://example.com",
+ MaxDepth: 2,
+ EnableCircuitBreaker: true,
+ }
+ dn, err := website.DownloadAndPackageWebsite(opts)
if err != nil {
log.Fatalf("Failed to collect website: %v", err)
}
diff --git a/go.mod b/go.mod
index d1c5f08..2e6e7cf 100644
--- a/go.mod
+++ b/go.mod
@@ -61,6 +61,7 @@ require (
github.com/wailsapp/mimetype v1.4.1 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect
golang.org/x/crypto v0.44.0 // indirect
+ golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/term v0.37.0 // indirect
golang.org/x/text v0.31.0 // indirect
diff --git a/pkg/circuitbreaker/circuitbreaker.go b/pkg/circuitbreaker/circuitbreaker.go
new file mode 100644
index 0000000..bbf3cf7
--- /dev/null
+++ b/pkg/circuitbreaker/circuitbreaker.go
@@ -0,0 +1,160 @@
+
+package circuitbreaker
+
+import (
+ "fmt"
+ "io"
+ "sync"
+ "time"
+
+ "golang.org/x/exp/slog"
+)
+
+// State represents the state of the circuit breaker.
+type State int
+
+const (
+ // ClosedState is the initial state of the circuit breaker.
+ ClosedState State = iota
+ // OpenState is the state when the circuit breaker is tripped.
+ OpenState
+ // HalfOpenState is the state when the circuit breaker is testing for recovery.
+ HalfOpenState
+)
+
+// String returns the string representation of the state.
+func (s State) String() string {
+ switch s {
+ case ClosedState:
+ return "CLOSED"
+ case OpenState:
+ return "OPEN"
+ case HalfOpenState:
+ return "HALF-OPEN"
+ default:
+ return "UNKNOWN"
+ }
+}
+
+// Settings configures the circuit breaker.
+type Settings struct {
+ // FailureThreshold is the number of consecutive failures before opening the circuit.
+ FailureThreshold int
+ // SuccessThreshold is the number of consecutive successes to close the circuit.
+ SuccessThreshold int
+ // Cooldown is the time to wait in the open state before transitioning to half-open.
+ Cooldown time.Duration
+ // HalfOpenRequests is the number of test requests to allow in the half-open state.
+ HalfOpenRequests int
+ // Logger is the logger to use for state changes.
+ Logger *slog.Logger
+}
+
+// CircuitBreaker is a state machine that prevents repeated calls to a failing service.
+type CircuitBreaker struct {
+ settings Settings
+ domain string
+ mu sync.Mutex
+ state State
+ failures int
+ successes int
+ halfOpenRequests int
+ lastError error
+ expiry time.Time
+}
+
+// New creates a new CircuitBreaker.
+func New(domain string, settings Settings) *CircuitBreaker {
+ if settings.Logger == nil {
+ settings.Logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{
+ Level: slog.LevelError,
+ }))
+ }
+ return &CircuitBreaker{
+ settings: settings,
+ domain: domain,
+ state: ClosedState,
+ }
+}
+
+// Execute runs the given function, protected by the circuit breaker.
+func (cb *CircuitBreaker) Execute(fn func() (interface{}, error)) (interface{}, error) {
+ cb.mu.Lock()
+ defer cb.mu.Unlock()
+
+ switch cb.state {
+ case OpenState:
+ if time.Now().After(cb.expiry) {
+ cb.setState(HalfOpenState)
+ return cb.executeHalfOpen(fn)
+ }
+ return nil, fmt.Errorf("circuit is open for: %w", cb.lastError)
+ case HalfOpenState:
+ return cb.executeHalfOpen(fn)
+ default: // ClosedState
+ return cb.executeClosed(fn)
+ }
+}
+
+func (cb *CircuitBreaker) executeClosed(fn func() (interface{}, error)) (interface{}, error) {
+ res, err := fn()
+ if err != nil {
+ cb.failures++
+ if cb.failures >= cb.settings.FailureThreshold {
+ cb.lastError = err
+ cb.setState(OpenState)
+ }
+ return nil, err
+ }
+
+ cb.failures = 0
+ return res, nil
+}
+
+func (cb *CircuitBreaker) executeHalfOpen(fn func() (interface{}, error)) (interface{}, error) {
+ if cb.halfOpenRequests >= cb.settings.HalfOpenRequests {
+ return nil, fmt.Errorf("circuit is half-open, test requests exhausted: %w", cb.lastError)
+ }
+ cb.halfOpenRequests++
+
+ res, err := fn()
+ if err != nil {
+ cb.lastError = err
+ cb.setState(OpenState)
+ return nil, err
+ }
+
+ cb.successes++
+ // If enough test requests succeed, close the circuit.
+ if cb.successes >= cb.settings.SuccessThreshold {
+ cb.setState(ClosedState)
+ }
+
+ return res, nil
+}
+
+func (cb *CircuitBreaker) setState(state State) {
+ if cb.state == state {
+ return
+ }
+
+ cb.state = state
+ logMessage := fmt.Sprintf("Circuit %s for %s", state, cb.domain)
+ if state == OpenState {
+ cb.settings.Logger.Warn(logMessage)
+ } else {
+ cb.settings.Logger.Info(logMessage)
+ }
+
+ switch state {
+ case OpenState:
+ cb.expiry = time.Now().Add(cb.settings.Cooldown)
+ cb.successes = 0
+ case ClosedState:
+ cb.failures = 0
+ cb.successes = 0
+ case HalfOpenState:
+ cb.failures = 0
+ cb.halfOpenRequests = 0
+ }
+}
diff --git a/pkg/circuitbreaker/circuitbreaker_test.go b/pkg/circuitbreaker/circuitbreaker_test.go
new file mode 100644
index 0000000..3b5643b
--- /dev/null
+++ b/pkg/circuitbreaker/circuitbreaker_test.go
@@ -0,0 +1,101 @@
+
+package circuitbreaker
+
+import (
+ "errors"
+ "testing"
+ "time"
+)
+
+func TestCircuitBreaker(t *testing.T) {
+ settings := Settings{
+ FailureThreshold: 2,
+ SuccessThreshold: 2,
+ Cooldown: 100 * time.Millisecond,
+ HalfOpenRequests: 2,
+ }
+ cb := New("test.com", settings)
+
+ // Initially closed
+ _, err := cb.Execute(func() (interface{}, error) {
+ return "success", nil
+ })
+ if err != nil {
+ t.Errorf("Expected success, got %v", err)
+ }
+
+ // Trip the breaker
+ _, err = cb.Execute(func() (interface{}, error) {
+ return nil, errors.New("failure 1")
+ })
+ if err == nil {
+ t.Error("Expected failure, got nil")
+ }
+ _, err = cb.Execute(func() (interface{}, error) {
+ return nil, errors.New("failure 2")
+ })
+ if err == nil {
+ t.Error("Expected failure, got nil")
+ }
+
+ // Now open
+ _, err = cb.Execute(func() (interface{}, error) {
+ return "should not be called", nil
+ })
+ if err == nil || err.Error() != "circuit is open for: failure 2" {
+ t.Errorf("Expected open circuit error, got %v", err)
+ }
+
+ // Wait for cooldown
+ time.Sleep(150 * time.Millisecond)
+
+ // Half-open, should succeed
+ _, err = cb.Execute(func() (interface{}, error) {
+ return "success", nil
+ })
+ if err != nil {
+ t.Errorf("Expected success in half-open, got %v", err)
+ }
+
+ // Still half-open, need another success
+ _, err = cb.Execute(func() (interface{}, error) {
+ return "success", nil
+ })
+ if err != nil {
+ t.Errorf("Expected success in half-open, got %v", err)
+ }
+
+ // Now closed again
+ _, err = cb.Execute(func() (interface{}, error) {
+ return "success", nil
+ })
+ if err != nil {
+ t.Errorf("Expected success in closed state, got %v", err)
+ }
+
+ // Trip again to test half-open failure
+ cb.Execute(func() (interface{}, error) {
+ return nil, errors.New("failure 1")
+ })
+ cb.Execute(func() (interface{}, error) {
+ return nil, errors.New("failure 2")
+ })
+
+ time.Sleep(150 * time.Millisecond)
+
+ // Half-open, but fail
+ _, err = cb.Execute(func() (interface{}, error) {
+ return nil, errors.New("half-open failure")
+ })
+ if err == nil {
+ t.Error("Expected failure in half-open, got nil")
+ }
+
+ // Should be open again
+ _, err = cb.Execute(func() (interface{}, error) {
+ return "should not be called", nil
+ })
+ if err == nil || err.Error() != "circuit is open for: half-open failure" {
+ t.Errorf("Expected open circuit error, got %v", err)
+ }
+}
diff --git a/pkg/website/website.go b/pkg/website/website.go
index b2bd517..caefe5a 100644
--- a/pkg/website/website.go
+++ b/pkg/website/website.go
@@ -6,53 +6,70 @@ import (
"net/http"
"net/url"
"strings"
+ "sync"
+ "github.com/Snider/Borg/pkg/circuitbreaker"
"github.com/Snider/Borg/pkg/datanode"
"github.com/schollz/progressbar/v3"
-
"golang.org/x/net/html"
)
var DownloadAndPackageWebsite = downloadAndPackageWebsite
+// DownloadOptions configures the website downloader.
+type DownloadOptions struct {
+ URL string
+ MaxDepth int
+ ProgressBar *progressbar.ProgressBar
+ EnableCircuitBreaker bool
+ CBSettings circuitbreaker.Settings
+}
+
// Downloader is a recursive website downloader.
type Downloader struct {
- baseURL *url.URL
- dn *datanode.DataNode
- visited map[string]bool
- maxDepth int
- progressBar *progressbar.ProgressBar
- client *http.Client
- errors []error
+ baseURL *url.URL
+ dn *datanode.DataNode
+ visited map[string]bool
+ maxDepth int
+ progressBar *progressbar.ProgressBar
+ client *http.Client
+ errors []error
+ cbEnabled bool
+ cbSettings circuitbreaker.Settings
+ circuitBreakers map[string]*circuitbreaker.CircuitBreaker
+ cbMutex sync.Mutex
}
// NewDownloader creates a new Downloader.
-func NewDownloader(maxDepth int) *Downloader {
- return NewDownloaderWithClient(maxDepth, http.DefaultClient)
+func NewDownloader(opts DownloadOptions) *Downloader {
+ return NewDownloaderWithClient(opts, http.DefaultClient)
}
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
-func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
+func NewDownloaderWithClient(opts DownloadOptions, client *http.Client) *Downloader {
return &Downloader{
- dn: datanode.New(),
- visited: make(map[string]bool),
- maxDepth: maxDepth,
- client: client,
- errors: make([]error, 0),
+ dn: datanode.New(),
+ visited: make(map[string]bool),
+ maxDepth: opts.MaxDepth,
+ client: client,
+ errors: make([]error, 0),
+ cbEnabled: opts.EnableCircuitBreaker,
+ cbSettings: opts.CBSettings,
+ circuitBreakers: make(map[string]*circuitbreaker.CircuitBreaker),
}
}
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
-func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
- baseURL, err := url.Parse(startURL)
+func downloadAndPackageWebsite(opts DownloadOptions) (*datanode.DataNode, error) {
+ baseURL, err := url.Parse(opts.URL)
if err != nil {
return nil, err
}
- d := NewDownloader(maxDepth)
+ d := NewDownloader(opts)
d.baseURL = baseURL
- d.progressBar = bar
- d.crawl(startURL, 0)
+ d.progressBar = opts.ProgressBar
+ d.crawl(opts.URL, 0)
if len(d.errors) > 0 {
var errs []string
@@ -65,6 +82,57 @@ func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P
return d.dn, nil
}
+func (d *Downloader) getCircuitBreaker(host string) *circuitbreaker.CircuitBreaker {
+ d.cbMutex.Lock()
+ defer d.cbMutex.Unlock()
+
+ if cb, ok := d.circuitBreakers[host]; ok {
+ return cb
+ }
+
+ cb := circuitbreaker.New(host, d.cbSettings)
+ d.circuitBreakers[host] = cb
+ return cb
+}
+
+func (d *Downloader) fetchURL(pageURL string) (*http.Response, error) {
+ if !d.cbEnabled {
+ resp, err := d.client.Get(pageURL)
+ if err != nil {
+ return nil, err
+ }
+ if resp.StatusCode >= 400 {
+ resp.Body.Close()
+ return nil, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status)
+ }
+ return resp, nil
+ }
+
+ u, err := url.Parse(pageURL)
+ if err != nil {
+ return nil, fmt.Errorf("invalid URL: %w", err)
+ }
+
+ cb := d.getCircuitBreaker(u.Hostname())
+ result, err := cb.Execute(func() (interface{}, error) {
+ resp, err := d.client.Get(pageURL)
+ if err != nil {
+ return nil, err
+ }
+ if resp.StatusCode >= 400 {
+ resp.Body.Close()
+ return nil, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status)
+ }
+ return resp, nil
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ return result.(*http.Response), nil
+}
+
func (d *Downloader) crawl(pageURL string, depth int) {
if depth > d.maxDepth || d.visited[pageURL] {
return
@@ -74,18 +142,13 @@ func (d *Downloader) crawl(pageURL string, depth int) {
d.progressBar.Add(1)
}
- resp, err := d.client.Get(pageURL)
+ resp, err := d.fetchURL(pageURL)
if err != nil {
d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err))
return
}
defer resp.Body.Close()
- if resp.StatusCode >= 400 {
- d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status))
- return
- }
-
body, err := io.ReadAll(resp.Body)
if err != nil {
d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err))
@@ -141,18 +204,13 @@ func (d *Downloader) downloadAsset(assetURL string) {
d.progressBar.Add(1)
}
- resp, err := d.client.Get(assetURL)
+ resp, err := d.fetchURL(assetURL)
if err != nil {
d.errors = append(d.errors, fmt.Errorf("Error getting asset %s: %w", assetURL, err))
return
}
defer resp.Body.Close()
- if resp.StatusCode >= 400 {
- d.errors = append(d.errors, fmt.Errorf("bad status for asset %s: %s", assetURL, resp.Status))
- return
- }
-
body, err := io.ReadAll(resp.Body)
if err != nil {
d.errors = append(d.errors, fmt.Errorf("Error reading body of asset %s: %w", assetURL, err))
diff --git a/pkg/website/website_test.go b/pkg/website/website_test.go
index d3685e5..437bc01 100644
--- a/pkg/website/website_test.go
+++ b/pkg/website/website_test.go
@@ -7,9 +7,11 @@ import (
"net/http"
"net/http/httptest"
"strings"
+ "sync/atomic"
"testing"
"time"
+ "github.com/Snider/Borg/pkg/circuitbreaker"
"github.com/schollz/progressbar/v3"
)
@@ -20,7 +22,12 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
defer server.Close()
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
- dn, err := DownloadAndPackageWebsite(server.URL, 2, bar)
+ opts := DownloadOptions{
+ URL: server.URL,
+ MaxDepth: 2,
+ ProgressBar: bar,
+ }
+ dn, err := DownloadAndPackageWebsite(opts)
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
@@ -52,7 +59,8 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
t.Run("Invalid Start URL", func(t *testing.T) {
- _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil)
+ opts := DownloadOptions{URL: "http://invalid-url", MaxDepth: 1}
+ _, err := DownloadAndPackageWebsite(opts)
if err == nil {
t.Fatal("Expected an error for an invalid start URL, but got nil")
}
@@ -63,7 +71,8 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}))
defer server.Close()
- _, err := DownloadAndPackageWebsite(server.URL, 1, nil)
+ opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
+ _, err := DownloadAndPackageWebsite(opts)
if err == nil {
t.Fatal("Expected an error for a server error on the start URL, but got nil")
}
@@ -80,7 +89,8 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
}))
defer server.Close()
// We expect an error because the link is broken.
- dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
+ opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
+ dn, err := DownloadAndPackageWebsite(opts)
if err == nil {
t.Fatal("Expected an error for a broken link, but got nil")
}
@@ -99,7 +109,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
defer server.Close()
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
- dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1
+ opts := DownloadOptions{URL: server.URL, MaxDepth: 1, ProgressBar: bar}
+ dn, err := DownloadAndPackageWebsite(opts)
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
@@ -122,7 +133,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
fmt.Fprint(w, `External`)
}))
defer server.Close()
- dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
+ opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
+ dn, err := DownloadAndPackageWebsite(opts)
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
@@ -156,7 +168,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
// For now, we'll just test that it doesn't hang forever.
done := make(chan bool)
go func() {
- _, err := DownloadAndPackageWebsite(server.URL, 1, nil)
+ opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
+ _, err := DownloadAndPackageWebsite(opts)
if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
// We expect a timeout error, but other errors are failures.
t.Errorf("unexpected error: %v", err)
@@ -170,6 +183,48 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
t.Fatal("Test timed out")
}
})
+
+ t.Run("Circuit Breaker", func(t *testing.T) {
+ var requestCount int32
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/" {
+ atomic.AddInt32(&requestCount, 1)
+ }
+ if r.URL.Path == "/" {
+ w.Header().Set("Content-Type", "text/html")
+ fmt.Fprint(w, `
+ Fail 1
+ Fail 2
+ Fail 3
+ `)
+ } else {
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ }
+ }))
+ defer server.Close()
+
+ opts := DownloadOptions{
+ URL: server.URL,
+ MaxDepth: 1,
+ EnableCircuitBreaker: true,
+ CBSettings: circuitbreaker.Settings{
+ FailureThreshold: 2,
+ Cooldown: 1 * time.Second,
+ },
+ }
+ _, err := DownloadAndPackageWebsite(opts)
+ if err == nil {
+ t.Fatal("Expected an error, but got nil")
+ }
+
+ if count := atomic.LoadInt32(&requestCount); count != 2 {
+ t.Errorf("Expected 2 requests to failing paths, but got %d", count)
+ }
+
+ if !strings.Contains(err.Error(), "circuit is open") {
+ t.Errorf("Expected error to contain 'circuit is open', but got: %v", err)
+ }
+ })
}
// --- Helpers ---