fix: Implement 6 quick wins from 109-finding code review
CONC-HIGH-1: Add mutex to wsClient.miners map to prevent race condition P2P-CRIT-2: Add MaxMessageSize config (1MB default) to prevent memory exhaustion P2P-CRIT-3: Track pending connections during handshake to enforce connection limits RESIL-HIGH-1: Add recover() to 4 background goroutines to prevent service crashes TEST-CRIT-1: Create auth_test.go with 16 tests covering Basic/Digest auth RESIL-HIGH-3: Implement circuit breaker for GitHub API with caching fallback Also fixed: NonceExpiry validation in auth.go to prevent panic on zero interval 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
d7b38195ac
commit
87b426480b
8 changed files with 1301 additions and 28 deletions
|
|
@ -219,7 +219,11 @@ func (da *DigestAuth) generateOpaque() string {
|
|||
|
||||
// cleanupNonces removes expired nonces periodically
|
||||
func (da *DigestAuth) cleanupNonces() {
|
||||
ticker := time.NewTicker(da.config.NonceExpiry)
|
||||
interval := da.config.NonceExpiry
|
||||
if interval <= 0 {
|
||||
interval = 5 * time.Minute // Default if not set
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
|
|
|||
604
pkg/mining/auth_test.go
Normal file
604
pkg/mining/auth_test.go
Normal file
|
|
@ -0,0 +1,604 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestDefaultAuthConfig(t *testing.T) {
|
||||
cfg := DefaultAuthConfig()
|
||||
|
||||
if cfg.Enabled {
|
||||
t.Error("expected Enabled to be false by default")
|
||||
}
|
||||
if cfg.Username != "" {
|
||||
t.Error("expected Username to be empty by default")
|
||||
}
|
||||
if cfg.Password != "" {
|
||||
t.Error("expected Password to be empty by default")
|
||||
}
|
||||
if cfg.Realm != "Mining API" {
|
||||
t.Errorf("expected Realm to be 'Mining API', got %s", cfg.Realm)
|
||||
}
|
||||
if cfg.NonceExpiry != 5*time.Minute {
|
||||
t.Errorf("expected NonceExpiry to be 5 minutes, got %v", cfg.NonceExpiry)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthConfigFromEnv(t *testing.T) {
|
||||
// Save original env
|
||||
origAuth := os.Getenv("MINING_API_AUTH")
|
||||
origUser := os.Getenv("MINING_API_USER")
|
||||
origPass := os.Getenv("MINING_API_PASS")
|
||||
origRealm := os.Getenv("MINING_API_REALM")
|
||||
defer func() {
|
||||
os.Setenv("MINING_API_AUTH", origAuth)
|
||||
os.Setenv("MINING_API_USER", origUser)
|
||||
os.Setenv("MINING_API_PASS", origPass)
|
||||
os.Setenv("MINING_API_REALM", origRealm)
|
||||
}()
|
||||
|
||||
t.Run("auth disabled by default", func(t *testing.T) {
|
||||
os.Setenv("MINING_API_AUTH", "")
|
||||
cfg := AuthConfigFromEnv()
|
||||
if cfg.Enabled {
|
||||
t.Error("expected Enabled to be false when env not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auth enabled with valid credentials", func(t *testing.T) {
|
||||
os.Setenv("MINING_API_AUTH", "true")
|
||||
os.Setenv("MINING_API_USER", "testuser")
|
||||
os.Setenv("MINING_API_PASS", "testpass")
|
||||
|
||||
cfg := AuthConfigFromEnv()
|
||||
if !cfg.Enabled {
|
||||
t.Error("expected Enabled to be true")
|
||||
}
|
||||
if cfg.Username != "testuser" {
|
||||
t.Errorf("expected Username 'testuser', got %s", cfg.Username)
|
||||
}
|
||||
if cfg.Password != "testpass" {
|
||||
t.Errorf("expected Password 'testpass', got %s", cfg.Password)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auth disabled if credentials missing", func(t *testing.T) {
|
||||
os.Setenv("MINING_API_AUTH", "true")
|
||||
os.Setenv("MINING_API_USER", "")
|
||||
os.Setenv("MINING_API_PASS", "")
|
||||
|
||||
cfg := AuthConfigFromEnv()
|
||||
if cfg.Enabled {
|
||||
t.Error("expected Enabled to be false when credentials missing")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom realm", func(t *testing.T) {
|
||||
os.Setenv("MINING_API_AUTH", "")
|
||||
os.Setenv("MINING_API_REALM", "Custom Realm")
|
||||
|
||||
cfg := AuthConfigFromEnv()
|
||||
if cfg.Realm != "Custom Realm" {
|
||||
t.Errorf("expected Realm 'Custom Realm', got %s", cfg.Realm)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewDigestAuth(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Enabled: true,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Realm: "Test",
|
||||
NonceExpiry: time.Second,
|
||||
}
|
||||
|
||||
da := NewDigestAuth(cfg)
|
||||
if da == nil {
|
||||
t.Fatal("expected non-nil DigestAuth")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
da.Stop()
|
||||
}
|
||||
|
||||
func TestDigestAuthStop(t *testing.T) {
|
||||
cfg := DefaultAuthConfig()
|
||||
da := NewDigestAuth(cfg)
|
||||
|
||||
// Should not panic when called multiple times
|
||||
da.Stop()
|
||||
da.Stop()
|
||||
da.Stop()
|
||||
}
|
||||
|
||||
func TestMiddlewareAuthDisabled(t *testing.T) {
|
||||
cfg := AuthConfig{Enabled: false}
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(da.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
if w.Body.String() != "success" {
|
||||
t.Errorf("expected body 'success', got %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareNoAuth(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Enabled: true,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Realm: "Test",
|
||||
NonceExpiry: 5 * time.Minute,
|
||||
}
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(da.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", w.Code)
|
||||
}
|
||||
|
||||
wwwAuth := w.Header().Get("WWW-Authenticate")
|
||||
if wwwAuth == "" {
|
||||
t.Error("expected WWW-Authenticate header")
|
||||
}
|
||||
if !authTestContains(wwwAuth, "Digest") {
|
||||
t.Error("expected Digest challenge in WWW-Authenticate")
|
||||
}
|
||||
if !authTestContains(wwwAuth, `realm="Test"`) {
|
||||
t.Error("expected realm in WWW-Authenticate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareBasicAuthValid(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Enabled: true,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Realm: "Test",
|
||||
NonceExpiry: 5 * time.Minute,
|
||||
}
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(da.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.SetBasicAuth("user", "pass")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareBasicAuthInvalid(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Enabled: true,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Realm: "Test",
|
||||
NonceExpiry: 5 * time.Minute,
|
||||
}
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(da.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
user string
|
||||
password string
|
||||
}{
|
||||
{"wrong user", "wronguser", "pass"},
|
||||
{"wrong password", "user", "wrongpass"},
|
||||
{"both wrong", "wronguser", "wrongpass"},
|
||||
{"empty user", "", "pass"},
|
||||
{"empty password", "user", ""},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.SetBasicAuth(tc.user, tc.password)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareDigestAuthValid(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Enabled: true,
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
Realm: "Test Realm",
|
||||
NonceExpiry: 5 * time.Minute,
|
||||
}
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(da.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
})
|
||||
|
||||
// First request to get nonce
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401 to get nonce, got %d", w.Code)
|
||||
}
|
||||
|
||||
wwwAuth := w.Header().Get("WWW-Authenticate")
|
||||
params := parseDigestParams(wwwAuth[7:]) // Skip "Digest "
|
||||
nonce := params["nonce"]
|
||||
|
||||
if nonce == "" {
|
||||
t.Fatal("nonce not found in challenge")
|
||||
}
|
||||
|
||||
// Build digest auth response
|
||||
uri := "/test"
|
||||
nc := "00000001"
|
||||
cnonce := "abc123"
|
||||
qop := "auth"
|
||||
|
||||
ha1 := md5Hash(fmt.Sprintf("%s:%s:%s", cfg.Username, cfg.Realm, cfg.Password))
|
||||
ha2 := md5Hash(fmt.Sprintf("GET:%s", uri))
|
||||
response := md5Hash(fmt.Sprintf("%s:%s:%s:%s:%s:%s", ha1, nonce, nc, cnonce, qop, ha2))
|
||||
|
||||
authHeader := fmt.Sprintf(
|
||||
`Digest username="%s", realm="%s", nonce="%s", uri="%s", qop=%s, nc=%s, cnonce="%s", response="%s"`,
|
||||
cfg.Username, cfg.Realm, nonce, uri, qop, nc, cnonce, response,
|
||||
)
|
||||
|
||||
// Second request with digest auth
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.Header.Set("Authorization", authHeader)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d; body: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareDigestAuthInvalidNonce(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Enabled: true,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Realm: "Test",
|
||||
NonceExpiry: 5 * time.Minute,
|
||||
}
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(da.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
})
|
||||
|
||||
// Try with a fake nonce that was never issued
|
||||
authHeader := `Digest username="user", realm="Test", nonce="fakenonce123", uri="/test", qop=auth, nc=00000001, cnonce="abc", response="xxx"`
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401 for invalid nonce, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareDigestAuthExpiredNonce(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Enabled: true,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Realm: "Test",
|
||||
NonceExpiry: 50 * time.Millisecond, // Very short for testing
|
||||
}
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(da.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
})
|
||||
|
||||
// Get a valid nonce
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
wwwAuth := w.Header().Get("WWW-Authenticate")
|
||||
params := parseDigestParams(wwwAuth[7:])
|
||||
nonce := params["nonce"]
|
||||
|
||||
// Wait for nonce to expire
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to use expired nonce
|
||||
uri := "/test"
|
||||
ha1 := md5Hash(fmt.Sprintf("%s:%s:%s", cfg.Username, cfg.Realm, cfg.Password))
|
||||
ha2 := md5Hash(fmt.Sprintf("GET:%s", uri))
|
||||
response := md5Hash(fmt.Sprintf("%s:%s:%s", ha1, nonce, ha2))
|
||||
|
||||
authHeader := fmt.Sprintf(
|
||||
`Digest username="%s", realm="%s", nonce="%s", uri="%s", response="%s"`,
|
||||
cfg.Username, cfg.Realm, nonce, uri, response,
|
||||
)
|
||||
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.Header.Set("Authorization", authHeader)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401 for expired nonce, got %d", w2.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDigestParams(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected map[string]string
|
||||
}{
|
||||
{
|
||||
name: "basic params",
|
||||
input: `username="john", realm="test"`,
|
||||
expected: map[string]string{
|
||||
"username": "john",
|
||||
"realm": "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "params with spaces",
|
||||
input: ` username = "john" , realm = "test" `,
|
||||
expected: map[string]string{
|
||||
"username": "john",
|
||||
"realm": "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unquoted values",
|
||||
input: `qop=auth, nc=00000001`,
|
||||
expected: map[string]string{
|
||||
"qop": "auth",
|
||||
"nc": "00000001",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full digest header",
|
||||
input: `username="user", realm="Test", nonce="abc123", uri="/api", qop=auth, nc=00000001, cnonce="xyz", response="hash"`,
|
||||
expected: map[string]string{
|
||||
"username": "user",
|
||||
"realm": "Test",
|
||||
"nonce": "abc123",
|
||||
"uri": "/api",
|
||||
"qop": "auth",
|
||||
"nc": "00000001",
|
||||
"cnonce": "xyz",
|
||||
"response": "hash",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := parseDigestParams(tc.input)
|
||||
for key, expectedVal := range tc.expected {
|
||||
if result[key] != expectedVal {
|
||||
t.Errorf("key %s: expected %s, got %s", key, expectedVal, result[key])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMd5Hash(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"hello", "5d41402abc4b2a76b9719d911017c592"},
|
||||
{"", "d41d8cd98f00b204e9800998ecf8427e"},
|
||||
{"user:realm:password", func() string {
|
||||
h := md5.Sum([]byte("user:realm:password"))
|
||||
return hex.EncodeToString(h[:])
|
||||
}()},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
result := md5Hash(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("expected %s, got %s", tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonceGeneration(t *testing.T) {
|
||||
cfg := DefaultAuthConfig()
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
nonces := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
nonce := da.generateNonce()
|
||||
if len(nonce) != 32 { // 16 bytes = 32 hex chars
|
||||
t.Errorf("expected nonce length 32, got %d", len(nonce))
|
||||
}
|
||||
if nonces[nonce] {
|
||||
t.Error("duplicate nonce generated")
|
||||
}
|
||||
nonces[nonce] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpaqueGeneration(t *testing.T) {
|
||||
cfg := AuthConfig{Realm: "TestRealm"}
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
opaque1 := da.generateOpaque()
|
||||
opaque2 := da.generateOpaque()
|
||||
|
||||
// Same realm should produce same opaque
|
||||
if opaque1 != opaque2 {
|
||||
t.Error("opaque should be consistent for same realm")
|
||||
}
|
||||
|
||||
// Should be MD5 of realm
|
||||
expected := md5Hash("TestRealm")
|
||||
if opaque1 != expected {
|
||||
t.Errorf("expected opaque %s, got %s", expected, opaque1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonceCleanup(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Enabled: true,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Realm: "Test",
|
||||
NonceExpiry: 50 * time.Millisecond,
|
||||
}
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
// Store a nonce
|
||||
nonce := da.generateNonce()
|
||||
da.nonces.Store(nonce, time.Now())
|
||||
|
||||
// Verify it exists
|
||||
if _, ok := da.nonces.Load(nonce); !ok {
|
||||
t.Error("nonce should exist immediately after storing")
|
||||
}
|
||||
|
||||
// Wait for cleanup (2x expiry to be safe)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Verify it was cleaned up
|
||||
if _, ok := da.nonces.Load(nonce); ok {
|
||||
t.Error("expired nonce should have been cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function
|
||||
func authTestContains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkMd5Hash(b *testing.B) {
|
||||
input := "user:realm:password"
|
||||
for i := 0; i < b.N; i++ {
|
||||
md5Hash(input)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNonceGeneration(b *testing.B) {
|
||||
cfg := DefaultAuthConfig()
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
da.generateNonce()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBasicAuthValidation(b *testing.B) {
|
||||
cfg := AuthConfig{
|
||||
Enabled: true,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Realm: "Test",
|
||||
NonceExpiry: 5 * time.Minute,
|
||||
}
|
||||
da := NewDigestAuth(cfg)
|
||||
defer da.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(da.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("user:pass")))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
||||
246
pkg/mining/circuit_breaker.go
Normal file
246
pkg/mining/circuit_breaker.go
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Mining/pkg/logging"
|
||||
)
|
||||
|
||||
// CircuitState represents the state of a circuit breaker
|
||||
type CircuitState int
|
||||
|
||||
const (
|
||||
// CircuitClosed means the circuit is functioning normally
|
||||
CircuitClosed CircuitState = iota
|
||||
// CircuitOpen means the circuit has tripped and requests are being rejected
|
||||
CircuitOpen
|
||||
// CircuitHalfOpen means the circuit is testing if the service has recovered
|
||||
CircuitHalfOpen
|
||||
)
|
||||
|
||||
func (s CircuitState) String() string {
|
||||
switch s {
|
||||
case CircuitClosed:
|
||||
return "closed"
|
||||
case CircuitOpen:
|
||||
return "open"
|
||||
case CircuitHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds configuration for a circuit breaker
|
||||
type CircuitBreakerConfig struct {
|
||||
// FailureThreshold is the number of failures before opening the circuit
|
||||
FailureThreshold int
|
||||
// ResetTimeout is how long to wait before attempting recovery
|
||||
ResetTimeout time.Duration
|
||||
// SuccessThreshold is the number of successes needed in half-open state to close
|
||||
SuccessThreshold int
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns sensible defaults
|
||||
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
|
||||
return CircuitBreakerConfig{
|
||||
FailureThreshold: 3,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
SuccessThreshold: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
type CircuitBreaker struct {
|
||||
name string
|
||||
config CircuitBreakerConfig
|
||||
state CircuitState
|
||||
failures int
|
||||
successes int
|
||||
lastFailure time.Time
|
||||
mu sync.RWMutex
|
||||
cachedResult interface{}
|
||||
cachedErr error
|
||||
lastCacheTime time.Time
|
||||
cacheDuration time.Duration
|
||||
}
|
||||
|
||||
// ErrCircuitOpen is returned when the circuit is open
|
||||
var ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker
|
||||
func NewCircuitBreaker(name string, config CircuitBreakerConfig) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
name: name,
|
||||
config: config,
|
||||
state: CircuitClosed,
|
||||
cacheDuration: 5 * time.Minute, // Cache successful results for 5 minutes
|
||||
}
|
||||
}
|
||||
|
||||
// State returns the current circuit state
|
||||
func (cb *CircuitBreaker) State() CircuitState {
|
||||
cb.mu.RLock()
|
||||
defer cb.mu.RUnlock()
|
||||
return cb.state
|
||||
}
|
||||
|
||||
// Execute runs the given function with circuit breaker protection
|
||||
func (cb *CircuitBreaker) Execute(fn func() (interface{}, error)) (interface{}, error) {
|
||||
// Check if we should allow this request
|
||||
if !cb.allowRequest() {
|
||||
// Return cached result if available
|
||||
cb.mu.RLock()
|
||||
if cb.cachedResult != nil && time.Since(cb.lastCacheTime) < cb.cacheDuration {
|
||||
result := cb.cachedResult
|
||||
cb.mu.RUnlock()
|
||||
logging.Debug("circuit breaker returning cached result", logging.Fields{
|
||||
"name": cb.name,
|
||||
"state": cb.state.String(),
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
cb.mu.RUnlock()
|
||||
return nil, ErrCircuitOpen
|
||||
}
|
||||
|
||||
// Execute the function
|
||||
result, err := fn()
|
||||
|
||||
// Record the result
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
} else {
|
||||
cb.recordSuccess(result)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// allowRequest checks if a request should be allowed through
|
||||
func (cb *CircuitBreaker) allowRequest() bool {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitClosed:
|
||||
return true
|
||||
|
||||
case CircuitOpen:
|
||||
// Check if we should transition to half-open
|
||||
if time.Since(cb.lastFailure) > cb.config.ResetTimeout {
|
||||
cb.state = CircuitHalfOpen
|
||||
cb.successes = 0
|
||||
logging.Info("circuit breaker transitioning to half-open", logging.Fields{
|
||||
"name": cb.name,
|
||||
})
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case CircuitHalfOpen:
|
||||
// Allow probe requests through
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed request
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
cb.failures++
|
||||
cb.lastFailure = time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitClosed:
|
||||
if cb.failures >= cb.config.FailureThreshold {
|
||||
cb.state = CircuitOpen
|
||||
logging.Warn("circuit breaker opened", logging.Fields{
|
||||
"name": cb.name,
|
||||
"failures": cb.failures,
|
||||
})
|
||||
}
|
||||
|
||||
case CircuitHalfOpen:
|
||||
// Probe failed, back to open
|
||||
cb.state = CircuitOpen
|
||||
logging.Warn("circuit breaker probe failed, reopening", logging.Fields{
|
||||
"name": cb.name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess records a successful request
|
||||
func (cb *CircuitBreaker) recordSuccess(result interface{}) {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
// Cache the successful result
|
||||
cb.cachedResult = result
|
||||
cb.lastCacheTime = time.Now()
|
||||
cb.cachedErr = nil
|
||||
|
||||
switch cb.state {
|
||||
case CircuitClosed:
|
||||
// Reset failure count on success
|
||||
cb.failures = 0
|
||||
|
||||
case CircuitHalfOpen:
|
||||
cb.successes++
|
||||
if cb.successes >= cb.config.SuccessThreshold {
|
||||
cb.state = CircuitClosed
|
||||
cb.failures = 0
|
||||
logging.Info("circuit breaker closed after successful probe", logging.Fields{
|
||||
"name": cb.name,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset manually resets the circuit breaker to closed state
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
cb.state = CircuitClosed
|
||||
cb.failures = 0
|
||||
cb.successes = 0
|
||||
logging.Debug("circuit breaker manually reset", logging.Fields{
|
||||
"name": cb.name,
|
||||
})
|
||||
}
|
||||
|
||||
// GetCached returns the cached result if available
|
||||
func (cb *CircuitBreaker) GetCached() (interface{}, bool) {
|
||||
cb.mu.RLock()
|
||||
defer cb.mu.RUnlock()
|
||||
|
||||
if cb.cachedResult != nil && time.Since(cb.lastCacheTime) < cb.cacheDuration {
|
||||
return cb.cachedResult, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Global circuit breaker for GitHub API
|
||||
var (
|
||||
githubCircuitBreaker *CircuitBreaker
|
||||
githubCircuitBreakerOnce sync.Once
|
||||
)
|
||||
|
||||
// getGitHubCircuitBreaker returns the shared GitHub API circuit breaker
|
||||
func getGitHubCircuitBreaker() *CircuitBreaker {
|
||||
githubCircuitBreakerOnce.Do(func() {
|
||||
githubCircuitBreaker = NewCircuitBreaker("github-api", CircuitBreakerConfig{
|
||||
FailureThreshold: 3,
|
||||
ResetTimeout: 60 * time.Second, // Wait 1 minute before retrying
|
||||
SuccessThreshold: 1,
|
||||
})
|
||||
})
|
||||
return githubCircuitBreaker
|
||||
}
|
||||
334
pkg/mining/circuit_breaker_test.go
Normal file
334
pkg/mining/circuit_breaker_test.go
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCircuitBreakerDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultCircuitBreakerConfig()
|
||||
|
||||
if cfg.FailureThreshold != 3 {
|
||||
t.Errorf("expected FailureThreshold 3, got %d", cfg.FailureThreshold)
|
||||
}
|
||||
if cfg.ResetTimeout != 30*time.Second {
|
||||
t.Errorf("expected ResetTimeout 30s, got %v", cfg.ResetTimeout)
|
||||
}
|
||||
if cfg.SuccessThreshold != 1 {
|
||||
t.Errorf("expected SuccessThreshold 1, got %d", cfg.SuccessThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerStateString(t *testing.T) {
|
||||
tests := []struct {
|
||||
state CircuitState
|
||||
expected string
|
||||
}{
|
||||
{CircuitClosed, "closed"},
|
||||
{CircuitOpen, "open"},
|
||||
{CircuitHalfOpen, "half-open"},
|
||||
{CircuitState(99), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := tt.state.String(); got != tt.expected {
|
||||
t.Errorf("state %d: expected %s, got %s", tt.state, tt.expected, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerClosed(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test", DefaultCircuitBreakerConfig())
|
||||
|
||||
if cb.State() != CircuitClosed {
|
||||
t.Error("expected initial state to be closed")
|
||||
}
|
||||
|
||||
// Successful execution
|
||||
result, err := cb.Execute(func() (interface{}, error) {
|
||||
return "success", nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "success" {
|
||||
t.Errorf("expected 'success', got %v", result)
|
||||
}
|
||||
if cb.State() != CircuitClosed {
|
||||
t.Error("state should still be closed after success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerOpensAfterFailures(t *testing.T) {
|
||||
cfg := CircuitBreakerConfig{
|
||||
FailureThreshold: 2,
|
||||
ResetTimeout: time.Minute,
|
||||
SuccessThreshold: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker("test", cfg)
|
||||
|
||||
testErr := errors.New("test error")
|
||||
|
||||
// First failure
|
||||
_, err := cb.Execute(func() (interface{}, error) {
|
||||
return nil, testErr
|
||||
})
|
||||
if err != testErr {
|
||||
t.Errorf("expected test error, got %v", err)
|
||||
}
|
||||
if cb.State() != CircuitClosed {
|
||||
t.Error("should still be closed after 1 failure")
|
||||
}
|
||||
|
||||
// Second failure - should open circuit
|
||||
_, err = cb.Execute(func() (interface{}, error) {
|
||||
return nil, testErr
|
||||
})
|
||||
if err != testErr {
|
||||
t.Errorf("expected test error, got %v", err)
|
||||
}
|
||||
if cb.State() != CircuitOpen {
|
||||
t.Error("should be open after 2 failures")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerRejectsWhenOpen(t *testing.T) {
|
||||
cfg := CircuitBreakerConfig{
|
||||
FailureThreshold: 1,
|
||||
ResetTimeout: time.Hour, // Long timeout to keep circuit open
|
||||
SuccessThreshold: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker("test", cfg)
|
||||
|
||||
// Open the circuit
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
if cb.State() != CircuitOpen {
|
||||
t.Fatal("circuit should be open")
|
||||
}
|
||||
|
||||
// Next request should be rejected
|
||||
called := false
|
||||
_, err := cb.Execute(func() (interface{}, error) {
|
||||
called = true
|
||||
return "should not run", nil
|
||||
})
|
||||
|
||||
if called {
|
||||
t.Error("function should not have been called when circuit is open")
|
||||
}
|
||||
if err != ErrCircuitOpen {
|
||||
t.Errorf("expected ErrCircuitOpen, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerTransitionsToHalfOpen(t *testing.T) {
|
||||
cfg := CircuitBreakerConfig{
|
||||
FailureThreshold: 1,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
SuccessThreshold: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker("test", cfg)
|
||||
|
||||
// Open the circuit
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
if cb.State() != CircuitOpen {
|
||||
t.Fatal("circuit should be open")
|
||||
}
|
||||
|
||||
// Wait for reset timeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Next request should transition to half-open and execute
|
||||
result, err := cb.Execute(func() (interface{}, error) {
|
||||
return "probe success", nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "probe success" {
|
||||
t.Errorf("expected 'probe success', got %v", result)
|
||||
}
|
||||
if cb.State() != CircuitClosed {
|
||||
t.Error("should be closed after successful probe")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerHalfOpenFailureReopens(t *testing.T) {
|
||||
cfg := CircuitBreakerConfig{
|
||||
FailureThreshold: 1,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
SuccessThreshold: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker("test", cfg)
|
||||
|
||||
// Open the circuit
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
// Wait for reset timeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Probe fails
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("probe failed")
|
||||
})
|
||||
|
||||
if cb.State() != CircuitOpen {
|
||||
t.Error("should be open after probe failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerCaching(t *testing.T) {
|
||||
cfg := CircuitBreakerConfig{
|
||||
FailureThreshold: 1,
|
||||
ResetTimeout: time.Hour,
|
||||
SuccessThreshold: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker("test", cfg)
|
||||
|
||||
// Successful call - caches result
|
||||
result, err := cb.Execute(func() (interface{}, error) {
|
||||
return "cached value", nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "cached value" {
|
||||
t.Fatalf("expected 'cached value', got %v", result)
|
||||
}
|
||||
|
||||
// Open the circuit
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
// Should return cached value when circuit is open
|
||||
result, err = cb.Execute(func() (interface{}, error) {
|
||||
return "should not run", nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected cached result, got error: %v", err)
|
||||
}
|
||||
if result != "cached value" {
|
||||
t.Errorf("expected 'cached value', got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerGetCached(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test", DefaultCircuitBreakerConfig())
|
||||
|
||||
// No cache initially
|
||||
_, ok := cb.GetCached()
|
||||
if ok {
|
||||
t.Error("expected no cached value initially")
|
||||
}
|
||||
|
||||
// Cache a value
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return "test value", nil
|
||||
})
|
||||
|
||||
cached, ok := cb.GetCached()
|
||||
if !ok {
|
||||
t.Error("expected cached value")
|
||||
}
|
||||
if cached != "test value" {
|
||||
t.Errorf("expected 'test value', got %v", cached)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerReset(t *testing.T) {
|
||||
cfg := CircuitBreakerConfig{
|
||||
FailureThreshold: 1,
|
||||
ResetTimeout: time.Hour,
|
||||
SuccessThreshold: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker("test", cfg)
|
||||
|
||||
// Open the circuit
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
if cb.State() != CircuitOpen {
|
||||
t.Fatal("circuit should be open")
|
||||
}
|
||||
|
||||
// Manual reset
|
||||
cb.Reset()
|
||||
|
||||
if cb.State() != CircuitClosed {
|
||||
t.Error("circuit should be closed after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerConcurrency(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test", DefaultCircuitBreakerConfig())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
if n%3 == 0 {
|
||||
return nil, errors.New("fail")
|
||||
}
|
||||
return "success", nil
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Just verify no panics occurred
|
||||
_ = cb.State()
|
||||
}
|
||||
|
||||
func TestGetGitHubCircuitBreaker(t *testing.T) {
|
||||
cb1 := getGitHubCircuitBreaker()
|
||||
cb2 := getGitHubCircuitBreaker()
|
||||
|
||||
if cb1 != cb2 {
|
||||
t.Error("expected singleton circuit breaker")
|
||||
}
|
||||
|
||||
if cb1.name != "github-api" {
|
||||
t.Errorf("expected name 'github-api', got %s", cb1.name)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkCircuitBreakerExecute(b *testing.B) {
|
||||
cb := NewCircuitBreaker("bench", DefaultCircuitBreakerConfig())
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return "result", nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreakerConcurrent(b *testing.B) {
|
||||
cb := NewCircuitBreaker("bench", DefaultCircuitBreakerConfig())
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cb.Execute(func() (interface{}, error) {
|
||||
return "result", nil
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -60,6 +60,7 @@ type wsClient struct {
|
|||
send chan []byte
|
||||
hub *EventHub
|
||||
miners map[string]bool // subscribed miners, "*" for all
|
||||
minersMu sync.RWMutex // protects miners map from concurrent access
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
|
|
@ -143,6 +144,11 @@ func (h *EventHub) Run() {
|
|||
// Send initial state sync if provider is set
|
||||
if stateProvider != nil {
|
||||
go func(c *wsClient) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logging.Error("panic in state sync goroutine", logging.Fields{"panic": r})
|
||||
}
|
||||
}()
|
||||
state := stateProvider()
|
||||
if state != nil {
|
||||
event := Event{
|
||||
|
|
@ -206,7 +212,10 @@ func (h *EventHub) shouldSendToClient(client *wsClient, event Event) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// Check miner subscription for miner events
|
||||
// Check miner subscription for miner events (protected by mutex)
|
||||
client.minersMu.RLock()
|
||||
defer client.minersMu.RUnlock()
|
||||
|
||||
if client.miners == nil || len(client.miners) == 0 {
|
||||
// No subscription filter, send all
|
||||
return true
|
||||
|
|
@ -354,11 +363,13 @@ func (c *wsClient) readPump() {
|
|||
|
||||
switch msg.Type {
|
||||
case "subscribe":
|
||||
// Update miner subscription
|
||||
// Update miner subscription (protected by mutex)
|
||||
c.minersMu.Lock()
|
||||
c.miners = make(map[string]bool)
|
||||
for _, m := range msg.Miners {
|
||||
c.miners[m] = true
|
||||
}
|
||||
c.minersMu.Unlock()
|
||||
logging.Debug("client subscribed to miners", logging.Fields{"miners": msg.Miners})
|
||||
|
||||
case "ping":
|
||||
|
|
|
|||
|
|
@ -130,6 +130,11 @@ func (m *Manager) startDBCleanup() {
|
|||
m.waitGroup.Add(1)
|
||||
go func() {
|
||||
defer m.waitGroup.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logging.Error("panic in database cleanup goroutine", logging.Fields{"panic": r})
|
||||
}
|
||||
}()
|
||||
// Run cleanup once per hour
|
||||
ticker := time.NewTicker(time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
|
@ -523,6 +528,11 @@ func (m *Manager) startStatsCollection() {
|
|||
m.waitGroup.Add(1)
|
||||
go func() {
|
||||
defer m.waitGroup.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logging.Error("panic in stats collection goroutine", logging.Fields{"panic": r})
|
||||
}
|
||||
}()
|
||||
ticker := time.NewTicker(HighResolutionInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
|
@ -570,6 +580,14 @@ func (m *Manager) collectMinerStats() {
|
|||
wg.Add(1)
|
||||
go func(miner Miner, minerType string) {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logging.Error("panic in single miner stats collection", logging.Fields{
|
||||
"panic": r,
|
||||
"miner": miner.GetName(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
m.collectSingleMinerStats(miner, minerType, now, dbEnabled)
|
||||
}(mi.miner, mi.minerType)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,37 @@ type GitHubRelease struct {
|
|||
|
||||
// FetchLatestGitHubVersion fetches the latest release version from a GitHub repository.
|
||||
// It takes the repository owner and name (e.g., "xmrig", "xmrig") and returns the tag name.
|
||||
// Uses a circuit breaker to prevent cascading failures when GitHub API is unavailable.
|
||||
func FetchLatestGitHubVersion(owner, repo string) (string, error) {
|
||||
cb := getGitHubCircuitBreaker()
|
||||
|
||||
result, err := cb.Execute(func() (interface{}, error) {
|
||||
return fetchGitHubVersionDirect(owner, repo)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// If circuit is open, try to return cached value with warning
|
||||
if err == ErrCircuitOpen {
|
||||
if cached, ok := cb.GetCached(); ok {
|
||||
if tagName, ok := cached.(string); ok {
|
||||
return tagName, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("github API unavailable (circuit breaker open): %w", err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
tagName, ok := result.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unexpected result type from circuit breaker")
|
||||
}
|
||||
|
||||
return tagName, nil
|
||||
}
|
||||
|
||||
// fetchGitHubVersionDirect is the actual GitHub API call, wrapped by circuit breaker
|
||||
func fetchGitHubVersionDirect(owner, repo string) (string, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", owner, repo)
|
||||
|
||||
resp, err := getHTTPClient().Get(url)
|
||||
|
|
|
|||
|
|
@ -22,25 +22,30 @@ var debugLogCounter atomic.Int64
|
|||
// debugLogInterval controls how often we log debug messages in hot paths (1 in N)
|
||||
const debugLogInterval = 100
|
||||
|
||||
// DefaultMaxMessageSize is the default maximum message size (1MB)
|
||||
const DefaultMaxMessageSize int64 = 1 << 20 // 1MB
|
||||
|
||||
// TransportConfig configures the WebSocket transport.
|
||||
type TransportConfig struct {
|
||||
ListenAddr string // ":9091" default
|
||||
WSPath string // "/ws" - WebSocket endpoint path
|
||||
TLSCertPath string // Optional TLS for wss://
|
||||
TLSKeyPath string
|
||||
MaxConns int // Maximum concurrent connections
|
||||
PingInterval time.Duration // WebSocket keepalive interval
|
||||
PongTimeout time.Duration // Timeout waiting for pong
|
||||
ListenAddr string // ":9091" default
|
||||
WSPath string // "/ws" - WebSocket endpoint path
|
||||
TLSCertPath string // Optional TLS for wss://
|
||||
TLSKeyPath string
|
||||
MaxConns int // Maximum concurrent connections
|
||||
MaxMessageSize int64 // Maximum message size in bytes (0 = 1MB default)
|
||||
PingInterval time.Duration // WebSocket keepalive interval
|
||||
PongTimeout time.Duration // Timeout waiting for pong
|
||||
}
|
||||
|
||||
// DefaultTransportConfig returns sensible defaults.
|
||||
func DefaultTransportConfig() TransportConfig {
|
||||
return TransportConfig{
|
||||
ListenAddr: ":9091",
|
||||
WSPath: "/ws",
|
||||
MaxConns: 100,
|
||||
PingInterval: 30 * time.Second,
|
||||
PongTimeout: 10 * time.Second,
|
||||
ListenAddr: ":9091",
|
||||
WSPath: "/ws",
|
||||
MaxConns: 100,
|
||||
MaxMessageSize: DefaultMaxMessageSize,
|
||||
PingInterval: 30 * time.Second,
|
||||
PongTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -49,17 +54,18 @@ type MessageHandler func(conn *PeerConnection, msg *Message)
|
|||
|
||||
// Transport manages WebSocket connections with SMSG encryption.
|
||||
type Transport struct {
|
||||
config TransportConfig
|
||||
server *http.Server
|
||||
upgrader websocket.Upgrader
|
||||
conns map[string]*PeerConnection // peer ID -> connection
|
||||
node *NodeManager
|
||||
registry *PeerRegistry
|
||||
handler MessageHandler
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
config TransportConfig
|
||||
server *http.Server
|
||||
upgrader websocket.Upgrader
|
||||
conns map[string]*PeerConnection // peer ID -> connection
|
||||
pendingConns atomic.Int32 // tracks connections during handshake
|
||||
node *NodeManager
|
||||
registry *PeerRegistry
|
||||
handler MessageHandler
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// PeerConnection represents an active connection to a peer.
|
||||
|
|
@ -267,21 +273,34 @@ func (t *Transport) GetConnection(peerID string) *PeerConnection {
|
|||
|
||||
// handleWSUpgrade handles incoming WebSocket connections.
|
||||
func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
||||
// Enforce MaxConns limit
|
||||
// Enforce MaxConns limit (including pending connections during handshake)
|
||||
t.mu.RLock()
|
||||
currentConns := len(t.conns)
|
||||
t.mu.RUnlock()
|
||||
pendingConns := int(t.pendingConns.Load())
|
||||
|
||||
if currentConns >= t.config.MaxConns {
|
||||
totalConns := currentConns + pendingConns
|
||||
if totalConns >= t.config.MaxConns {
|
||||
http.Error(w, "Too many connections", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Track this connection as pending during handshake
|
||||
t.pendingConns.Add(1)
|
||||
defer t.pendingConns.Add(-1)
|
||||
|
||||
conn, err := t.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Apply message size limit during handshake to prevent memory exhaustion
|
||||
maxSize := t.config.MaxMessageSize
|
||||
if maxSize <= 0 {
|
||||
maxSize = DefaultMaxMessageSize
|
||||
}
|
||||
conn.SetReadLimit(maxSize)
|
||||
|
||||
// Set handshake timeout to prevent slow/malicious clients from blocking
|
||||
handshakeTimeout := 10 * time.Second
|
||||
conn.SetReadDeadline(time.Now().Add(handshakeTimeout))
|
||||
|
|
@ -468,6 +487,13 @@ func (t *Transport) readLoop(pc *PeerConnection) {
|
|||
defer t.wg.Done()
|
||||
defer t.removeConnection(pc)
|
||||
|
||||
// Apply message size limit to prevent memory exhaustion attacks
|
||||
maxSize := t.config.MaxMessageSize
|
||||
if maxSize <= 0 {
|
||||
maxSize = DefaultMaxMessageSize
|
||||
}
|
||||
pc.Conn.SetReadLimit(maxSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue