feat: Add rate limiting, race condition fix, and shutdown improvements

- Add rate limiting middleware (10 req/s with burst of 20)
- Add atomic UpdateMinersConfig to fix config race conditions
- Add WebSocket connection limits (max 100 connections)
- Add graceful shutdown timeout (10s max wait for goroutines)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
snider 2025-12-31 09:51:48 +00:00
parent 9e98f58795
commit 0c8b2d999b
4 changed files with 235 additions and 45 deletions

View file

@ -143,3 +143,90 @@ func SaveMinersConfig(cfg *MinersConfig) error {
success = true
return nil
}
// UpdateMinersConfig atomically loads, modifies, and saves the miners config.
// This prevents race conditions in read-modify-write operations.
func UpdateMinersConfig(fn func(*MinersConfig) error) error {
configMu.Lock()
defer configMu.Unlock()
configPath, err := getMinersConfigPath()
if err != nil {
return fmt.Errorf("could not determine miners config path: %w", err)
}
// Load current config
var cfg MinersConfig
data, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
cfg = MinersConfig{
Miners: []MinerAutostartConfig{},
Database: defaultDatabaseConfig(),
}
} else {
return fmt.Errorf("failed to read miners config file: %w", err)
}
} else {
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("failed to unmarshal miners config: %w", err)
}
if cfg.Database.RetentionDays == 0 {
cfg.Database = defaultDatabaseConfig()
}
}
// Apply the modification
if err := fn(&cfg); err != nil {
return err
}
// Save atomically
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
newData, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal miners config: %w", err)
}
tmpFile, err := os.CreateTemp(dir, "miners-config-*.tmp")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
success := false
defer func() {
if !success {
os.Remove(tmpPath)
}
}()
if _, err := tmpFile.Write(newData); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to write temp file: %w", err)
}
if err := tmpFile.Sync(); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to sync temp file: %w", err)
}
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file: %w", err)
}
if err := os.Chmod(tmpPath, 0600); err != nil {
return fmt.Errorf("failed to set temp file permissions: %w", err)
}
if err := os.Rename(tmpPath, configPath); err != nil {
return fmt.Errorf("failed to rename temp file: %w", err)
}
success = true
return nil
}

View file

@ -81,16 +81,31 @@ type EventHub struct {
// Stop signal
stop chan struct{}
// Connection limits
maxConnections int
}
// NewEventHub creates a new EventHub
// DefaultMaxConnections is the default maximum WebSocket connections
const DefaultMaxConnections = 100
// NewEventHub creates a new EventHub with default settings
func NewEventHub() *EventHub {
return NewEventHubWithOptions(DefaultMaxConnections)
}
// NewEventHubWithOptions creates a new EventHub with custom settings
func NewEventHubWithOptions(maxConnections int) *EventHub {
if maxConnections <= 0 {
maxConnections = DefaultMaxConnections
}
return &EventHub{
clients: make(map[*wsClient]bool),
broadcast: make(chan Event, 256),
register: make(chan *wsClient),
unregister: make(chan *wsClient),
stop: make(chan struct{}),
clients: make(map[*wsClient]bool),
broadcast: make(chan Event, 256),
register: make(chan *wsClient),
unregister: make(chan *wsClient),
stop: make(chan struct{}),
maxConnections: maxConnections,
}
}
@ -309,8 +324,22 @@ func (c *wsClient) readPump() {
}
}
// ServeWs handles websocket requests from clients
func (h *EventHub) ServeWs(conn *websocket.Conn) {
// ServeWs handles websocket requests from clients.
// Returns false if the connection was rejected due to limits.
func (h *EventHub) ServeWs(conn *websocket.Conn) bool {
// Check connection limit
h.mu.RLock()
currentCount := len(h.clients)
h.mu.RUnlock()
if currentCount >= h.maxConnections {
log.Printf("[EventHub] Connection rejected: limit reached (%d/%d)", currentCount, h.maxConnections)
conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseTryAgainLater, "connection limit reached"))
conn.Close()
return false
}
client := &wsClient{
conn: conn,
send: make(chan []byte, 256),
@ -323,4 +352,5 @@ func (h *EventHub) ServeWs(conn *websocket.Conn) {
// Start read/write pumps
go client.writePump()
go client.readPump()
return true
}

View file

@ -356,48 +356,40 @@ func (m *Manager) UninstallMiner(minerType string) error {
return fmt.Errorf("failed to uninstall miner files: %w", err)
}
cfg, err := LoadMinersConfig()
if err != nil {
return fmt.Errorf("failed to load miners config to update uninstall status: %w", err)
}
var updatedMiners []MinerAutostartConfig
for _, minerCfg := range cfg.Miners {
if !strings.EqualFold(minerCfg.MinerType, minerType) {
updatedMiners = append(updatedMiners, minerCfg)
return UpdateMinersConfig(func(cfg *MinersConfig) error {
var updatedMiners []MinerAutostartConfig
for _, minerCfg := range cfg.Miners {
if !strings.EqualFold(minerCfg.MinerType, minerType) {
updatedMiners = append(updatedMiners, minerCfg)
}
}
}
cfg.Miners = updatedMiners
return SaveMinersConfig(cfg)
cfg.Miners = updatedMiners
return nil
})
}
// updateMinerConfig saves the autostart and last-used config for a miner.
func (m *Manager) updateMinerConfig(minerType string, autostart bool, config *Config) error {
cfg, err := LoadMinersConfig()
if err != nil {
return err
}
found := false
for i, minerCfg := range cfg.Miners {
if strings.EqualFold(minerCfg.MinerType, minerType) {
cfg.Miners[i].Autostart = autostart
cfg.Miners[i].Config = config
found = true
break
return UpdateMinersConfig(func(cfg *MinersConfig) error {
found := false
for i, minerCfg := range cfg.Miners {
if strings.EqualFold(minerCfg.MinerType, minerType) {
cfg.Miners[i].Autostart = autostart
cfg.Miners[i].Config = config
found = true
break
}
}
}
if !found {
cfg.Miners = append(cfg.Miners, MinerAutostartConfig{
MinerType: minerType,
Autostart: autostart,
Config: config,
})
}
return SaveMinersConfig(cfg)
if !found {
cfg.Miners = append(cfg.Miners, MinerAutostartConfig{
MinerType: minerType,
Autostart: autostart,
Config: config,
})
}
return nil
})
}
// StopMiner stops a running miner and removes it from the manager.
@ -618,6 +610,9 @@ func (m *Manager) GetMinerHashrateHistory(name string) ([]HashratePoint, error)
return miner.GetHashrateHistory(), nil
}
// ShutdownTimeout is the maximum time to wait for goroutines during shutdown
const ShutdownTimeout = 10 * time.Second
// Stop stops all running miners, background goroutines, and closes resources.
// Safe to call multiple times - subsequent calls are no-ops.
func (m *Manager) Stop() {
@ -632,7 +627,20 @@ func (m *Manager) Stop() {
m.mu.Unlock()
close(m.stopChan)
m.waitGroup.Wait()
// Wait for goroutines with timeout
done := make(chan struct{})
go func() {
m.waitGroup.Wait()
close(done)
}()
select {
case <-done:
log.Printf("All goroutines stopped gracefully")
case <-time.After(ShutdownTimeout):
log.Printf("Warning: shutdown timeout - some goroutines may not have stopped")
}
// Close the database
if m.dbEnabled {

View file

@ -12,6 +12,7 @@ import (
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/Masterminds/semver/v3"
@ -125,6 +126,65 @@ func generateRequestID() string {
return fmt.Sprintf("%d-%x", time.Now().UnixMilli(), b[:4])
}
// rateLimitMiddleware provides basic rate limiting per IP address
func rateLimitMiddleware(requestsPerSecond int, burst int) gin.HandlerFunc {
type client struct {
tokens float64
lastCheck time.Time
}
var (
clients = make(map[string]*client)
mu = &sync.RWMutex{}
)
// Cleanup old clients every minute
go func() {
for {
time.Sleep(time.Minute)
mu.Lock()
for ip, c := range clients {
if time.Since(c.lastCheck) > 5*time.Minute {
delete(clients, ip)
}
}
mu.Unlock()
}
}()
return func(c *gin.Context) {
ip := c.ClientIP()
mu.Lock()
cl, exists := clients[ip]
if !exists {
cl = &client{tokens: float64(burst), lastCheck: time.Now()}
clients[ip] = cl
}
// Token bucket algorithm
now := time.Now()
elapsed := now.Sub(cl.lastCheck).Seconds()
cl.tokens += elapsed * float64(requestsPerSecond)
if cl.tokens > float64(burst) {
cl.tokens = float64(burst)
}
cl.lastCheck = now
if cl.tokens < 1 {
mu.Unlock()
respondWithError(c, http.StatusTooManyRequests, "RATE_LIMITED",
"too many requests", "rate limit exceeded")
c.Abort()
return
}
cl.tokens--
mu.Unlock()
c.Next()
}
}
// WebSocket upgrader for the events endpoint
var wsUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
@ -236,6 +296,9 @@ func (s *Service) InitRouter() {
// Add X-Request-ID middleware for request tracing
s.Router.Use(requestIDMiddleware())
// Add rate limiting (10 requests/second with burst of 20)
s.Router.Use(rateLimitMiddleware(10, 20))
s.SetupRoutes()
}
@ -954,5 +1017,7 @@ func (s *Service) handleWebSocketEvents(c *gin.Context) {
}
log.Printf("[WebSocket] New connection from %s", c.Request.RemoteAddr)
s.EventHub.ServeWs(conn)
if !s.EventHub.ServeWs(conn) {
log.Printf("[WebSocket] Connection from %s rejected (limit reached)", c.Request.RemoteAddr)
}
}