feat: Circuit breaker for failing domains

Implement a circuit breaker for the website collector to prevent hammering domains that are consistently failing.

The circuit breaker has three states: CLOSED, OPEN, and HALF-OPEN. It tracks failures per-domain and will open the circuit after a configurable number of consecutive failures. After a cooldown period, the circuit will transition to HALF-OPEN and allow a limited number of test requests to check for recovery.

The following command-line flags have been added to the `collect website` command:
- `--no-circuit-breaker`: Disable the circuit breaker
- `--circuit-failures`: Number of failures to trip the circuit breaker
- `--circuit-cooldown`: Cooldown time for the circuit breaker
- `--circuit-success-threshold`: Number of successes to close the circuit breaker
- `--circuit-half-open-requests`: Number of test requests in the half-open state

The implementation also includes:
- A new `circuitbreaker` package with the core logic
- Integration into the `website` package with per-domain tracking
- Improved logging to include the domain name and state changes
- Integration tests to verify the circuit breaker's behavior

Co-authored-by: Snider <631881+Snider@users.noreply.github.com>
This commit is contained in:
google-labs-jules[bot] 2026-02-02 00:59:38 +00:00
parent cf2af53ed3
commit 2ff65938ca
8 changed files with 455 additions and 45 deletions

View file

@ -3,8 +3,11 @@ package cmd
import (
"fmt"
"os"
"time"
"github.com/Snider/Borg/pkg/circuitbreaker"
"github.com/schollz/progressbar/v3"
"golang.org/x/exp/slog"
"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
@ -38,6 +41,11 @@ func NewCollectWebsiteCmd() *cobra.Command {
format, _ := cmd.Flags().GetString("format")
compression, _ := cmd.Flags().GetString("compression")
password, _ := cmd.Flags().GetString("password")
noCircuitBreaker, _ := cmd.Flags().GetBool("no-circuit-breaker")
circuitFailures, _ := cmd.Flags().GetInt("circuit-failures")
circuitCooldown, _ := cmd.Flags().GetDuration("circuit-cooldown")
circuitSuccessThreshold, _ := cmd.Flags().GetInt("circuit-success-threshold")
circuitHalfOpenRequests, _ := cmd.Flags().GetInt("circuit-half-open-requests")
if format != "datanode" && format != "tim" && format != "trix" {
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format)
@ -51,7 +59,25 @@ func NewCollectWebsiteCmd() *cobra.Command {
bar = ui.NewProgressBar(-1, "Crawling website")
}
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
opts := website.DownloadOptions{
URL: websiteURL,
MaxDepth: depth,
ProgressBar: bar,
EnableCircuitBreaker: !noCircuitBreaker,
CBSettings: circuitbreaker.Settings{
FailureThreshold: circuitFailures,
SuccessThreshold: circuitSuccessThreshold,
Cooldown: circuitCooldown,
HalfOpenRequests: circuitHalfOpenRequests,
Logger: logger,
},
}
dn, err := website.DownloadAndPackageWebsite(opts)
if err != nil {
return fmt.Errorf("error downloading and packaging website: %w", err)
}
@ -104,5 +130,10 @@ func NewCollectWebsiteCmd() *cobra.Command {
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
collectWebsiteCmd.Flags().Bool("no-circuit-breaker", false, "Disable the circuit breaker")
collectWebsiteCmd.Flags().Int("circuit-failures", 5, "Number of failures to trip the circuit breaker")
collectWebsiteCmd.Flags().Duration("circuit-cooldown", 30*time.Second, "Cooldown time for the circuit breaker")
collectWebsiteCmd.Flags().Int("circuit-success-threshold", 2, "Number of successes to close the circuit breaker")
collectWebsiteCmd.Flags().Int("circuit-half-open-requests", 1, "Number of test requests in half-open state")
return collectWebsiteCmd
}

View file

@ -8,13 +8,12 @@ import (
"github.com/Snider/Borg/pkg/datanode"
"github.com/Snider/Borg/pkg/website"
"github.com/schollz/progressbar/v3"
)
func TestCollectWebsiteCmd_Good(t *testing.T) {
// Mock the website downloader
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
website.DownloadAndPackageWebsite = func(opts website.DownloadOptions) (*datanode.DataNode, error) {
return datanode.New(), nil
}
defer func() {
@ -35,7 +34,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
func TestCollectWebsiteCmd_Bad(t *testing.T) {
// Mock the website downloader to return an error
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
website.DownloadAndPackageWebsite = func(opts website.DownloadOptions) (*datanode.DataNode, error) {
return nil, fmt.Errorf("website error")
}
defer func() {

View file

@ -11,7 +11,12 @@ func main() {
log.Println("Collecting website...")
// Download and package the website.
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
opts := website.DownloadOptions{
URL: "https://example.com",
MaxDepth: 2,
EnableCircuitBreaker: true,
}
dn, err := website.DownloadAndPackageWebsite(opts)
if err != nil {
log.Fatalf("Failed to collect website: %v", err)
}

1
go.mod
View file

@ -61,6 +61,7 @@ require (
github.com/wailsapp/mimetype v1.4.1 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect
golang.org/x/crypto v0.44.0 // indirect
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/term v0.37.0 // indirect
golang.org/x/text v0.31.0 // indirect

View file

@ -0,0 +1,160 @@
package circuitbreaker
import (
"fmt"
"io"
"sync"
"time"
"golang.org/x/exp/slog"
)
// State represents the state of the circuit breaker.
type State int
const (
// ClosedState is the initial state of the circuit breaker.
ClosedState State = iota
// OpenState is the state when the circuit breaker is tripped.
OpenState
// HalfOpenState is the state when the circuit breaker is testing for recovery.
HalfOpenState
)
// String returns the string representation of the state.
func (s State) String() string {
switch s {
case ClosedState:
return "CLOSED"
case OpenState:
return "OPEN"
case HalfOpenState:
return "HALF-OPEN"
default:
return "UNKNOWN"
}
}
// Settings configures the circuit breaker.
type Settings struct {
// FailureThreshold is the number of consecutive failures before opening the circuit.
FailureThreshold int
// SuccessThreshold is the number of consecutive successes to close the circuit.
SuccessThreshold int
// Cooldown is the time to wait in the open state before transitioning to half-open.
Cooldown time.Duration
// HalfOpenRequests is the number of test requests to allow in the half-open state.
HalfOpenRequests int
// Logger is the logger to use for state changes.
Logger *slog.Logger
}
// CircuitBreaker is a state machine that prevents repeated calls to a failing service.
type CircuitBreaker struct {
settings Settings
domain string
mu sync.Mutex
state State
failures int
successes int
halfOpenRequests int
lastError error
expiry time.Time
}
// New creates a new CircuitBreaker.
func New(domain string, settings Settings) *CircuitBreaker {
if settings.Logger == nil {
settings.Logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{
Level: slog.LevelError,
}))
}
return &CircuitBreaker{
settings: settings,
domain: domain,
state: ClosedState,
}
}
// Execute runs the given function, protected by the circuit breaker.
func (cb *CircuitBreaker) Execute(fn func() (interface{}, error)) (interface{}, error) {
cb.mu.Lock()
defer cb.mu.Unlock()
switch cb.state {
case OpenState:
if time.Now().After(cb.expiry) {
cb.setState(HalfOpenState)
return cb.executeHalfOpen(fn)
}
return nil, fmt.Errorf("circuit is open for: %w", cb.lastError)
case HalfOpenState:
return cb.executeHalfOpen(fn)
default: // ClosedState
return cb.executeClosed(fn)
}
}
func (cb *CircuitBreaker) executeClosed(fn func() (interface{}, error)) (interface{}, error) {
res, err := fn()
if err != nil {
cb.failures++
if cb.failures >= cb.settings.FailureThreshold {
cb.lastError = err
cb.setState(OpenState)
}
return nil, err
}
cb.failures = 0
return res, nil
}
func (cb *CircuitBreaker) executeHalfOpen(fn func() (interface{}, error)) (interface{}, error) {
if cb.halfOpenRequests >= cb.settings.HalfOpenRequests {
return nil, fmt.Errorf("circuit is half-open, test requests exhausted: %w", cb.lastError)
}
cb.halfOpenRequests++
res, err := fn()
if err != nil {
cb.lastError = err
cb.setState(OpenState)
return nil, err
}
cb.successes++
// If enough test requests succeed, close the circuit.
if cb.successes >= cb.settings.SuccessThreshold {
cb.setState(ClosedState)
}
return res, nil
}
func (cb *CircuitBreaker) setState(state State) {
if cb.state == state {
return
}
cb.state = state
logMessage := fmt.Sprintf("Circuit %s for %s", state, cb.domain)
if state == OpenState {
cb.settings.Logger.Warn(logMessage)
} else {
cb.settings.Logger.Info(logMessage)
}
switch state {
case OpenState:
cb.expiry = time.Now().Add(cb.settings.Cooldown)
cb.successes = 0
case ClosedState:
cb.failures = 0
cb.successes = 0
case HalfOpenState:
cb.failures = 0
cb.halfOpenRequests = 0
}
}

View file

@ -0,0 +1,101 @@
package circuitbreaker
import (
"errors"
"testing"
"time"
)
func TestCircuitBreaker(t *testing.T) {
settings := Settings{
FailureThreshold: 2,
SuccessThreshold: 2,
Cooldown: 100 * time.Millisecond,
HalfOpenRequests: 2,
}
cb := New("test.com", settings)
// Initially closed
_, err := cb.Execute(func() (interface{}, error) {
return "success", nil
})
if err != nil {
t.Errorf("Expected success, got %v", err)
}
// Trip the breaker
_, err = cb.Execute(func() (interface{}, error) {
return nil, errors.New("failure 1")
})
if err == nil {
t.Error("Expected failure, got nil")
}
_, err = cb.Execute(func() (interface{}, error) {
return nil, errors.New("failure 2")
})
if err == nil {
t.Error("Expected failure, got nil")
}
// Now open
_, err = cb.Execute(func() (interface{}, error) {
return "should not be called", nil
})
if err == nil || err.Error() != "circuit is open for: failure 2" {
t.Errorf("Expected open circuit error, got %v", err)
}
// Wait for cooldown
time.Sleep(150 * time.Millisecond)
// Half-open, should succeed
_, err = cb.Execute(func() (interface{}, error) {
return "success", nil
})
if err != nil {
t.Errorf("Expected success in half-open, got %v", err)
}
// Still half-open, need another success
_, err = cb.Execute(func() (interface{}, error) {
return "success", nil
})
if err != nil {
t.Errorf("Expected success in half-open, got %v", err)
}
// Now closed again
_, err = cb.Execute(func() (interface{}, error) {
return "success", nil
})
if err != nil {
t.Errorf("Expected success in closed state, got %v", err)
}
// Trip again to test half-open failure
cb.Execute(func() (interface{}, error) {
return nil, errors.New("failure 1")
})
cb.Execute(func() (interface{}, error) {
return nil, errors.New("failure 2")
})
time.Sleep(150 * time.Millisecond)
// Half-open, but fail
_, err = cb.Execute(func() (interface{}, error) {
return nil, errors.New("half-open failure")
})
if err == nil {
t.Error("Expected failure in half-open, got nil")
}
// Should be open again
_, err = cb.Execute(func() (interface{}, error) {
return "should not be called", nil
})
if err == nil || err.Error() != "circuit is open for: half-open failure" {
t.Errorf("Expected open circuit error, got %v", err)
}
}

View file

@ -6,53 +6,70 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"github.com/Snider/Borg/pkg/circuitbreaker"
"github.com/Snider/Borg/pkg/datanode"
"github.com/schollz/progressbar/v3"
"golang.org/x/net/html"
)
var DownloadAndPackageWebsite = downloadAndPackageWebsite
// DownloadOptions configures the website downloader.
type DownloadOptions struct {
URL string
MaxDepth int
ProgressBar *progressbar.ProgressBar
EnableCircuitBreaker bool
CBSettings circuitbreaker.Settings
}
// Downloader is a recursive website downloader.
type Downloader struct {
baseURL *url.URL
dn *datanode.DataNode
visited map[string]bool
maxDepth int
progressBar *progressbar.ProgressBar
client *http.Client
errors []error
baseURL *url.URL
dn *datanode.DataNode
visited map[string]bool
maxDepth int
progressBar *progressbar.ProgressBar
client *http.Client
errors []error
cbEnabled bool
cbSettings circuitbreaker.Settings
circuitBreakers map[string]*circuitbreaker.CircuitBreaker
cbMutex sync.Mutex
}
// NewDownloader creates a new Downloader.
func NewDownloader(maxDepth int) *Downloader {
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
func NewDownloader(opts DownloadOptions) *Downloader {
return NewDownloaderWithClient(opts, http.DefaultClient)
}
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
func NewDownloaderWithClient(opts DownloadOptions, client *http.Client) *Downloader {
return &Downloader{
dn: datanode.New(),
visited: make(map[string]bool),
maxDepth: maxDepth,
client: client,
errors: make([]error, 0),
dn: datanode.New(),
visited: make(map[string]bool),
maxDepth: opts.MaxDepth,
client: client,
errors: make([]error, 0),
cbEnabled: opts.EnableCircuitBreaker,
cbSettings: opts.CBSettings,
circuitBreakers: make(map[string]*circuitbreaker.CircuitBreaker),
}
}
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
baseURL, err := url.Parse(startURL)
func downloadAndPackageWebsite(opts DownloadOptions) (*datanode.DataNode, error) {
baseURL, err := url.Parse(opts.URL)
if err != nil {
return nil, err
}
d := NewDownloader(maxDepth)
d := NewDownloader(opts)
d.baseURL = baseURL
d.progressBar = bar
d.crawl(startURL, 0)
d.progressBar = opts.ProgressBar
d.crawl(opts.URL, 0)
if len(d.errors) > 0 {
var errs []string
@ -65,6 +82,57 @@ func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P
return d.dn, nil
}
func (d *Downloader) getCircuitBreaker(host string) *circuitbreaker.CircuitBreaker {
d.cbMutex.Lock()
defer d.cbMutex.Unlock()
if cb, ok := d.circuitBreakers[host]; ok {
return cb
}
cb := circuitbreaker.New(host, d.cbSettings)
d.circuitBreakers[host] = cb
return cb
}
func (d *Downloader) fetchURL(pageURL string) (*http.Response, error) {
if !d.cbEnabled {
resp, err := d.client.Get(pageURL)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
resp.Body.Close()
return nil, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status)
}
return resp, nil
}
u, err := url.Parse(pageURL)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
cb := d.getCircuitBreaker(u.Hostname())
result, err := cb.Execute(func() (interface{}, error) {
resp, err := d.client.Get(pageURL)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
resp.Body.Close()
return nil, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status)
}
return resp, nil
})
if err != nil {
return nil, err
}
return result.(*http.Response), nil
}
func (d *Downloader) crawl(pageURL string, depth int) {
if depth > d.maxDepth || d.visited[pageURL] {
return
@ -74,18 +142,13 @@ func (d *Downloader) crawl(pageURL string, depth int) {
d.progressBar.Add(1)
}
resp, err := d.client.Get(pageURL)
resp, err := d.fetchURL(pageURL)
if err != nil {
d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err))
return
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status))
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err))
@ -141,18 +204,13 @@ func (d *Downloader) downloadAsset(assetURL string) {
d.progressBar.Add(1)
}
resp, err := d.client.Get(assetURL)
resp, err := d.fetchURL(assetURL)
if err != nil {
d.errors = append(d.errors, fmt.Errorf("Error getting asset %s: %w", assetURL, err))
return
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
d.errors = append(d.errors, fmt.Errorf("bad status for asset %s: %s", assetURL, resp.Status))
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
d.errors = append(d.errors, fmt.Errorf("Error reading body of asset %s: %w", assetURL, err))

View file

@ -7,9 +7,11 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Snider/Borg/pkg/circuitbreaker"
"github.com/schollz/progressbar/v3"
)
@ -20,7 +22,12 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
defer server.Close()
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar)
opts := DownloadOptions{
URL: server.URL,
MaxDepth: 2,
ProgressBar: bar,
}
dn, err := DownloadAndPackageWebsite(opts)
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
@ -52,7 +59,8 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
t.Run("Invalid Start URL", func(t *testing.T) {
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil)
opts := DownloadOptions{URL: "http://invalid-url", MaxDepth: 1}
_, err := DownloadAndPackageWebsite(opts)
if err == nil {
t.Fatal("Expected an error for an invalid start URL, but got nil")
}
@ -63,7 +71,8 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}))
defer server.Close()
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
_, err := DownloadAndPackageWebsite(opts)
if err == nil {
t.Fatal("Expected an error for a server error on the start URL, but got nil")
}
@ -80,7 +89,8 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
}))
defer server.Close()
// We expect an error because the link is broken.
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
dn, err := DownloadAndPackageWebsite(opts)
if err == nil {
t.Fatal("Expected an error for a broken link, but got nil")
}
@ -99,7 +109,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
defer server.Close()
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1
opts := DownloadOptions{URL: server.URL, MaxDepth: 1, ProgressBar: bar}
dn, err := DownloadAndPackageWebsite(opts)
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
@ -122,7 +133,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
fmt.Fprint(w, `<a href="http://externalsite.com/page.html">External</a>`)
}))
defer server.Close()
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
dn, err := DownloadAndPackageWebsite(opts)
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
@ -156,7 +168,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
// For now, we'll just test that it doesn't hang forever.
done := make(chan bool)
go func() {
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
_, err := DownloadAndPackageWebsite(opts)
if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
// We expect a timeout error, but other errors are failures.
t.Errorf("unexpected error: %v", err)
@ -170,6 +183,48 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
t.Fatal("Test timed out")
}
})
t.Run("Circuit Breaker", func(t *testing.T) {
var requestCount int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
atomic.AddInt32(&requestCount, 1)
}
if r.URL.Path == "/" {
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, `
<a href="/fail1">Fail 1</a>
<a href="/fail2">Fail 2</a>
<a href="/fail3">Fail 3</a>
`)
} else {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}))
defer server.Close()
opts := DownloadOptions{
URL: server.URL,
MaxDepth: 1,
EnableCircuitBreaker: true,
CBSettings: circuitbreaker.Settings{
FailureThreshold: 2,
Cooldown: 1 * time.Second,
},
}
_, err := DownloadAndPackageWebsite(opts)
if err == nil {
t.Fatal("Expected an error, but got nil")
}
if count := atomic.LoadInt32(&requestCount); count != 2 {
t.Errorf("Expected 2 requests to failing paths, but got %d", count)
}
if !strings.Contains(err.Error(), "circuit is open") {
t.Errorf("Expected error to contain 'circuit is open', but got: %v", err)
}
})
}
// --- Helpers ---