feat: Add rate limiting, race condition fix, and shutdown improvements
- Add rate limiting middleware (10 req/s with burst of 20) - Add atomic UpdateMinersConfig to fix config race conditions - Add WebSocket connection limits (max 100 connections) - Add graceful shutdown timeout (10s max wait for goroutines) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
9e98f58795
commit
0c8b2d999b
4 changed files with 235 additions and 45 deletions
|
|
@ -143,3 +143,90 @@ func SaveMinersConfig(cfg *MinersConfig) error {
|
|||
success = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateMinersConfig atomically loads, modifies, and saves the miners config.
|
||||
// This prevents race conditions in read-modify-write operations.
|
||||
func UpdateMinersConfig(fn func(*MinersConfig) error) error {
|
||||
configMu.Lock()
|
||||
defer configMu.Unlock()
|
||||
|
||||
configPath, err := getMinersConfigPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not determine miners config path: %w", err)
|
||||
}
|
||||
|
||||
// Load current config
|
||||
var cfg MinersConfig
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
cfg = MinersConfig{
|
||||
Miners: []MinerAutostartConfig{},
|
||||
Database: defaultDatabaseConfig(),
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to read miners config file: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal miners config: %w", err)
|
||||
}
|
||||
if cfg.Database.RetentionDays == 0 {
|
||||
cfg.Database = defaultDatabaseConfig()
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the modification
|
||||
if err := fn(&cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save atomically
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
newData, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal miners config: %w", err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp(dir, "miners-config-*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := tmpFile.Write(newData); err != nil {
|
||||
tmpFile.Close()
|
||||
return fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Sync(); err != nil {
|
||||
tmpFile.Close()
|
||||
return fmt.Errorf("failed to sync temp file: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tmpPath, 0600); err != nil {
|
||||
return fmt.Errorf("failed to set temp file permissions: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, configPath); err != nil {
|
||||
return fmt.Errorf("failed to rename temp file: %w", err)
|
||||
}
|
||||
|
||||
success = true
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,16 +81,31 @@ type EventHub struct {
|
|||
|
||||
// Stop signal
|
||||
stop chan struct{}
|
||||
|
||||
// Connection limits
|
||||
maxConnections int
|
||||
}
|
||||
|
||||
// NewEventHub creates a new EventHub
|
||||
// DefaultMaxConnections is the default maximum WebSocket connections
|
||||
const DefaultMaxConnections = 100
|
||||
|
||||
// NewEventHub creates a new EventHub with default settings
|
||||
func NewEventHub() *EventHub {
|
||||
return NewEventHubWithOptions(DefaultMaxConnections)
|
||||
}
|
||||
|
||||
// NewEventHubWithOptions creates a new EventHub with custom settings
|
||||
func NewEventHubWithOptions(maxConnections int) *EventHub {
|
||||
if maxConnections <= 0 {
|
||||
maxConnections = DefaultMaxConnections
|
||||
}
|
||||
return &EventHub{
|
||||
clients: make(map[*wsClient]bool),
|
||||
broadcast: make(chan Event, 256),
|
||||
register: make(chan *wsClient),
|
||||
unregister: make(chan *wsClient),
|
||||
stop: make(chan struct{}),
|
||||
clients: make(map[*wsClient]bool),
|
||||
broadcast: make(chan Event, 256),
|
||||
register: make(chan *wsClient),
|
||||
unregister: make(chan *wsClient),
|
||||
stop: make(chan struct{}),
|
||||
maxConnections: maxConnections,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -309,8 +324,22 @@ func (c *wsClient) readPump() {
|
|||
}
|
||||
}
|
||||
|
||||
// ServeWs handles websocket requests from clients
|
||||
func (h *EventHub) ServeWs(conn *websocket.Conn) {
|
||||
// ServeWs handles websocket requests from clients.
|
||||
// Returns false if the connection was rejected due to limits.
|
||||
func (h *EventHub) ServeWs(conn *websocket.Conn) bool {
|
||||
// Check connection limit
|
||||
h.mu.RLock()
|
||||
currentCount := len(h.clients)
|
||||
h.mu.RUnlock()
|
||||
|
||||
if currentCount >= h.maxConnections {
|
||||
log.Printf("[EventHub] Connection rejected: limit reached (%d/%d)", currentCount, h.maxConnections)
|
||||
conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseTryAgainLater, "connection limit reached"))
|
||||
conn.Close()
|
||||
return false
|
||||
}
|
||||
|
||||
client := &wsClient{
|
||||
conn: conn,
|
||||
send: make(chan []byte, 256),
|
||||
|
|
@ -323,4 +352,5 @@ func (h *EventHub) ServeWs(conn *websocket.Conn) {
|
|||
// Start read/write pumps
|
||||
go client.writePump()
|
||||
go client.readPump()
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -356,48 +356,40 @@ func (m *Manager) UninstallMiner(minerType string) error {
|
|||
return fmt.Errorf("failed to uninstall miner files: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadMinersConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load miners config to update uninstall status: %w", err)
|
||||
}
|
||||
|
||||
var updatedMiners []MinerAutostartConfig
|
||||
for _, minerCfg := range cfg.Miners {
|
||||
if !strings.EqualFold(minerCfg.MinerType, minerType) {
|
||||
updatedMiners = append(updatedMiners, minerCfg)
|
||||
return UpdateMinersConfig(func(cfg *MinersConfig) error {
|
||||
var updatedMiners []MinerAutostartConfig
|
||||
for _, minerCfg := range cfg.Miners {
|
||||
if !strings.EqualFold(minerCfg.MinerType, minerType) {
|
||||
updatedMiners = append(updatedMiners, minerCfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
cfg.Miners = updatedMiners
|
||||
|
||||
return SaveMinersConfig(cfg)
|
||||
cfg.Miners = updatedMiners
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// updateMinerConfig saves the autostart and last-used config for a miner.
|
||||
func (m *Manager) updateMinerConfig(minerType string, autostart bool, config *Config) error {
|
||||
cfg, err := LoadMinersConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
found := false
|
||||
for i, minerCfg := range cfg.Miners {
|
||||
if strings.EqualFold(minerCfg.MinerType, minerType) {
|
||||
cfg.Miners[i].Autostart = autostart
|
||||
cfg.Miners[i].Config = config
|
||||
found = true
|
||||
break
|
||||
return UpdateMinersConfig(func(cfg *MinersConfig) error {
|
||||
found := false
|
||||
for i, minerCfg := range cfg.Miners {
|
||||
if strings.EqualFold(minerCfg.MinerType, minerType) {
|
||||
cfg.Miners[i].Autostart = autostart
|
||||
cfg.Miners[i].Config = config
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
cfg.Miners = append(cfg.Miners, MinerAutostartConfig{
|
||||
MinerType: minerType,
|
||||
Autostart: autostart,
|
||||
Config: config,
|
||||
})
|
||||
}
|
||||
|
||||
return SaveMinersConfig(cfg)
|
||||
if !found {
|
||||
cfg.Miners = append(cfg.Miners, MinerAutostartConfig{
|
||||
MinerType: minerType,
|
||||
Autostart: autostart,
|
||||
Config: config,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// StopMiner stops a running miner and removes it from the manager.
|
||||
|
|
@ -618,6 +610,9 @@ func (m *Manager) GetMinerHashrateHistory(name string) ([]HashratePoint, error)
|
|||
return miner.GetHashrateHistory(), nil
|
||||
}
|
||||
|
||||
// ShutdownTimeout is the maximum time to wait for goroutines during shutdown
|
||||
const ShutdownTimeout = 10 * time.Second
|
||||
|
||||
// Stop stops all running miners, background goroutines, and closes resources.
|
||||
// Safe to call multiple times - subsequent calls are no-ops.
|
||||
func (m *Manager) Stop() {
|
||||
|
|
@ -632,7 +627,20 @@ func (m *Manager) Stop() {
|
|||
m.mu.Unlock()
|
||||
|
||||
close(m.stopChan)
|
||||
m.waitGroup.Wait()
|
||||
|
||||
// Wait for goroutines with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
m.waitGroup.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Printf("All goroutines stopped gracefully")
|
||||
case <-time.After(ShutdownTimeout):
|
||||
log.Printf("Warning: shutdown timeout - some goroutines may not have stopped")
|
||||
}
|
||||
|
||||
// Close the database
|
||||
if m.dbEnabled {
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
|
|
@ -125,6 +126,65 @@ func generateRequestID() string {
|
|||
return fmt.Sprintf("%d-%x", time.Now().UnixMilli(), b[:4])
|
||||
}
|
||||
|
||||
// rateLimitMiddleware provides basic rate limiting per IP address
|
||||
func rateLimitMiddleware(requestsPerSecond int, burst int) gin.HandlerFunc {
|
||||
type client struct {
|
||||
tokens float64
|
||||
lastCheck time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
clients = make(map[string]*client)
|
||||
mu = &sync.RWMutex{}
|
||||
)
|
||||
|
||||
// Cleanup old clients every minute
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
mu.Lock()
|
||||
for ip, c := range clients {
|
||||
if time.Since(c.lastCheck) > 5*time.Minute {
|
||||
delete(clients, ip)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
|
||||
mu.Lock()
|
||||
cl, exists := clients[ip]
|
||||
if !exists {
|
||||
cl = &client{tokens: float64(burst), lastCheck: time.Now()}
|
||||
clients[ip] = cl
|
||||
}
|
||||
|
||||
// Token bucket algorithm
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(cl.lastCheck).Seconds()
|
||||
cl.tokens += elapsed * float64(requestsPerSecond)
|
||||
if cl.tokens > float64(burst) {
|
||||
cl.tokens = float64(burst)
|
||||
}
|
||||
cl.lastCheck = now
|
||||
|
||||
if cl.tokens < 1 {
|
||||
mu.Unlock()
|
||||
respondWithError(c, http.StatusTooManyRequests, "RATE_LIMITED",
|
||||
"too many requests", "rate limit exceeded")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
cl.tokens--
|
||||
mu.Unlock()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocket upgrader for the events endpoint
|
||||
var wsUpgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
|
|
@ -236,6 +296,9 @@ func (s *Service) InitRouter() {
|
|||
// Add X-Request-ID middleware for request tracing
|
||||
s.Router.Use(requestIDMiddleware())
|
||||
|
||||
// Add rate limiting (10 requests/second with burst of 20)
|
||||
s.Router.Use(rateLimitMiddleware(10, 20))
|
||||
|
||||
s.SetupRoutes()
|
||||
}
|
||||
|
||||
|
|
@ -954,5 +1017,7 @@ func (s *Service) handleWebSocketEvents(c *gin.Context) {
|
|||
}
|
||||
|
||||
log.Printf("[WebSocket] New connection from %s", c.Request.RemoteAddr)
|
||||
s.EventHub.ServeWs(conn)
|
||||
if !s.EventHub.ServeWs(conn) {
|
||||
log.Printf("[WebSocket] Connection from %s rejected (limit reached)", c.Request.RemoteAddr)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue