diff --git a/pkg/mining/config_manager.go b/pkg/mining/config_manager.go index 84fa9fd..0a2282a 100644 --- a/pkg/mining/config_manager.go +++ b/pkg/mining/config_manager.go @@ -143,3 +143,90 @@ func SaveMinersConfig(cfg *MinersConfig) error { success = true return nil } + +// UpdateMinersConfig atomically loads, modifies, and saves the miners config. +// This prevents race conditions in read-modify-write operations. +func UpdateMinersConfig(fn func(*MinersConfig) error) error { + configMu.Lock() + defer configMu.Unlock() + + configPath, err := getMinersConfigPath() + if err != nil { + return fmt.Errorf("could not determine miners config path: %w", err) + } + + // Load current config + var cfg MinersConfig + data, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + cfg = MinersConfig{ + Miners: []MinerAutostartConfig{}, + Database: defaultDatabaseConfig(), + } + } else { + return fmt.Errorf("failed to read miners config file: %w", err) + } + } else { + if err := json.Unmarshal(data, &cfg); err != nil { + return fmt.Errorf("failed to unmarshal miners config: %w", err) + } + if cfg.Database.RetentionDays == 0 { + cfg.Database = defaultDatabaseConfig() + } + } + + // Apply the modification + if err := fn(&cfg); err != nil { + return err + } + + // Save atomically + dir := filepath.Dir(configPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + newData, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal miners config: %w", err) + } + + tmpFile, err := os.CreateTemp(dir, "miners-config-*.tmp") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + success := false + defer func() { + if !success { + os.Remove(tmpPath) + } + }() + + if _, err := tmpFile.Write(newData); err != nil { + tmpFile.Close() + return fmt.Errorf("failed to write temp file: %w", err) + } + + if err := tmpFile.Sync(); err != nil { + tmpFile.Close() + return fmt.Errorf("failed to sync temp file: %w", err) + } + + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("failed to close temp file: %w", err) + } + + if err := os.Chmod(tmpPath, 0600); err != nil { + return fmt.Errorf("failed to set temp file permissions: %w", err) + } + + if err := os.Rename(tmpPath, configPath); err != nil { + return fmt.Errorf("failed to rename temp file: %w", err) + } + + success = true + return nil +} diff --git a/pkg/mining/events.go b/pkg/mining/events.go index 9f9220e..73bf5ca 100644 --- a/pkg/mining/events.go +++ b/pkg/mining/events.go @@ -81,16 +81,31 @@ type EventHub struct { // Stop signal stop chan struct{} + + // Connection limits + maxConnections int } -// NewEventHub creates a new EventHub +// DefaultMaxConnections is the default maximum WebSocket connections +const DefaultMaxConnections = 100 + +// NewEventHub creates a new EventHub with default settings func NewEventHub() *EventHub { + return NewEventHubWithOptions(DefaultMaxConnections) +} + +// NewEventHubWithOptions creates a new EventHub with custom settings +func NewEventHubWithOptions(maxConnections int) *EventHub { + if maxConnections <= 0 { + maxConnections = DefaultMaxConnections + } return &EventHub{ - clients: make(map[*wsClient]bool), - broadcast: make(chan Event, 256), - register: make(chan *wsClient), - unregister: make(chan *wsClient), - stop: make(chan struct{}), + clients: make(map[*wsClient]bool), + broadcast: make(chan Event, 256), + register: make(chan *wsClient), + unregister: make(chan *wsClient), + stop: make(chan struct{}), + maxConnections: maxConnections, } } @@ -309,8 +324,22 @@ func (c *wsClient) readPump() { } } -// ServeWs handles websocket requests from clients -func (h *EventHub) ServeWs(conn *websocket.Conn) { +// ServeWs handles websocket requests from clients. +// Returns false if the connection was rejected due to limits. +func (h *EventHub) ServeWs(conn *websocket.Conn) bool { + // Check connection limit + h.mu.RLock() + currentCount := len(h.clients) + h.mu.RUnlock() + + if currentCount >= h.maxConnections { + log.Printf("[EventHub] Connection rejected: limit reached (%d/%d)", currentCount, h.maxConnections) + conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseTryAgainLater, "connection limit reached")) + conn.Close() + return false + } + client := &wsClient{ conn: conn, send: make(chan []byte, 256), @@ -323,4 +352,5 @@ func (h *EventHub) ServeWs(conn *websocket.Conn) { // Start read/write pumps go client.writePump() go client.readPump() + return true } diff --git a/pkg/mining/manager.go b/pkg/mining/manager.go index 8f1764e..2ceffa0 100644 --- a/pkg/mining/manager.go +++ b/pkg/mining/manager.go @@ -356,48 +356,40 @@ func (m *Manager) UninstallMiner(minerType string) error { return fmt.Errorf("failed to uninstall miner files: %w", err) } - cfg, err := LoadMinersConfig() - if err != nil { - return fmt.Errorf("failed to load miners config to update uninstall status: %w", err) - } - - var updatedMiners []MinerAutostartConfig - for _, minerCfg := range cfg.Miners { - if !strings.EqualFold(minerCfg.MinerType, minerType) { - updatedMiners = append(updatedMiners, minerCfg) + return UpdateMinersConfig(func(cfg *MinersConfig) error { + var updatedMiners []MinerAutostartConfig + for _, minerCfg := range cfg.Miners { + if !strings.EqualFold(minerCfg.MinerType, minerType) { + updatedMiners = append(updatedMiners, minerCfg) + } } - } - cfg.Miners = updatedMiners - - return SaveMinersConfig(cfg) + cfg.Miners = updatedMiners + return nil + }) } // updateMinerConfig saves the autostart and last-used config for a miner. func (m *Manager) updateMinerConfig(minerType string, autostart bool, config *Config) error { - cfg, err := LoadMinersConfig() - if err != nil { - return err - } - - found := false - for i, minerCfg := range cfg.Miners { - if strings.EqualFold(minerCfg.MinerType, minerType) { - cfg.Miners[i].Autostart = autostart - cfg.Miners[i].Config = config - found = true - break + return UpdateMinersConfig(func(cfg *MinersConfig) error { + found := false + for i, minerCfg := range cfg.Miners { + if strings.EqualFold(minerCfg.MinerType, minerType) { + cfg.Miners[i].Autostart = autostart + cfg.Miners[i].Config = config + found = true + break + } } - } - if !found { - cfg.Miners = append(cfg.Miners, MinerAutostartConfig{ - MinerType: minerType, - Autostart: autostart, - Config: config, - }) - } - - return SaveMinersConfig(cfg) + if !found { + cfg.Miners = append(cfg.Miners, MinerAutostartConfig{ + MinerType: minerType, + Autostart: autostart, + Config: config, + }) + } + return nil + }) } // StopMiner stops a running miner and removes it from the manager. @@ -618,6 +610,9 @@ func (m *Manager) GetMinerHashrateHistory(name string) ([]HashratePoint, error) return miner.GetHashrateHistory(), nil } +// ShutdownTimeout is the maximum time to wait for goroutines during shutdown +const ShutdownTimeout = 10 * time.Second + // Stop stops all running miners, background goroutines, and closes resources. // Safe to call multiple times - subsequent calls are no-ops. func (m *Manager) Stop() { @@ -632,7 +627,20 @@ func (m *Manager) Stop() { m.mu.Unlock() close(m.stopChan) - m.waitGroup.Wait() + + // Wait for goroutines with timeout + done := make(chan struct{}) + go func() { + m.waitGroup.Wait() + close(done) + }() + + select { + case <-done: + log.Printf("All goroutines stopped gracefully") + case <-time.After(ShutdownTimeout): + log.Printf("Warning: shutdown timeout - some goroutines may not have stopped") + } // Close the database if m.dbEnabled { diff --git a/pkg/mining/service.go b/pkg/mining/service.go index cb6b564..9d1e3d2 100644 --- a/pkg/mining/service.go +++ b/pkg/mining/service.go @@ -12,6 +12,7 @@ import ( "path/filepath" "runtime" "strings" + "sync" "time" "github.com/Masterminds/semver/v3" @@ -125,6 +126,65 @@ func generateRequestID() string { return fmt.Sprintf("%d-%x", time.Now().UnixMilli(), b[:4]) } +// rateLimitMiddleware provides basic rate limiting per IP address +func rateLimitMiddleware(requestsPerSecond int, burst int) gin.HandlerFunc { + type client struct { + tokens float64 + lastCheck time.Time + } + + var ( + clients = make(map[string]*client) + mu = &sync.RWMutex{} + ) + + // Cleanup old clients every minute + go func() { + for { + time.Sleep(time.Minute) + mu.Lock() + for ip, c := range clients { + if time.Since(c.lastCheck) > 5*time.Minute { + delete(clients, ip) + } + } + mu.Unlock() + } + }() + + return func(c *gin.Context) { + ip := c.ClientIP() + + mu.Lock() + cl, exists := clients[ip] + if !exists { + cl = &client{tokens: float64(burst), lastCheck: time.Now()} + clients[ip] = cl + } + + // Token bucket algorithm + now := time.Now() + elapsed := now.Sub(cl.lastCheck).Seconds() + cl.tokens += elapsed * float64(requestsPerSecond) + if cl.tokens > float64(burst) { + cl.tokens = float64(burst) + } + cl.lastCheck = now + + if cl.tokens < 1 { + mu.Unlock() + respondWithError(c, http.StatusTooManyRequests, "RATE_LIMITED", + "too many requests", "rate limit exceeded") + c.Abort() + return + } + + cl.tokens-- + mu.Unlock() + c.Next() + } +} + // WebSocket upgrader for the events endpoint var wsUpgrader = websocket.Upgrader{ ReadBufferSize: 1024, @@ -236,6 +296,9 @@ func (s *Service) InitRouter() { // Add X-Request-ID middleware for request tracing s.Router.Use(requestIDMiddleware()) + // Add rate limiting (10 requests/second with burst of 20) + s.Router.Use(rateLimitMiddleware(10, 20)) + s.SetupRoutes() } @@ -954,5 +1017,7 @@ func (s *Service) handleWebSocketEvents(c *gin.Context) { } log.Printf("[WebSocket] New connection from %s", c.Request.RemoteAddr) - s.EventHub.ServeWs(conn) + if !s.EventHub.ServeWs(conn) { + log.Printf("[WebSocket] Connection from %s rejected (limit reached)", c.Request.RemoteAddr) + } }