feat: Circuit breaker for failing domains
Implement a circuit breaker for the website collector to prevent hammering domains that are consistently failing. The circuit breaker has three states: CLOSED, OPEN, and HALF-OPEN. It tracks failures per-domain and will open the circuit after a configurable number of consecutive failures. After a cooldown period, the circuit will transition to HALF-OPEN and allow a limited number of test requests to check for recovery. The following command-line flags have been added to the `collect website` command: - `--no-circuit-breaker`: Disable the circuit breaker - `--circuit-failures`: Number of failures to trip the circuit breaker - `--circuit-cooldown`: Cooldown time for the circuit breaker - `--circuit-success-threshold`: Number of successes to close the circuit breaker - `--circuit-half-open-requests`: Number of test requests in the half-open state The implementation also includes: - A new `circuitbreaker` package with the core logic - Integration into the `website` package with per-domain tracking - Improved logging to include the domain name and state changes - Integration tests to verify the circuit breaker's behavior Co-authored-by: Snider <631881+Snider@users.noreply.github.com>
This commit is contained in:
parent
cf2af53ed3
commit
2ff65938ca
8 changed files with 455 additions and 45 deletions
|
|
@ -3,8 +3,11 @@ package cmd
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Borg/pkg/circuitbreaker"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"golang.org/x/exp/slog"
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
|
|
@ -38,6 +41,11 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
password, _ := cmd.Flags().GetString("password")
|
||||
noCircuitBreaker, _ := cmd.Flags().GetBool("no-circuit-breaker")
|
||||
circuitFailures, _ := cmd.Flags().GetInt("circuit-failures")
|
||||
circuitCooldown, _ := cmd.Flags().GetDuration("circuit-cooldown")
|
||||
circuitSuccessThreshold, _ := cmd.Flags().GetInt("circuit-success-threshold")
|
||||
circuitHalfOpenRequests, _ := cmd.Flags().GetInt("circuit-half-open-requests")
|
||||
|
||||
if format != "datanode" && format != "tim" && format != "trix" {
|
||||
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format)
|
||||
|
|
@ -51,7 +59,25 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
bar = ui.NewProgressBar(-1, "Crawling website")
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}))
|
||||
|
||||
opts := website.DownloadOptions{
|
||||
URL: websiteURL,
|
||||
MaxDepth: depth,
|
||||
ProgressBar: bar,
|
||||
EnableCircuitBreaker: !noCircuitBreaker,
|
||||
CBSettings: circuitbreaker.Settings{
|
||||
FailureThreshold: circuitFailures,
|
||||
SuccessThreshold: circuitSuccessThreshold,
|
||||
Cooldown: circuitCooldown,
|
||||
HalfOpenRequests: circuitHalfOpenRequests,
|
||||
Logger: logger,
|
||||
},
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error downloading and packaging website: %w", err)
|
||||
}
|
||||
|
|
@ -104,5 +130,10 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
|
||||
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
|
||||
collectWebsiteCmd.Flags().Bool("no-circuit-breaker", false, "Disable the circuit breaker")
|
||||
collectWebsiteCmd.Flags().Int("circuit-failures", 5, "Number of failures to trip the circuit breaker")
|
||||
collectWebsiteCmd.Flags().Duration("circuit-cooldown", 30*time.Second, "Cooldown time for the circuit breaker")
|
||||
collectWebsiteCmd.Flags().Int("circuit-success-threshold", 2, "Number of successes to close the circuit breaker")
|
||||
collectWebsiteCmd.Flags().Int("circuit-half-open-requests", 1, "Number of test requests in half-open state")
|
||||
return collectWebsiteCmd
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,13 +8,12 @@ import (
|
|||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/website"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
)
|
||||
|
||||
func TestCollectWebsiteCmd_Good(t *testing.T) {
|
||||
// Mock the website downloader
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
website.DownloadAndPackageWebsite = func(opts website.DownloadOptions) (*datanode.DataNode, error) {
|
||||
return datanode.New(), nil
|
||||
}
|
||||
defer func() {
|
||||
|
|
@ -35,7 +34,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
|
|||
func TestCollectWebsiteCmd_Bad(t *testing.T) {
|
||||
// Mock the website downloader to return an error
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
website.DownloadAndPackageWebsite = func(opts website.DownloadOptions) (*datanode.DataNode, error) {
|
||||
return nil, fmt.Errorf("website error")
|
||||
}
|
||||
defer func() {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,12 @@ func main() {
|
|||
log.Println("Collecting website...")
|
||||
|
||||
// Download and package the website.
|
||||
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
|
||||
opts := website.DownloadOptions{
|
||||
URL: "https://example.com",
|
||||
MaxDepth: 2,
|
||||
EnableCircuitBreaker: true,
|
||||
}
|
||||
dn, err := website.DownloadAndPackageWebsite(opts)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to collect website: %v", err)
|
||||
}
|
||||
|
|
|
|||
1
go.mod
1
go.mod
|
|
@ -61,6 +61,7 @@ require (
|
|||
github.com/wailsapp/mimetype v1.4.1 // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
golang.org/x/crypto v0.44.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/term v0.37.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
|
|
|
|||
160
pkg/circuitbreaker/circuitbreaker.go
Normal file
160
pkg/circuitbreaker/circuitbreaker.go
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
// State represents the state of the circuit breaker.
|
||||
type State int
|
||||
|
||||
const (
|
||||
// ClosedState is the initial state of the circuit breaker.
|
||||
ClosedState State = iota
|
||||
// OpenState is the state when the circuit breaker is tripped.
|
||||
OpenState
|
||||
// HalfOpenState is the state when the circuit breaker is testing for recovery.
|
||||
HalfOpenState
|
||||
)
|
||||
|
||||
// String returns the string representation of the state.
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case ClosedState:
|
||||
return "CLOSED"
|
||||
case OpenState:
|
||||
return "OPEN"
|
||||
case HalfOpenState:
|
||||
return "HALF-OPEN"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// Settings configures the circuit breaker.
|
||||
type Settings struct {
|
||||
// FailureThreshold is the number of consecutive failures before opening the circuit.
|
||||
FailureThreshold int
|
||||
// SuccessThreshold is the number of consecutive successes to close the circuit.
|
||||
SuccessThreshold int
|
||||
// Cooldown is the time to wait in the open state before transitioning to half-open.
|
||||
Cooldown time.Duration
|
||||
// HalfOpenRequests is the number of test requests to allow in the half-open state.
|
||||
HalfOpenRequests int
|
||||
// Logger is the logger to use for state changes.
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
// CircuitBreaker is a state machine that prevents repeated calls to a failing service.
|
||||
type CircuitBreaker struct {
|
||||
settings Settings
|
||||
domain string
|
||||
mu sync.Mutex
|
||||
state State
|
||||
failures int
|
||||
successes int
|
||||
halfOpenRequests int
|
||||
lastError error
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
// New creates a new CircuitBreaker.
|
||||
func New(domain string, settings Settings) *CircuitBreaker {
|
||||
if settings.Logger == nil {
|
||||
settings.Logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{
|
||||
Level: slog.LevelError,
|
||||
}))
|
||||
}
|
||||
return &CircuitBreaker{
|
||||
settings: settings,
|
||||
domain: domain,
|
||||
state: ClosedState,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs the given function, protected by the circuit breaker.
|
||||
func (cb *CircuitBreaker) Execute(fn func() (interface{}, error)) (interface{}, error) {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
switch cb.state {
|
||||
case OpenState:
|
||||
if time.Now().After(cb.expiry) {
|
||||
cb.setState(HalfOpenState)
|
||||
return cb.executeHalfOpen(fn)
|
||||
}
|
||||
return nil, fmt.Errorf("circuit is open for: %w", cb.lastError)
|
||||
case HalfOpenState:
|
||||
return cb.executeHalfOpen(fn)
|
||||
default: // ClosedState
|
||||
return cb.executeClosed(fn)
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *CircuitBreaker) executeClosed(fn func() (interface{}, error)) (interface{}, error) {
|
||||
res, err := fn()
|
||||
if err != nil {
|
||||
cb.failures++
|
||||
if cb.failures >= cb.settings.FailureThreshold {
|
||||
cb.lastError = err
|
||||
cb.setState(OpenState)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cb.failures = 0
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (cb *CircuitBreaker) executeHalfOpen(fn func() (interface{}, error)) (interface{}, error) {
|
||||
if cb.halfOpenRequests >= cb.settings.HalfOpenRequests {
|
||||
return nil, fmt.Errorf("circuit is half-open, test requests exhausted: %w", cb.lastError)
|
||||
}
|
||||
cb.halfOpenRequests++
|
||||
|
||||
res, err := fn()
|
||||
if err != nil {
|
||||
cb.lastError = err
|
||||
cb.setState(OpenState)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cb.successes++
|
||||
// If enough test requests succeed, close the circuit.
|
||||
if cb.successes >= cb.settings.SuccessThreshold {
|
||||
cb.setState(ClosedState)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (cb *CircuitBreaker) setState(state State) {
|
||||
if cb.state == state {
|
||||
return
|
||||
}
|
||||
|
||||
cb.state = state
|
||||
logMessage := fmt.Sprintf("Circuit %s for %s", state, cb.domain)
|
||||
if state == OpenState {
|
||||
cb.settings.Logger.Warn(logMessage)
|
||||
} else {
|
||||
cb.settings.Logger.Info(logMessage)
|
||||
}
|
||||
|
||||
switch state {
|
||||
case OpenState:
|
||||
cb.expiry = time.Now().Add(cb.settings.Cooldown)
|
||||
cb.successes = 0
|
||||
case ClosedState:
|
||||
cb.failures = 0
|
||||
cb.successes = 0
|
||||
case HalfOpenState:
|
||||
cb.failures = 0
|
||||
cb.halfOpenRequests = 0
|
||||
}
|
||||
}
|
||||
101
pkg/circuitbreaker/circuitbreaker_test.go
Normal file
101
pkg/circuitbreaker/circuitbreaker_test.go
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCircuitBreaker(t *testing.T) {
|
||||
settings := Settings{
|
||||
FailureThreshold: 2,
|
||||
SuccessThreshold: 2,
|
||||
Cooldown: 100 * time.Millisecond,
|
||||
HalfOpenRequests: 2,
|
||||
}
|
||||
cb := New("test.com", settings)
|
||||
|
||||
// Initially closed
|
||||
_, err := cb.Execute(func() (interface{}, error) {
|
||||
return "success", nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected success, got %v", err)
|
||||
}
|
||||
|
||||
// Trip the breaker
|
||||
_, err = cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("failure 1")
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("Expected failure, got nil")
|
||||
}
|
||||
_, err = cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("failure 2")
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("Expected failure, got nil")
|
||||
}
|
||||
|
||||
// Now open
|
||||
_, err = cb.Execute(func() (interface{}, error) {
|
||||
return "should not be called", nil
|
||||
})
|
||||
if err == nil || err.Error() != "circuit is open for: failure 2" {
|
||||
t.Errorf("Expected open circuit error, got %v", err)
|
||||
}
|
||||
|
||||
// Wait for cooldown
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Half-open, should succeed
|
||||
_, err = cb.Execute(func() (interface{}, error) {
|
||||
return "success", nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected success in half-open, got %v", err)
|
||||
}
|
||||
|
||||
// Still half-open, need another success
|
||||
_, err = cb.Execute(func() (interface{}, error) {
|
||||
return "success", nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected success in half-open, got %v", err)
|
||||
}
|
||||
|
||||
// Now closed again
|
||||
_, err = cb.Execute(func() (interface{}, error) {
|
||||
return "success", nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected success in closed state, got %v", err)
|
||||
}
|
||||
|
||||
// Trip again to test half-open failure
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("failure 1")
|
||||
})
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("failure 2")
|
||||
})
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Half-open, but fail
|
||||
_, err = cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("half-open failure")
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("Expected failure in half-open, got nil")
|
||||
}
|
||||
|
||||
// Should be open again
|
||||
_, err = cb.Execute(func() (interface{}, error) {
|
||||
return "should not be called", nil
|
||||
})
|
||||
if err == nil || err.Error() != "circuit is open for: half-open failure" {
|
||||
t.Errorf("Expected open circuit error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -6,53 +6,70 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Snider/Borg/pkg/circuitbreaker"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
)
|
||||
|
||||
var DownloadAndPackageWebsite = downloadAndPackageWebsite
|
||||
|
||||
// DownloadOptions configures the website downloader.
|
||||
type DownloadOptions struct {
|
||||
URL string
|
||||
MaxDepth int
|
||||
ProgressBar *progressbar.ProgressBar
|
||||
EnableCircuitBreaker bool
|
||||
CBSettings circuitbreaker.Settings
|
||||
}
|
||||
|
||||
// Downloader is a recursive website downloader.
|
||||
type Downloader struct {
|
||||
baseURL *url.URL
|
||||
dn *datanode.DataNode
|
||||
visited map[string]bool
|
||||
maxDepth int
|
||||
progressBar *progressbar.ProgressBar
|
||||
client *http.Client
|
||||
errors []error
|
||||
baseURL *url.URL
|
||||
dn *datanode.DataNode
|
||||
visited map[string]bool
|
||||
maxDepth int
|
||||
progressBar *progressbar.ProgressBar
|
||||
client *http.Client
|
||||
errors []error
|
||||
cbEnabled bool
|
||||
cbSettings circuitbreaker.Settings
|
||||
circuitBreakers map[string]*circuitbreaker.CircuitBreaker
|
||||
cbMutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewDownloader creates a new Downloader.
|
||||
func NewDownloader(maxDepth int) *Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
|
||||
func NewDownloader(opts DownloadOptions) *Downloader {
|
||||
return NewDownloaderWithClient(opts, http.DefaultClient)
|
||||
}
|
||||
|
||||
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
|
||||
func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
||||
func NewDownloaderWithClient(opts DownloadOptions, client *http.Client) *Downloader {
|
||||
return &Downloader{
|
||||
dn: datanode.New(),
|
||||
visited: make(map[string]bool),
|
||||
maxDepth: maxDepth,
|
||||
client: client,
|
||||
errors: make([]error, 0),
|
||||
dn: datanode.New(),
|
||||
visited: make(map[string]bool),
|
||||
maxDepth: opts.MaxDepth,
|
||||
client: client,
|
||||
errors: make([]error, 0),
|
||||
cbEnabled: opts.EnableCircuitBreaker,
|
||||
cbSettings: opts.CBSettings,
|
||||
circuitBreakers: make(map[string]*circuitbreaker.CircuitBreaker),
|
||||
}
|
||||
}
|
||||
|
||||
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(startURL)
|
||||
func downloadAndPackageWebsite(opts DownloadOptions) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(opts.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := NewDownloader(maxDepth)
|
||||
d := NewDownloader(opts)
|
||||
d.baseURL = baseURL
|
||||
d.progressBar = bar
|
||||
d.crawl(startURL, 0)
|
||||
d.progressBar = opts.ProgressBar
|
||||
d.crawl(opts.URL, 0)
|
||||
|
||||
if len(d.errors) > 0 {
|
||||
var errs []string
|
||||
|
|
@ -65,6 +82,57 @@ func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P
|
|||
return d.dn, nil
|
||||
}
|
||||
|
||||
func (d *Downloader) getCircuitBreaker(host string) *circuitbreaker.CircuitBreaker {
|
||||
d.cbMutex.Lock()
|
||||
defer d.cbMutex.Unlock()
|
||||
|
||||
if cb, ok := d.circuitBreakers[host]; ok {
|
||||
return cb
|
||||
}
|
||||
|
||||
cb := circuitbreaker.New(host, d.cbSettings)
|
||||
d.circuitBreakers[host] = cb
|
||||
return cb
|
||||
}
|
||||
|
||||
func (d *Downloader) fetchURL(pageURL string) (*http.Response, error) {
|
||||
if !d.cbEnabled {
|
||||
resp, err := d.client.Get(pageURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
u, err := url.Parse(pageURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
cb := d.getCircuitBreaker(u.Hostname())
|
||||
result, err := cb.Execute(func() (interface{}, error) {
|
||||
resp, err := d.client.Get(pageURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status)
|
||||
}
|
||||
return resp, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.(*http.Response), nil
|
||||
}
|
||||
|
||||
func (d *Downloader) crawl(pageURL string, depth int) {
|
||||
if depth > d.maxDepth || d.visited[pageURL] {
|
||||
return
|
||||
|
|
@ -74,18 +142,13 @@ func (d *Downloader) crawl(pageURL string, depth int) {
|
|||
d.progressBar.Add(1)
|
||||
}
|
||||
|
||||
resp, err := d.client.Get(pageURL)
|
||||
resp, err := d.fetchURL(pageURL)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status))
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err))
|
||||
|
|
@ -141,18 +204,13 @@ func (d *Downloader) downloadAsset(assetURL string) {
|
|||
d.progressBar.Add(1)
|
||||
}
|
||||
|
||||
resp, err := d.client.Get(assetURL)
|
||||
resp, err := d.fetchURL(assetURL)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting asset %s: %w", assetURL, err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
d.errors = append(d.errors, fmt.Errorf("bad status for asset %s: %s", assetURL, resp.Status))
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error reading body of asset %s: %w", assetURL, err))
|
||||
|
|
|
|||
|
|
@ -7,9 +7,11 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Borg/pkg/circuitbreaker"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
)
|
||||
|
||||
|
|
@ -20,7 +22,12 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar)
|
||||
opts := DownloadOptions{
|
||||
URL: server.URL,
|
||||
MaxDepth: 2,
|
||||
ProgressBar: bar,
|
||||
}
|
||||
dn, err := DownloadAndPackageWebsite(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -52,7 +59,8 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
|
|||
|
||||
func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
||||
t.Run("Invalid Start URL", func(t *testing.T) {
|
||||
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil)
|
||||
opts := DownloadOptions{URL: "http://invalid-url", MaxDepth: 1}
|
||||
_, err := DownloadAndPackageWebsite(opts)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for an invalid start URL, but got nil")
|
||||
}
|
||||
|
|
@ -63,7 +71,8 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
|||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
|
||||
_, err := DownloadAndPackageWebsite(opts)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for a server error on the start URL, but got nil")
|
||||
}
|
||||
|
|
@ -80,7 +89,8 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
|||
}))
|
||||
defer server.Close()
|
||||
// We expect an error because the link is broken.
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
|
||||
dn, err := DownloadAndPackageWebsite(opts)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for a broken link, but got nil")
|
||||
}
|
||||
|
|
@ -99,7 +109,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1
|
||||
opts := DownloadOptions{URL: server.URL, MaxDepth: 1, ProgressBar: bar}
|
||||
dn, err := DownloadAndPackageWebsite(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -122,7 +133,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
fmt.Fprint(w, `<a href="http://externalsite.com/page.html">External</a>`)
|
||||
}))
|
||||
defer server.Close()
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
|
||||
dn, err := DownloadAndPackageWebsite(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -156,7 +168,8 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
// For now, we'll just test that it doesn't hang forever.
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
opts := DownloadOptions{URL: server.URL, MaxDepth: 1}
|
||||
_, err := DownloadAndPackageWebsite(opts)
|
||||
if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
|
||||
// We expect a timeout error, but other errors are failures.
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
|
|
@ -170,6 +183,48 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
t.Fatal("Test timed out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Circuit Breaker", func(t *testing.T) {
|
||||
var requestCount int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
atomic.AddInt32(&requestCount, 1)
|
||||
}
|
||||
if r.URL.Path == "/" {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, `
|
||||
<a href="/fail1">Fail 1</a>
|
||||
<a href="/fail2">Fail 2</a>
|
||||
<a href="/fail3">Fail 3</a>
|
||||
`)
|
||||
} else {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
opts := DownloadOptions{
|
||||
URL: server.URL,
|
||||
MaxDepth: 1,
|
||||
EnableCircuitBreaker: true,
|
||||
CBSettings: circuitbreaker.Settings{
|
||||
FailureThreshold: 2,
|
||||
Cooldown: 1 * time.Second,
|
||||
},
|
||||
}
|
||||
_, err := DownloadAndPackageWebsite(opts)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error, but got nil")
|
||||
}
|
||||
|
||||
if count := atomic.LoadInt32(&requestCount); count != 2 {
|
||||
t.Errorf("Expected 2 requests to failing paths, but got %d", count)
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "circuit is open") {
|
||||
t.Errorf("Expected error to contain 'circuit is open', but got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue