diff --git a/pkg/mining/auth.go b/pkg/mining/auth.go index 11a7780..49a6b4b 100644 --- a/pkg/mining/auth.go +++ b/pkg/mining/auth.go @@ -219,7 +219,11 @@ func (da *DigestAuth) generateOpaque() string { // cleanupNonces removes expired nonces periodically func (da *DigestAuth) cleanupNonces() { - ticker := time.NewTicker(da.config.NonceExpiry) + interval := da.config.NonceExpiry + if interval <= 0 { + interval = 5 * time.Minute // Default if not set + } + ticker := time.NewTicker(interval) defer ticker.Stop() for { diff --git a/pkg/mining/auth_test.go b/pkg/mining/auth_test.go new file mode 100644 index 0000000..b7bacb1 --- /dev/null +++ b/pkg/mining/auth_test.go @@ -0,0 +1,604 @@ +package mining + +import ( + "crypto/md5" + "encoding/base64" + "encoding/hex" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func TestDefaultAuthConfig(t *testing.T) { + cfg := DefaultAuthConfig() + + if cfg.Enabled { + t.Error("expected Enabled to be false by default") + } + if cfg.Username != "" { + t.Error("expected Username to be empty by default") + } + if cfg.Password != "" { + t.Error("expected Password to be empty by default") + } + if cfg.Realm != "Mining API" { + t.Errorf("expected Realm to be 'Mining API', got %s", cfg.Realm) + } + if cfg.NonceExpiry != 5*time.Minute { + t.Errorf("expected NonceExpiry to be 5 minutes, got %v", cfg.NonceExpiry) + } +} + +func TestAuthConfigFromEnv(t *testing.T) { + // Save original env + origAuth := os.Getenv("MINING_API_AUTH") + origUser := os.Getenv("MINING_API_USER") + origPass := os.Getenv("MINING_API_PASS") + origRealm := os.Getenv("MINING_API_REALM") + defer func() { + os.Setenv("MINING_API_AUTH", origAuth) + os.Setenv("MINING_API_USER", origUser) + os.Setenv("MINING_API_PASS", origPass) + os.Setenv("MINING_API_REALM", origRealm) + }() + + t.Run("auth disabled by default", func(t *testing.T) { + os.Setenv("MINING_API_AUTH", "") + cfg := AuthConfigFromEnv() + if cfg.Enabled { + t.Error("expected Enabled to be false when env not set") + } + }) + + t.Run("auth enabled with valid credentials", func(t *testing.T) { + os.Setenv("MINING_API_AUTH", "true") + os.Setenv("MINING_API_USER", "testuser") + os.Setenv("MINING_API_PASS", "testpass") + + cfg := AuthConfigFromEnv() + if !cfg.Enabled { + t.Error("expected Enabled to be true") + } + if cfg.Username != "testuser" { + t.Errorf("expected Username 'testuser', got %s", cfg.Username) + } + if cfg.Password != "testpass" { + t.Errorf("expected Password 'testpass', got %s", cfg.Password) + } + }) + + t.Run("auth disabled if credentials missing", func(t *testing.T) { + os.Setenv("MINING_API_AUTH", "true") + os.Setenv("MINING_API_USER", "") + os.Setenv("MINING_API_PASS", "") + + cfg := AuthConfigFromEnv() + if cfg.Enabled { + t.Error("expected Enabled to be false when credentials missing") + } + }) + + t.Run("custom realm", func(t *testing.T) { + os.Setenv("MINING_API_AUTH", "") + os.Setenv("MINING_API_REALM", "Custom Realm") + + cfg := AuthConfigFromEnv() + if cfg.Realm != "Custom Realm" { + t.Errorf("expected Realm 'Custom Realm', got %s", cfg.Realm) + } + }) +} + +func TestNewDigestAuth(t *testing.T) { + cfg := AuthConfig{ + Enabled: true, + Username: "user", + Password: "pass", + Realm: "Test", + NonceExpiry: time.Second, + } + + da := NewDigestAuth(cfg) + if da == nil { + t.Fatal("expected non-nil DigestAuth") + } + + // Cleanup + da.Stop() +} + +func TestDigestAuthStop(t *testing.T) { + cfg := DefaultAuthConfig() + da := NewDigestAuth(cfg) + + // Should not panic when called multiple times + da.Stop() + da.Stop() + da.Stop() +} + +func TestMiddlewareAuthDisabled(t *testing.T) { + cfg := AuthConfig{Enabled: false} + da := NewDigestAuth(cfg) + defer da.Stop() + + router := gin.New() + router.Use(da.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "success") + }) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + if w.Body.String() != "success" { + t.Errorf("expected body 'success', got %s", w.Body.String()) + } +} + +func TestMiddlewareNoAuth(t *testing.T) { + cfg := AuthConfig{ + Enabled: true, + Username: "user", + Password: "pass", + Realm: "Test", + NonceExpiry: 5 * time.Minute, + } + da := NewDigestAuth(cfg) + defer da.Stop() + + router := gin.New() + router.Use(da.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "success") + }) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected status 401, got %d", w.Code) + } + + wwwAuth := w.Header().Get("WWW-Authenticate") + if wwwAuth == "" { + t.Error("expected WWW-Authenticate header") + } + if !authTestContains(wwwAuth, "Digest") { + t.Error("expected Digest challenge in WWW-Authenticate") + } + if !authTestContains(wwwAuth, `realm="Test"`) { + t.Error("expected realm in WWW-Authenticate") + } +} + +func TestMiddlewareBasicAuthValid(t *testing.T) { + cfg := AuthConfig{ + Enabled: true, + Username: "user", + Password: "pass", + Realm: "Test", + NonceExpiry: 5 * time.Minute, + } + da := NewDigestAuth(cfg) + defer da.Stop() + + router := gin.New() + router.Use(da.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "success") + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.SetBasicAuth("user", "pass") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +func TestMiddlewareBasicAuthInvalid(t *testing.T) { + cfg := AuthConfig{ + Enabled: true, + Username: "user", + Password: "pass", + Realm: "Test", + NonceExpiry: 5 * time.Minute, + } + da := NewDigestAuth(cfg) + defer da.Stop() + + router := gin.New() + router.Use(da.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "success") + }) + + testCases := []struct { + name string + user string + password string + }{ + {"wrong user", "wronguser", "pass"}, + {"wrong password", "user", "wrongpass"}, + {"both wrong", "wronguser", "wrongpass"}, + {"empty user", "", "pass"}, + {"empty password", "user", ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.SetBasicAuth(tc.user, tc.password) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected status 401, got %d", w.Code) + } + }) + } +} + +func TestMiddlewareDigestAuthValid(t *testing.T) { + cfg := AuthConfig{ + Enabled: true, + Username: "testuser", + Password: "testpass", + Realm: "Test Realm", + NonceExpiry: 5 * time.Minute, + } + da := NewDigestAuth(cfg) + defer da.Stop() + + router := gin.New() + router.Use(da.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "success") + }) + + // First request to get nonce + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401 to get nonce, got %d", w.Code) + } + + wwwAuth := w.Header().Get("WWW-Authenticate") + params := parseDigestParams(wwwAuth[7:]) // Skip "Digest " + nonce := params["nonce"] + + if nonce == "" { + t.Fatal("nonce not found in challenge") + } + + // Build digest auth response + uri := "/test" + nc := "00000001" + cnonce := "abc123" + qop := "auth" + + ha1 := md5Hash(fmt.Sprintf("%s:%s:%s", cfg.Username, cfg.Realm, cfg.Password)) + ha2 := md5Hash(fmt.Sprintf("GET:%s", uri)) + response := md5Hash(fmt.Sprintf("%s:%s:%s:%s:%s:%s", ha1, nonce, nc, cnonce, qop, ha2)) + + authHeader := fmt.Sprintf( + `Digest username="%s", realm="%s", nonce="%s", uri="%s", qop=%s, nc=%s, cnonce="%s", response="%s"`, + cfg.Username, cfg.Realm, nonce, uri, qop, nc, cnonce, response, + ) + + // Second request with digest auth + req2 := httptest.NewRequest("GET", "/test", nil) + req2.Header.Set("Authorization", authHeader) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Errorf("expected status 200, got %d; body: %s", w2.Code, w2.Body.String()) + } +} + +func TestMiddlewareDigestAuthInvalidNonce(t *testing.T) { + cfg := AuthConfig{ + Enabled: true, + Username: "user", + Password: "pass", + Realm: "Test", + NonceExpiry: 5 * time.Minute, + } + da := NewDigestAuth(cfg) + defer da.Stop() + + router := gin.New() + router.Use(da.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "success") + }) + + // Try with a fake nonce that was never issued + authHeader := `Digest username="user", realm="Test", nonce="fakenonce123", uri="/test", qop=auth, nc=00000001, cnonce="abc", response="xxx"` + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", authHeader) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected status 401 for invalid nonce, got %d", w.Code) + } +} + +func TestMiddlewareDigestAuthExpiredNonce(t *testing.T) { + cfg := AuthConfig{ + Enabled: true, + Username: "user", + Password: "pass", + Realm: "Test", + NonceExpiry: 50 * time.Millisecond, // Very short for testing + } + da := NewDigestAuth(cfg) + defer da.Stop() + + router := gin.New() + router.Use(da.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "success") + }) + + // Get a valid nonce + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + wwwAuth := w.Header().Get("WWW-Authenticate") + params := parseDigestParams(wwwAuth[7:]) + nonce := params["nonce"] + + // Wait for nonce to expire + time.Sleep(100 * time.Millisecond) + + // Try to use expired nonce + uri := "/test" + ha1 := md5Hash(fmt.Sprintf("%s:%s:%s", cfg.Username, cfg.Realm, cfg.Password)) + ha2 := md5Hash(fmt.Sprintf("GET:%s", uri)) + response := md5Hash(fmt.Sprintf("%s:%s:%s", ha1, nonce, ha2)) + + authHeader := fmt.Sprintf( + `Digest username="%s", realm="%s", nonce="%s", uri="%s", response="%s"`, + cfg.Username, cfg.Realm, nonce, uri, response, + ) + + req2 := httptest.NewRequest("GET", "/test", nil) + req2.Header.Set("Authorization", authHeader) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusUnauthorized { + t.Errorf("expected status 401 for expired nonce, got %d", w2.Code) + } +} + +func TestParseDigestParams(t *testing.T) { + testCases := []struct { + name string + input string + expected map[string]string + }{ + { + name: "basic params", + input: `username="john", realm="test"`, + expected: map[string]string{ + "username": "john", + "realm": "test", + }, + }, + { + name: "params with spaces", + input: ` username = "john" , realm = "test" `, + expected: map[string]string{ + "username": "john", + "realm": "test", + }, + }, + { + name: "unquoted values", + input: `qop=auth, nc=00000001`, + expected: map[string]string{ + "qop": "auth", + "nc": "00000001", + }, + }, + { + name: "full digest header", + input: `username="user", realm="Test", nonce="abc123", uri="/api", qop=auth, nc=00000001, cnonce="xyz", response="hash"`, + expected: map[string]string{ + "username": "user", + "realm": "Test", + "nonce": "abc123", + "uri": "/api", + "qop": "auth", + "nc": "00000001", + "cnonce": "xyz", + "response": "hash", + }, + }, + { + name: "empty string", + input: "", + expected: map[string]string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := parseDigestParams(tc.input) + for key, expectedVal := range tc.expected { + if result[key] != expectedVal { + t.Errorf("key %s: expected %s, got %s", key, expectedVal, result[key]) + } + } + }) + } +} + +func TestMd5Hash(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + {"hello", "5d41402abc4b2a76b9719d911017c592"}, + {"", "d41d8cd98f00b204e9800998ecf8427e"}, + {"user:realm:password", func() string { + h := md5.Sum([]byte("user:realm:password")) + return hex.EncodeToString(h[:]) + }()}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result := md5Hash(tc.input) + if result != tc.expected { + t.Errorf("expected %s, got %s", tc.expected, result) + } + }) + } +} + +func TestNonceGeneration(t *testing.T) { + cfg := DefaultAuthConfig() + da := NewDigestAuth(cfg) + defer da.Stop() + + nonces := make(map[string]bool) + for i := 0; i < 100; i++ { + nonce := da.generateNonce() + if len(nonce) != 32 { // 16 bytes = 32 hex chars + t.Errorf("expected nonce length 32, got %d", len(nonce)) + } + if nonces[nonce] { + t.Error("duplicate nonce generated") + } + nonces[nonce] = true + } +} + +func TestOpaqueGeneration(t *testing.T) { + cfg := AuthConfig{Realm: "TestRealm"} + da := NewDigestAuth(cfg) + defer da.Stop() + + opaque1 := da.generateOpaque() + opaque2 := da.generateOpaque() + + // Same realm should produce same opaque + if opaque1 != opaque2 { + t.Error("opaque should be consistent for same realm") + } + + // Should be MD5 of realm + expected := md5Hash("TestRealm") + if opaque1 != expected { + t.Errorf("expected opaque %s, got %s", expected, opaque1) + } +} + +func TestNonceCleanup(t *testing.T) { + cfg := AuthConfig{ + Enabled: true, + Username: "user", + Password: "pass", + Realm: "Test", + NonceExpiry: 50 * time.Millisecond, + } + da := NewDigestAuth(cfg) + defer da.Stop() + + // Store a nonce + nonce := da.generateNonce() + da.nonces.Store(nonce, time.Now()) + + // Verify it exists + if _, ok := da.nonces.Load(nonce); !ok { + t.Error("nonce should exist immediately after storing") + } + + // Wait for cleanup (2x expiry to be safe) + time.Sleep(150 * time.Millisecond) + + // Verify it was cleaned up + if _, ok := da.nonces.Load(nonce); ok { + t.Error("expired nonce should have been cleaned up") + } +} + +// Helper function +func authTestContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// Benchmark tests +func BenchmarkMd5Hash(b *testing.B) { + input := "user:realm:password" + for i := 0; i < b.N; i++ { + md5Hash(input) + } +} + +func BenchmarkNonceGeneration(b *testing.B) { + cfg := DefaultAuthConfig() + da := NewDigestAuth(cfg) + defer da.Stop() + + for i := 0; i < b.N; i++ { + da.generateNonce() + } +} + +func BenchmarkBasicAuthValidation(b *testing.B) { + cfg := AuthConfig{ + Enabled: true, + Username: "user", + Password: "pass", + Realm: "Test", + NonceExpiry: 5 * time.Minute, + } + da := NewDigestAuth(cfg) + defer da.Stop() + + router := gin.New() + router.Use(da.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("user:pass"))) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} diff --git a/pkg/mining/circuit_breaker.go b/pkg/mining/circuit_breaker.go new file mode 100644 index 0000000..ccb1eec --- /dev/null +++ b/pkg/mining/circuit_breaker.go @@ -0,0 +1,246 @@ +package mining + +import ( + "errors" + "sync" + "time" + + "github.com/Snider/Mining/pkg/logging" +) + +// CircuitState represents the state of a circuit breaker +type CircuitState int + +const ( + // CircuitClosed means the circuit is functioning normally + CircuitClosed CircuitState = iota + // CircuitOpen means the circuit has tripped and requests are being rejected + CircuitOpen + // CircuitHalfOpen means the circuit is testing if the service has recovered + CircuitHalfOpen +) + +func (s CircuitState) String() string { + switch s { + case CircuitClosed: + return "closed" + case CircuitOpen: + return "open" + case CircuitHalfOpen: + return "half-open" + default: + return "unknown" + } +} + +// CircuitBreakerConfig holds configuration for a circuit breaker +type CircuitBreakerConfig struct { + // FailureThreshold is the number of failures before opening the circuit + FailureThreshold int + // ResetTimeout is how long to wait before attempting recovery + ResetTimeout time.Duration + // SuccessThreshold is the number of successes needed in half-open state to close + SuccessThreshold int +} + +// DefaultCircuitBreakerConfig returns sensible defaults +func DefaultCircuitBreakerConfig() CircuitBreakerConfig { + return CircuitBreakerConfig{ + FailureThreshold: 3, + ResetTimeout: 30 * time.Second, + SuccessThreshold: 1, + } +} + +// CircuitBreaker implements the circuit breaker pattern +type CircuitBreaker struct { + name string + config CircuitBreakerConfig + state CircuitState + failures int + successes int + lastFailure time.Time + mu sync.RWMutex + cachedResult interface{} + cachedErr error + lastCacheTime time.Time + cacheDuration time.Duration +} + +// ErrCircuitOpen is returned when the circuit is open +var ErrCircuitOpen = errors.New("circuit breaker is open") + +// NewCircuitBreaker creates a new circuit breaker +func NewCircuitBreaker(name string, config CircuitBreakerConfig) *CircuitBreaker { + return &CircuitBreaker{ + name: name, + config: config, + state: CircuitClosed, + cacheDuration: 5 * time.Minute, // Cache successful results for 5 minutes + } +} + +// State returns the current circuit state +func (cb *CircuitBreaker) State() CircuitState { + cb.mu.RLock() + defer cb.mu.RUnlock() + return cb.state +} + +// Execute runs the given function with circuit breaker protection +func (cb *CircuitBreaker) Execute(fn func() (interface{}, error)) (interface{}, error) { + // Check if we should allow this request + if !cb.allowRequest() { + // Return cached result if available + cb.mu.RLock() + if cb.cachedResult != nil && time.Since(cb.lastCacheTime) < cb.cacheDuration { + result := cb.cachedResult + cb.mu.RUnlock() + logging.Debug("circuit breaker returning cached result", logging.Fields{ + "name": cb.name, + "state": cb.state.String(), + }) + return result, nil + } + cb.mu.RUnlock() + return nil, ErrCircuitOpen + } + + // Execute the function + result, err := fn() + + // Record the result + if err != nil { + cb.recordFailure() + } else { + cb.recordSuccess(result) + } + + return result, err +} + +// allowRequest checks if a request should be allowed through +func (cb *CircuitBreaker) allowRequest() bool { + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case CircuitClosed: + return true + + case CircuitOpen: + // Check if we should transition to half-open + if time.Since(cb.lastFailure) > cb.config.ResetTimeout { + cb.state = CircuitHalfOpen + cb.successes = 0 + logging.Info("circuit breaker transitioning to half-open", logging.Fields{ + "name": cb.name, + }) + return true + } + return false + + case CircuitHalfOpen: + // Allow probe requests through + return true + + default: + return false + } +} + +// recordFailure records a failed request +func (cb *CircuitBreaker) recordFailure() { + cb.mu.Lock() + defer cb.mu.Unlock() + + cb.failures++ + cb.lastFailure = time.Now() + + switch cb.state { + case CircuitClosed: + if cb.failures >= cb.config.FailureThreshold { + cb.state = CircuitOpen + logging.Warn("circuit breaker opened", logging.Fields{ + "name": cb.name, + "failures": cb.failures, + }) + } + + case CircuitHalfOpen: + // Probe failed, back to open + cb.state = CircuitOpen + logging.Warn("circuit breaker probe failed, reopening", logging.Fields{ + "name": cb.name, + }) + } +} + +// recordSuccess records a successful request +func (cb *CircuitBreaker) recordSuccess(result interface{}) { + cb.mu.Lock() + defer cb.mu.Unlock() + + // Cache the successful result + cb.cachedResult = result + cb.lastCacheTime = time.Now() + cb.cachedErr = nil + + switch cb.state { + case CircuitClosed: + // Reset failure count on success + cb.failures = 0 + + case CircuitHalfOpen: + cb.successes++ + if cb.successes >= cb.config.SuccessThreshold { + cb.state = CircuitClosed + cb.failures = 0 + logging.Info("circuit breaker closed after successful probe", logging.Fields{ + "name": cb.name, + }) + } + } +} + +// Reset manually resets the circuit breaker to closed state +func (cb *CircuitBreaker) Reset() { + cb.mu.Lock() + defer cb.mu.Unlock() + + cb.state = CircuitClosed + cb.failures = 0 + cb.successes = 0 + logging.Debug("circuit breaker manually reset", logging.Fields{ + "name": cb.name, + }) +} + +// GetCached returns the cached result if available +func (cb *CircuitBreaker) GetCached() (interface{}, bool) { + cb.mu.RLock() + defer cb.mu.RUnlock() + + if cb.cachedResult != nil && time.Since(cb.lastCacheTime) < cb.cacheDuration { + return cb.cachedResult, true + } + return nil, false +} + +// Global circuit breaker for GitHub API +var ( + githubCircuitBreaker *CircuitBreaker + githubCircuitBreakerOnce sync.Once +) + +// getGitHubCircuitBreaker returns the shared GitHub API circuit breaker +func getGitHubCircuitBreaker() *CircuitBreaker { + githubCircuitBreakerOnce.Do(func() { + githubCircuitBreaker = NewCircuitBreaker("github-api", CircuitBreakerConfig{ + FailureThreshold: 3, + ResetTimeout: 60 * time.Second, // Wait 1 minute before retrying + SuccessThreshold: 1, + }) + }) + return githubCircuitBreaker +} diff --git a/pkg/mining/circuit_breaker_test.go b/pkg/mining/circuit_breaker_test.go new file mode 100644 index 0000000..03363b0 --- /dev/null +++ b/pkg/mining/circuit_breaker_test.go @@ -0,0 +1,334 @@ +package mining + +import ( + "errors" + "sync" + "testing" + "time" +) + +func TestCircuitBreakerDefaultConfig(t *testing.T) { + cfg := DefaultCircuitBreakerConfig() + + if cfg.FailureThreshold != 3 { + t.Errorf("expected FailureThreshold 3, got %d", cfg.FailureThreshold) + } + if cfg.ResetTimeout != 30*time.Second { + t.Errorf("expected ResetTimeout 30s, got %v", cfg.ResetTimeout) + } + if cfg.SuccessThreshold != 1 { + t.Errorf("expected SuccessThreshold 1, got %d", cfg.SuccessThreshold) + } +} + +func TestCircuitBreakerStateString(t *testing.T) { + tests := []struct { + state CircuitState + expected string + }{ + {CircuitClosed, "closed"}, + {CircuitOpen, "open"}, + {CircuitHalfOpen, "half-open"}, + {CircuitState(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.state.String(); got != tt.expected { + t.Errorf("state %d: expected %s, got %s", tt.state, tt.expected, got) + } + } +} + +func TestCircuitBreakerClosed(t *testing.T) { + cb := NewCircuitBreaker("test", DefaultCircuitBreakerConfig()) + + if cb.State() != CircuitClosed { + t.Error("expected initial state to be closed") + } + + // Successful execution + result, err := cb.Execute(func() (interface{}, error) { + return "success", nil + }) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result != "success" { + t.Errorf("expected 'success', got %v", result) + } + if cb.State() != CircuitClosed { + t.Error("state should still be closed after success") + } +} + +func TestCircuitBreakerOpensAfterFailures(t *testing.T) { + cfg := CircuitBreakerConfig{ + FailureThreshold: 2, + ResetTimeout: time.Minute, + SuccessThreshold: 1, + } + cb := NewCircuitBreaker("test", cfg) + + testErr := errors.New("test error") + + // First failure + _, err := cb.Execute(func() (interface{}, error) { + return nil, testErr + }) + if err != testErr { + t.Errorf("expected test error, got %v", err) + } + if cb.State() != CircuitClosed { + t.Error("should still be closed after 1 failure") + } + + // Second failure - should open circuit + _, err = cb.Execute(func() (interface{}, error) { + return nil, testErr + }) + if err != testErr { + t.Errorf("expected test error, got %v", err) + } + if cb.State() != CircuitOpen { + t.Error("should be open after 2 failures") + } +} + +func TestCircuitBreakerRejectsWhenOpen(t *testing.T) { + cfg := CircuitBreakerConfig{ + FailureThreshold: 1, + ResetTimeout: time.Hour, // Long timeout to keep circuit open + SuccessThreshold: 1, + } + cb := NewCircuitBreaker("test", cfg) + + // Open the circuit + cb.Execute(func() (interface{}, error) { + return nil, errors.New("fail") + }) + + if cb.State() != CircuitOpen { + t.Fatal("circuit should be open") + } + + // Next request should be rejected + called := false + _, err := cb.Execute(func() (interface{}, error) { + called = true + return "should not run", nil + }) + + if called { + t.Error("function should not have been called when circuit is open") + } + if err != ErrCircuitOpen { + t.Errorf("expected ErrCircuitOpen, got %v", err) + } +} + +func TestCircuitBreakerTransitionsToHalfOpen(t *testing.T) { + cfg := CircuitBreakerConfig{ + FailureThreshold: 1, + ResetTimeout: 50 * time.Millisecond, + SuccessThreshold: 1, + } + cb := NewCircuitBreaker("test", cfg) + + // Open the circuit + cb.Execute(func() (interface{}, error) { + return nil, errors.New("fail") + }) + + if cb.State() != CircuitOpen { + t.Fatal("circuit should be open") + } + + // Wait for reset timeout + time.Sleep(100 * time.Millisecond) + + // Next request should transition to half-open and execute + result, err := cb.Execute(func() (interface{}, error) { + return "probe success", nil + }) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result != "probe success" { + t.Errorf("expected 'probe success', got %v", result) + } + if cb.State() != CircuitClosed { + t.Error("should be closed after successful probe") + } +} + +func TestCircuitBreakerHalfOpenFailureReopens(t *testing.T) { + cfg := CircuitBreakerConfig{ + FailureThreshold: 1, + ResetTimeout: 50 * time.Millisecond, + SuccessThreshold: 1, + } + cb := NewCircuitBreaker("test", cfg) + + // Open the circuit + cb.Execute(func() (interface{}, error) { + return nil, errors.New("fail") + }) + + // Wait for reset timeout + time.Sleep(100 * time.Millisecond) + + // Probe fails + cb.Execute(func() (interface{}, error) { + return nil, errors.New("probe failed") + }) + + if cb.State() != CircuitOpen { + t.Error("should be open after probe failure") + } +} + +func TestCircuitBreakerCaching(t *testing.T) { + cfg := CircuitBreakerConfig{ + FailureThreshold: 1, + ResetTimeout: time.Hour, + SuccessThreshold: 1, + } + cb := NewCircuitBreaker("test", cfg) + + // Successful call - caches result + result, err := cb.Execute(func() (interface{}, error) { + return "cached value", nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "cached value" { + t.Fatalf("expected 'cached value', got %v", result) + } + + // Open the circuit + cb.Execute(func() (interface{}, error) { + return nil, errors.New("fail") + }) + + // Should return cached value when circuit is open + result, err = cb.Execute(func() (interface{}, error) { + return "should not run", nil + }) + + if err != nil { + t.Errorf("expected cached result, got error: %v", err) + } + if result != "cached value" { + t.Errorf("expected 'cached value', got %v", result) + } +} + +func TestCircuitBreakerGetCached(t *testing.T) { + cb := NewCircuitBreaker("test", DefaultCircuitBreakerConfig()) + + // No cache initially + _, ok := cb.GetCached() + if ok { + t.Error("expected no cached value initially") + } + + // Cache a value + cb.Execute(func() (interface{}, error) { + return "test value", nil + }) + + cached, ok := cb.GetCached() + if !ok { + t.Error("expected cached value") + } + if cached != "test value" { + t.Errorf("expected 'test value', got %v", cached) + } +} + +func TestCircuitBreakerReset(t *testing.T) { + cfg := CircuitBreakerConfig{ + FailureThreshold: 1, + ResetTimeout: time.Hour, + SuccessThreshold: 1, + } + cb := NewCircuitBreaker("test", cfg) + + // Open the circuit + cb.Execute(func() (interface{}, error) { + return nil, errors.New("fail") + }) + + if cb.State() != CircuitOpen { + t.Fatal("circuit should be open") + } + + // Manual reset + cb.Reset() + + if cb.State() != CircuitClosed { + t.Error("circuit should be closed after reset") + } +} + +func TestCircuitBreakerConcurrency(t *testing.T) { + cb := NewCircuitBreaker("test", DefaultCircuitBreakerConfig()) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + cb.Execute(func() (interface{}, error) { + if n%3 == 0 { + return nil, errors.New("fail") + } + return "success", nil + }) + }(i) + } + wg.Wait() + + // Just verify no panics occurred + _ = cb.State() +} + +func TestGetGitHubCircuitBreaker(t *testing.T) { + cb1 := getGitHubCircuitBreaker() + cb2 := getGitHubCircuitBreaker() + + if cb1 != cb2 { + t.Error("expected singleton circuit breaker") + } + + if cb1.name != "github-api" { + t.Errorf("expected name 'github-api', got %s", cb1.name) + } +} + +// Benchmark tests +func BenchmarkCircuitBreakerExecute(b *testing.B) { + cb := NewCircuitBreaker("bench", DefaultCircuitBreakerConfig()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + cb.Execute(func() (interface{}, error) { + return "result", nil + }) + } +} + +func BenchmarkCircuitBreakerConcurrent(b *testing.B) { + cb := NewCircuitBreaker("bench", DefaultCircuitBreakerConfig()) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cb.Execute(func() (interface{}, error) { + return "result", nil + }) + } + }) +} diff --git a/pkg/mining/events.go b/pkg/mining/events.go index d9cdcc9..f3de661 100644 --- a/pkg/mining/events.go +++ b/pkg/mining/events.go @@ -60,6 +60,7 @@ type wsClient struct { send chan []byte hub *EventHub miners map[string]bool // subscribed miners, "*" for all + minersMu sync.RWMutex // protects miners map from concurrent access closeOnce sync.Once } @@ -143,6 +144,11 @@ func (h *EventHub) Run() { // Send initial state sync if provider is set if stateProvider != nil { go func(c *wsClient) { + defer func() { + if r := recover(); r != nil { + logging.Error("panic in state sync goroutine", logging.Fields{"panic": r}) + } + }() state := stateProvider() if state != nil { event := Event{ @@ -206,7 +212,10 @@ func (h *EventHub) shouldSendToClient(client *wsClient, event Event) bool { return true } - // Check miner subscription for miner events + // Check miner subscription for miner events (protected by mutex) + client.minersMu.RLock() + defer client.minersMu.RUnlock() + if client.miners == nil || len(client.miners) == 0 { // No subscription filter, send all return true @@ -354,11 +363,13 @@ func (c *wsClient) readPump() { switch msg.Type { case "subscribe": - // Update miner subscription + // Update miner subscription (protected by mutex) + c.minersMu.Lock() c.miners = make(map[string]bool) for _, m := range msg.Miners { c.miners[m] = true } + c.minersMu.Unlock() logging.Debug("client subscribed to miners", logging.Fields{"miners": msg.Miners}) case "ping": diff --git a/pkg/mining/manager.go b/pkg/mining/manager.go index cf5c8fe..d4226c2 100644 --- a/pkg/mining/manager.go +++ b/pkg/mining/manager.go @@ -130,6 +130,11 @@ func (m *Manager) startDBCleanup() { m.waitGroup.Add(1) go func() { defer m.waitGroup.Done() + defer func() { + if r := recover(); r != nil { + logging.Error("panic in database cleanup goroutine", logging.Fields{"panic": r}) + } + }() // Run cleanup once per hour ticker := time.NewTicker(time.Hour) defer ticker.Stop() @@ -523,6 +528,11 @@ func (m *Manager) startStatsCollection() { m.waitGroup.Add(1) go func() { defer m.waitGroup.Done() + defer func() { + if r := recover(); r != nil { + logging.Error("panic in stats collection goroutine", logging.Fields{"panic": r}) + } + }() ticker := time.NewTicker(HighResolutionInterval) defer ticker.Stop() @@ -570,6 +580,14 @@ func (m *Manager) collectMinerStats() { wg.Add(1) go func(miner Miner, minerType string) { defer wg.Done() + defer func() { + if r := recover(); r != nil { + logging.Error("panic in single miner stats collection", logging.Fields{ + "panic": r, + "miner": miner.GetName(), + }) + } + }() m.collectSingleMinerStats(miner, minerType, now, dbEnabled) }(mi.miner, mi.minerType) } diff --git a/pkg/mining/version.go b/pkg/mining/version.go index 96282d0..2e2c8c8 100644 --- a/pkg/mining/version.go +++ b/pkg/mining/version.go @@ -36,7 +36,37 @@ type GitHubRelease struct { // FetchLatestGitHubVersion fetches the latest release version from a GitHub repository. // It takes the repository owner and name (e.g., "xmrig", "xmrig") and returns the tag name. +// Uses a circuit breaker to prevent cascading failures when GitHub API is unavailable. func FetchLatestGitHubVersion(owner, repo string) (string, error) { + cb := getGitHubCircuitBreaker() + + result, err := cb.Execute(func() (interface{}, error) { + return fetchGitHubVersionDirect(owner, repo) + }) + + if err != nil { + // If circuit is open, try to return cached value with warning + if err == ErrCircuitOpen { + if cached, ok := cb.GetCached(); ok { + if tagName, ok := cached.(string); ok { + return tagName, nil + } + } + return "", fmt.Errorf("github API unavailable (circuit breaker open): %w", err) + } + return "", err + } + + tagName, ok := result.(string) + if !ok { + return "", fmt.Errorf("unexpected result type from circuit breaker") + } + + return tagName, nil +} + +// fetchGitHubVersionDirect is the actual GitHub API call, wrapped by circuit breaker +func fetchGitHubVersionDirect(owner, repo string) (string, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", owner, repo) resp, err := getHTTPClient().Get(url) diff --git a/pkg/node/transport.go b/pkg/node/transport.go index 199cbe2..89d917c 100644 --- a/pkg/node/transport.go +++ b/pkg/node/transport.go @@ -22,25 +22,30 @@ var debugLogCounter atomic.Int64 // debugLogInterval controls how often we log debug messages in hot paths (1 in N) const debugLogInterval = 100 +// DefaultMaxMessageSize is the default maximum message size (1MB) +const DefaultMaxMessageSize int64 = 1 << 20 // 1MB + // TransportConfig configures the WebSocket transport. type TransportConfig struct { - ListenAddr string // ":9091" default - WSPath string // "/ws" - WebSocket endpoint path - TLSCertPath string // Optional TLS for wss:// - TLSKeyPath string - MaxConns int // Maximum concurrent connections - PingInterval time.Duration // WebSocket keepalive interval - PongTimeout time.Duration // Timeout waiting for pong + ListenAddr string // ":9091" default + WSPath string // "/ws" - WebSocket endpoint path + TLSCertPath string // Optional TLS for wss:// + TLSKeyPath string + MaxConns int // Maximum concurrent connections + MaxMessageSize int64 // Maximum message size in bytes (0 = 1MB default) + PingInterval time.Duration // WebSocket keepalive interval + PongTimeout time.Duration // Timeout waiting for pong } // DefaultTransportConfig returns sensible defaults. func DefaultTransportConfig() TransportConfig { return TransportConfig{ - ListenAddr: ":9091", - WSPath: "/ws", - MaxConns: 100, - PingInterval: 30 * time.Second, - PongTimeout: 10 * time.Second, + ListenAddr: ":9091", + WSPath: "/ws", + MaxConns: 100, + MaxMessageSize: DefaultMaxMessageSize, + PingInterval: 30 * time.Second, + PongTimeout: 10 * time.Second, } } @@ -49,17 +54,18 @@ type MessageHandler func(conn *PeerConnection, msg *Message) // Transport manages WebSocket connections with SMSG encryption. type Transport struct { - config TransportConfig - server *http.Server - upgrader websocket.Upgrader - conns map[string]*PeerConnection // peer ID -> connection - node *NodeManager - registry *PeerRegistry - handler MessageHandler - mu sync.RWMutex - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup + config TransportConfig + server *http.Server + upgrader websocket.Upgrader + conns map[string]*PeerConnection // peer ID -> connection + pendingConns atomic.Int32 // tracks connections during handshake + node *NodeManager + registry *PeerRegistry + handler MessageHandler + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup } // PeerConnection represents an active connection to a peer. @@ -267,21 +273,34 @@ func (t *Transport) GetConnection(peerID string) *PeerConnection { // handleWSUpgrade handles incoming WebSocket connections. func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { - // Enforce MaxConns limit + // Enforce MaxConns limit (including pending connections during handshake) t.mu.RLock() currentConns := len(t.conns) t.mu.RUnlock() + pendingConns := int(t.pendingConns.Load()) - if currentConns >= t.config.MaxConns { + totalConns := currentConns + pendingConns + if totalConns >= t.config.MaxConns { http.Error(w, "Too many connections", http.StatusServiceUnavailable) return } + // Track this connection as pending during handshake + t.pendingConns.Add(1) + defer t.pendingConns.Add(-1) + conn, err := t.upgrader.Upgrade(w, r, nil) if err != nil { return } + // Apply message size limit during handshake to prevent memory exhaustion + maxSize := t.config.MaxMessageSize + if maxSize <= 0 { + maxSize = DefaultMaxMessageSize + } + conn.SetReadLimit(maxSize) + // Set handshake timeout to prevent slow/malicious clients from blocking handshakeTimeout := 10 * time.Second conn.SetReadDeadline(time.Now().Add(handshakeTimeout)) @@ -468,6 +487,13 @@ func (t *Transport) readLoop(pc *PeerConnection) { defer t.wg.Done() defer t.removeConnection(pc) + // Apply message size limit to prevent memory exhaustion attacks + maxSize := t.config.MaxMessageSize + if maxSize <= 0 { + maxSize = DefaultMaxMessageSize + } + pc.Conn.SetReadLimit(maxSize) + for { select { case <-t.ctx.Done():