From 95ae55e4fafc8563c2520f5d3e3e21ff80ceba97 Mon Sep 17 00:00:00 2001 From: snider Date: Wed, 31 Dec 2025 10:56:26 +0000 Subject: [PATCH] feat: Add rate limiter with cleanup and custom error types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rate Limiter: - Extract rate limiting to pkg/mining/ratelimiter.go with proper lifecycle - Add Stop() method to gracefully shutdown cleanup goroutine - Add RateLimiter.Middleware() for Gin integration - Add ClientCount() for monitoring - Fix goroutine leak in previous inline implementation Custom Errors: - Add pkg/mining/errors.go with MiningError type - Define error codes: MINER_NOT_FOUND, INSTALL_FAILED, TIMEOUT, etc. - Add predefined error constructors (ErrMinerNotFound, ErrStartFailed, etc.) - Support error chaining with WithCause, WithDetails, WithSuggestion - Include HTTP status codes and retry policies Service: - Add Service.Stop() method for graceful cleanup - Update CLI commands to use context.Background() for Manager methods Tests: - Add comprehensive tests for RateLimiter (token bucket, multi-IP, refill) - Add comprehensive tests for MiningError (codes, status, retryable) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- cmd/mining/cmd/serve.go | 4 +- cmd/mining/cmd/simulate.go | 2 +- cmd/mining/cmd/start.go | 3 +- cmd/mining/cmd/stop.go | 3 +- cmd/mining/cmd/uninstall.go | 3 +- pkg/mining/errors.go | 247 +++++++++++++++++++++++++++++++++ pkg/mining/errors_test.go | 151 ++++++++++++++++++++ pkg/mining/ratelimiter.go | 119 ++++++++++++++++ pkg/mining/ratelimiter_test.go | 194 ++++++++++++++++++++++++++ pkg/mining/service.go | 86 ++---------- 10 files changed, 734 insertions(+), 78 deletions(-) create mode 100644 pkg/mining/errors.go create mode 100644 pkg/mining/errors_test.go create mode 100644 pkg/mining/ratelimiter.go create mode 100644 pkg/mining/ratelimiter_test.go diff --git a/cmd/mining/cmd/serve.go b/cmd/mining/cmd/serve.go index 14a5fda..5ed8e21 100644 --- a/cmd/mining/cmd/serve.go +++ b/cmd/mining/cmd/serve.go @@ -103,7 +103,7 @@ var serveCmd = &cobra.Command{ Wallet: cmdArgs[2], LogOutput: true, } - miner, err := mgr.StartMiner(minerType, config) + miner, err := mgr.StartMiner(context.Background(), minerType, config) if err != nil { fmt.Fprintf(os.Stderr, "Error starting miner: %v\n", err) } else { @@ -137,7 +137,7 @@ var serveCmd = &cobra.Command{ fmt.Println("Error: stop command requires miner name (e.g., 'stop xmrig')") } else { minerName := cmdArgs[0] - err := mgr.StopMiner(minerName) + err := mgr.StopMiner(context.Background(), minerName) if err != nil { fmt.Fprintf(os.Stderr, "Error stopping miner: %v\n", err) } else { diff --git a/cmd/mining/cmd/simulate.go b/cmd/mining/cmd/simulate.go index 3c1d9b5..83675f1 100644 --- a/cmd/mining/cmd/simulate.go +++ b/cmd/mining/cmd/simulate.go @@ -115,7 +115,7 @@ Available presets: // Stop all simulated miners for _, miner := range mgr.ListMiners() { - mgr.StopMiner(miner.GetName()) + mgr.StopMiner(context.Background(), miner.GetName()) } fmt.Println("Simulation stopped.") diff --git a/cmd/mining/cmd/start.go b/cmd/mining/cmd/start.go index fa67acc..5cbbe41 100644 --- a/cmd/mining/cmd/start.go +++ b/cmd/mining/cmd/start.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "github.com/Snider/Mining/pkg/mining" @@ -25,7 +26,7 @@ var startCmd = &cobra.Command{ Wallet: minerWallet, } - miner, err := getManager().StartMiner(minerType, config) + miner, err := getManager().StartMiner(context.Background(), minerType, config) if err != nil { return fmt.Errorf("failed to start miner: %w", err) } diff --git a/cmd/mining/cmd/stop.go b/cmd/mining/cmd/stop.go index d157b39..ad22e8b 100644 --- a/cmd/mining/cmd/stop.go +++ b/cmd/mining/cmd/stop.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "github.com/spf13/cobra" @@ -16,7 +17,7 @@ var stopCmd = &cobra.Command{ minerName := args[0] mgr := getManager() - if err := mgr.StopMiner(minerName); err != nil { + if err := mgr.StopMiner(context.Background(), minerName); err != nil { return fmt.Errorf("failed to stop miner: %w", err) } diff --git a/cmd/mining/cmd/uninstall.go b/cmd/mining/cmd/uninstall.go index b01a162..cc4eaae 100644 --- a/cmd/mining/cmd/uninstall.go +++ b/cmd/mining/cmd/uninstall.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "github.com/spf13/cobra" @@ -17,7 +18,7 @@ var uninstallCmd = &cobra.Command{ manager := getManager() // Assuming getManager() provides the singleton manager instance fmt.Printf("Uninstalling %s...\n", minerType) - if err := manager.UninstallMiner(minerType); err != nil { + if err := manager.UninstallMiner(context.Background(), minerType); err != nil { return fmt.Errorf("failed to uninstall miner: %w", err) } diff --git a/pkg/mining/errors.go b/pkg/mining/errors.go new file mode 100644 index 0000000..a88a7f3 --- /dev/null +++ b/pkg/mining/errors.go @@ -0,0 +1,247 @@ +package mining + +import ( + "fmt" + "net/http" +) + +// Error codes for the mining package +const ( + ErrCodeMinerNotFound = "MINER_NOT_FOUND" + ErrCodeMinerExists = "MINER_EXISTS" + ErrCodeMinerNotRunning = "MINER_NOT_RUNNING" + ErrCodeInstallFailed = "INSTALL_FAILED" + ErrCodeStartFailed = "START_FAILED" + ErrCodeStopFailed = "STOP_FAILED" + ErrCodeInvalidConfig = "INVALID_CONFIG" + ErrCodeInvalidInput = "INVALID_INPUT" + ErrCodeUnsupportedMiner = "UNSUPPORTED_MINER" + ErrCodeNotSupported = "NOT_SUPPORTED" + ErrCodeConnectionFailed = "CONNECTION_FAILED" + ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE" + ErrCodeTimeout = "TIMEOUT" + ErrCodeDatabaseError = "DATABASE_ERROR" + ErrCodeProfileNotFound = "PROFILE_NOT_FOUND" + ErrCodeProfileExists = "PROFILE_EXISTS" + ErrCodeInternalError = "INTERNAL_ERROR" +) + +// MiningError is a structured error type for the mining package +type MiningError struct { + Code string // Machine-readable error code + Message string // Human-readable message + Details string // Technical details (for debugging) + Suggestion string // What to do next + Retryable bool // Can the client retry? + HTTPStatus int // HTTP status code to return + Cause error // Underlying error +} + +// Error implements the error interface +func (e *MiningError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %s (%v)", e.Code, e.Message, e.Cause) + } + return fmt.Sprintf("%s: %s", e.Code, e.Message) +} + +// Unwrap returns the underlying error +func (e *MiningError) Unwrap() error { + return e.Cause +} + +// WithCause adds an underlying error +func (e *MiningError) WithCause(err error) *MiningError { + e.Cause = err + return e +} + +// WithDetails adds technical details +func (e *MiningError) WithDetails(details string) *MiningError { + e.Details = details + return e +} + +// WithSuggestion adds a suggestion for the user +func (e *MiningError) WithSuggestion(suggestion string) *MiningError { + e.Suggestion = suggestion + return e +} + +// IsRetryable returns whether the error is retryable +func (e *MiningError) IsRetryable() bool { + return e.Retryable +} + +// StatusCode returns the HTTP status code for this error +func (e *MiningError) StatusCode() int { + if e.HTTPStatus == 0 { + return http.StatusInternalServerError + } + return e.HTTPStatus +} + +// NewMiningError creates a new MiningError +func NewMiningError(code, message string) *MiningError { + return &MiningError{ + Code: code, + Message: message, + HTTPStatus: http.StatusInternalServerError, + } +} + +// Predefined error constructors for common errors + +// ErrMinerNotFound creates a miner not found error +func ErrMinerNotFound(name string) *MiningError { + return &MiningError{ + Code: ErrCodeMinerNotFound, + Message: fmt.Sprintf("miner '%s' not found", name), + Suggestion: "Check that the miner name is correct and that it is running", + Retryable: false, + HTTPStatus: http.StatusNotFound, + } +} + +// ErrMinerExists creates a miner already exists error +func ErrMinerExists(name string) *MiningError { + return &MiningError{ + Code: ErrCodeMinerExists, + Message: fmt.Sprintf("miner '%s' is already running", name), + Suggestion: "Stop the existing miner first or use a different configuration", + Retryable: false, + HTTPStatus: http.StatusConflict, + } +} + +// ErrMinerNotRunning creates a miner not running error +func ErrMinerNotRunning(name string) *MiningError { + return &MiningError{ + Code: ErrCodeMinerNotRunning, + Message: fmt.Sprintf("miner '%s' is not running", name), + Suggestion: "Start the miner first before performing this operation", + Retryable: false, + HTTPStatus: http.StatusBadRequest, + } +} + +// ErrInstallFailed creates an installation failed error +func ErrInstallFailed(minerType string) *MiningError { + return &MiningError{ + Code: ErrCodeInstallFailed, + Message: fmt.Sprintf("failed to install %s", minerType), + Suggestion: "Check your internet connection and try again", + Retryable: true, + HTTPStatus: http.StatusInternalServerError, + } +} + +// ErrStartFailed creates a start failed error +func ErrStartFailed(name string) *MiningError { + return &MiningError{ + Code: ErrCodeStartFailed, + Message: fmt.Sprintf("failed to start miner '%s'", name), + Suggestion: "Check the miner configuration and logs for details", + Retryable: true, + HTTPStatus: http.StatusInternalServerError, + } +} + +// ErrStopFailed creates a stop failed error +func ErrStopFailed(name string) *MiningError { + return &MiningError{ + Code: ErrCodeStopFailed, + Message: fmt.Sprintf("failed to stop miner '%s'", name), + Suggestion: "The miner process may need to be terminated manually", + Retryable: true, + HTTPStatus: http.StatusInternalServerError, + } +} + +// ErrInvalidConfig creates an invalid configuration error +func ErrInvalidConfig(reason string) *MiningError { + return &MiningError{ + Code: ErrCodeInvalidConfig, + Message: fmt.Sprintf("invalid configuration: %s", reason), + Suggestion: "Review the configuration and ensure all required fields are provided", + Retryable: false, + HTTPStatus: http.StatusBadRequest, + } +} + +// ErrUnsupportedMiner creates an unsupported miner type error +func ErrUnsupportedMiner(minerType string) *MiningError { + return &MiningError{ + Code: ErrCodeUnsupportedMiner, + Message: fmt.Sprintf("unsupported miner type: %s", minerType), + Suggestion: "Use one of the supported miner types: xmrig, tt-miner", + Retryable: false, + HTTPStatus: http.StatusBadRequest, + } +} + +// ErrConnectionFailed creates a connection failed error +func ErrConnectionFailed(target string) *MiningError { + return &MiningError{ + Code: ErrCodeConnectionFailed, + Message: fmt.Sprintf("failed to connect to %s", target), + Suggestion: "Check network connectivity and try again", + Retryable: true, + HTTPStatus: http.StatusServiceUnavailable, + } +} + +// ErrTimeout creates a timeout error +func ErrTimeout(operation string) *MiningError { + return &MiningError{ + Code: ErrCodeTimeout, + Message: fmt.Sprintf("operation timed out: %s", operation), + Suggestion: "The operation is taking longer than expected, try again later", + Retryable: true, + HTTPStatus: http.StatusGatewayTimeout, + } +} + +// ErrDatabaseError creates a database error +func ErrDatabaseError(operation string) *MiningError { + return &MiningError{ + Code: ErrCodeDatabaseError, + Message: fmt.Sprintf("database error during %s", operation), + Suggestion: "This may be a temporary issue, try again", + Retryable: true, + HTTPStatus: http.StatusInternalServerError, + } +} + +// ErrProfileNotFound creates a profile not found error +func ErrProfileNotFound(id string) *MiningError { + return &MiningError{ + Code: ErrCodeProfileNotFound, + Message: fmt.Sprintf("profile '%s' not found", id), + Suggestion: "Check that the profile ID is correct", + Retryable: false, + HTTPStatus: http.StatusNotFound, + } +} + +// ErrProfileExists creates a profile already exists error +func ErrProfileExists(name string) *MiningError { + return &MiningError{ + Code: ErrCodeProfileExists, + Message: fmt.Sprintf("profile '%s' already exists", name), + Suggestion: "Use a different name or update the existing profile", + Retryable: false, + HTTPStatus: http.StatusConflict, + } +} + +// ErrInternal creates a generic internal error +func ErrInternal(message string) *MiningError { + return &MiningError{ + Code: ErrCodeInternalError, + Message: message, + Suggestion: "Please report this issue if it persists", + Retryable: true, + HTTPStatus: http.StatusInternalServerError, + } +} diff --git a/pkg/mining/errors_test.go b/pkg/mining/errors_test.go new file mode 100644 index 0000000..06f6245 --- /dev/null +++ b/pkg/mining/errors_test.go @@ -0,0 +1,151 @@ +package mining + +import ( + "errors" + "net/http" + "testing" +) + +func TestMiningError_Error(t *testing.T) { + err := NewMiningError(ErrCodeMinerNotFound, "miner not found") + expected := "MINER_NOT_FOUND: miner not found" + if err.Error() != expected { + t.Errorf("Expected %q, got %q", expected, err.Error()) + } +} + +func TestMiningError_ErrorWithCause(t *testing.T) { + cause := errors.New("underlying error") + err := NewMiningError(ErrCodeStartFailed, "failed to start").WithCause(cause) + + // Should include cause in error message + if err.Cause != cause { + t.Error("Cause was not set") + } + + // Should be unwrappable + if errors.Unwrap(err) != cause { + t.Error("Unwrap did not return cause") + } +} + +func TestMiningError_WithDetails(t *testing.T) { + err := NewMiningError(ErrCodeInvalidConfig, "invalid config"). + WithDetails("port must be between 1024 and 65535") + + if err.Details != "port must be between 1024 and 65535" { + t.Errorf("Details not set correctly: %s", err.Details) + } +} + +func TestMiningError_WithSuggestion(t *testing.T) { + err := NewMiningError(ErrCodeConnectionFailed, "connection failed"). + WithSuggestion("check your network") + + if err.Suggestion != "check your network" { + t.Errorf("Suggestion not set correctly: %s", err.Suggestion) + } +} + +func TestMiningError_StatusCode(t *testing.T) { + tests := []struct { + name string + err *MiningError + expected int + }{ + {"default", NewMiningError("TEST", "test"), http.StatusInternalServerError}, + {"not found", ErrMinerNotFound("test"), http.StatusNotFound}, + {"conflict", ErrMinerExists("test"), http.StatusConflict}, + {"bad request", ErrInvalidConfig("bad"), http.StatusBadRequest}, + {"service unavailable", ErrConnectionFailed("pool"), http.StatusServiceUnavailable}, + {"timeout", ErrTimeout("operation"), http.StatusGatewayTimeout}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err.StatusCode() != tt.expected { + t.Errorf("Expected status %d, got %d", tt.expected, tt.err.StatusCode()) + } + }) + } +} + +func TestMiningError_IsRetryable(t *testing.T) { + tests := []struct { + name string + err *MiningError + retryable bool + }{ + {"not found", ErrMinerNotFound("test"), false}, + {"exists", ErrMinerExists("test"), false}, + {"invalid config", ErrInvalidConfig("bad"), false}, + {"install failed", ErrInstallFailed("xmrig"), true}, + {"start failed", ErrStartFailed("test"), true}, + {"connection failed", ErrConnectionFailed("pool"), true}, + {"timeout", ErrTimeout("operation"), true}, + {"database error", ErrDatabaseError("query"), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err.IsRetryable() != tt.retryable { + t.Errorf("Expected retryable=%v, got %v", tt.retryable, tt.err.IsRetryable()) + } + }) + } +} + +func TestPredefinedErrors(t *testing.T) { + tests := []struct { + name string + err *MiningError + code string + }{ + {"ErrMinerNotFound", ErrMinerNotFound("test"), ErrCodeMinerNotFound}, + {"ErrMinerExists", ErrMinerExists("test"), ErrCodeMinerExists}, + {"ErrMinerNotRunning", ErrMinerNotRunning("test"), ErrCodeMinerNotRunning}, + {"ErrInstallFailed", ErrInstallFailed("xmrig"), ErrCodeInstallFailed}, + {"ErrStartFailed", ErrStartFailed("test"), ErrCodeStartFailed}, + {"ErrStopFailed", ErrStopFailed("test"), ErrCodeStopFailed}, + {"ErrInvalidConfig", ErrInvalidConfig("bad port"), ErrCodeInvalidConfig}, + {"ErrUnsupportedMiner", ErrUnsupportedMiner("unknown"), ErrCodeUnsupportedMiner}, + {"ErrConnectionFailed", ErrConnectionFailed("pool:3333"), ErrCodeConnectionFailed}, + {"ErrTimeout", ErrTimeout("GetStats"), ErrCodeTimeout}, + {"ErrDatabaseError", ErrDatabaseError("insert"), ErrCodeDatabaseError}, + {"ErrProfileNotFound", ErrProfileNotFound("abc123"), ErrCodeProfileNotFound}, + {"ErrProfileExists", ErrProfileExists("My Profile"), ErrCodeProfileExists}, + {"ErrInternal", ErrInternal("unexpected error"), ErrCodeInternalError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err.Code != tt.code { + t.Errorf("Expected code %s, got %s", tt.code, tt.err.Code) + } + if tt.err.Message == "" { + t.Error("Message should not be empty") + } + }) + } +} + +func TestMiningError_Chaining(t *testing.T) { + cause := errors.New("network timeout") + err := ErrConnectionFailed("pool:3333"). + WithCause(cause). + WithDetails("timeout after 30s"). + WithSuggestion("check firewall settings") + + if err.Code != ErrCodeConnectionFailed { + t.Errorf("Code changed: %s", err.Code) + } + if err.Cause != cause { + t.Error("Cause not set") + } + if err.Details != "timeout after 30s" { + t.Errorf("Details not set: %s", err.Details) + } + if err.Suggestion != "check firewall settings" { + t.Errorf("Suggestion not set: %s", err.Suggestion) + } +} diff --git a/pkg/mining/ratelimiter.go b/pkg/mining/ratelimiter.go new file mode 100644 index 0000000..15a872c --- /dev/null +++ b/pkg/mining/ratelimiter.go @@ -0,0 +1,119 @@ +package mining + +import ( + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// RateLimiter provides token bucket rate limiting per IP address +type RateLimiter struct { + requestsPerSecond int + burst int + clients map[string]*rateLimitClient + mu sync.RWMutex + stopChan chan struct{} + stopped bool +} + +type rateLimitClient struct { + tokens float64 + lastCheck time.Time +} + +// NewRateLimiter creates a new rate limiter with the specified limits +func NewRateLimiter(requestsPerSecond, burst int) *RateLimiter { + rl := &RateLimiter{ + requestsPerSecond: requestsPerSecond, + burst: burst, + clients: make(map[string]*rateLimitClient), + stopChan: make(chan struct{}), + } + + // Start cleanup goroutine + go rl.cleanupLoop() + + return rl +} + +// cleanupLoop removes stale clients periodically +func (rl *RateLimiter) cleanupLoop() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for { + select { + case <-rl.stopChan: + return + case <-ticker.C: + rl.cleanup() + } + } +} + +// cleanup removes clients that haven't made requests in 5 minutes +func (rl *RateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + for ip, c := range rl.clients { + if time.Since(c.lastCheck) > 5*time.Minute { + delete(rl.clients, ip) + } + } +} + +// Stop stops the rate limiter's cleanup goroutine +func (rl *RateLimiter) Stop() { + rl.mu.Lock() + defer rl.mu.Unlock() + + if !rl.stopped { + close(rl.stopChan) + rl.stopped = true + } +} + +// Middleware returns a Gin middleware handler for rate limiting +func (rl *RateLimiter) Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + ip := c.ClientIP() + + rl.mu.Lock() + cl, exists := rl.clients[ip] + if !exists { + cl = &rateLimitClient{tokens: float64(rl.burst), lastCheck: time.Now()} + rl.clients[ip] = cl + } + + // Token bucket algorithm + now := time.Now() + elapsed := now.Sub(cl.lastCheck).Seconds() + cl.tokens += elapsed * float64(rl.requestsPerSecond) + if cl.tokens > float64(rl.burst) { + cl.tokens = float64(rl.burst) + } + cl.lastCheck = now + + if cl.tokens < 1 { + rl.mu.Unlock() + respondWithError(c, http.StatusTooManyRequests, "RATE_LIMITED", + "too many requests", "rate limit exceeded") + c.Abort() + return + } + + cl.tokens-- + rl.mu.Unlock() + c.Next() + } +} + +// ClientCount returns the number of tracked clients (for testing/monitoring) +func (rl *RateLimiter) ClientCount() int { + rl.mu.RLock() + defer rl.mu.RUnlock() + return len(rl.clients) +} diff --git a/pkg/mining/ratelimiter_test.go b/pkg/mining/ratelimiter_test.go new file mode 100644 index 0000000..9dfa469 --- /dev/null +++ b/pkg/mining/ratelimiter_test.go @@ -0,0 +1,194 @@ +package mining + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" +) + +func TestNewRateLimiter(t *testing.T) { + rl := NewRateLimiter(10, 20) + if rl == nil { + t.Fatal("NewRateLimiter returned nil") + } + defer rl.Stop() + + if rl.requestsPerSecond != 10 { + t.Errorf("Expected requestsPerSecond 10, got %d", rl.requestsPerSecond) + } + if rl.burst != 20 { + t.Errorf("Expected burst 20, got %d", rl.burst) + } +} + +func TestRateLimiterStop(t *testing.T) { + rl := NewRateLimiter(10, 20) + + // Stop should not panic + defer func() { + if r := recover(); r != nil { + t.Errorf("Stop panicked: %v", r) + } + }() + + rl.Stop() + + // Calling Stop again should not panic (idempotent) + rl.Stop() +} + +func TestRateLimiterMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + rl := NewRateLimiter(10, 5) // 10 req/s, burst of 5 + defer rl.Stop() + + router := gin.New() + router.Use(rl.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + // First 5 requests should succeed (burst) + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Request %d: expected 200, got %d", i+1, w.Code) + } + } + + // 6th request should be rate limited + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("Expected 429 Too Many Requests, got %d", w.Code) + } +} + +func TestRateLimiterDifferentIPs(t *testing.T) { + gin.SetMode(gin.TestMode) + rl := NewRateLimiter(10, 2) // 10 req/s, burst of 2 + defer rl.Stop() + + router := gin.New() + router.Use(rl.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + // Exhaust rate limit for IP1 + for i := 0; i < 2; i++ { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } + + // IP1 should be rate limited + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Errorf("IP1 should be rate limited, got %d", w.Code) + } + + // IP2 should still be able to make requests + req = httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.2:12345" + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("IP2 should not be rate limited, got %d", w.Code) + } +} + +func TestRateLimiterClientCount(t *testing.T) { + rl := NewRateLimiter(10, 5) + defer rl.Stop() + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(rl.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + // Initial count should be 0 + if count := rl.ClientCount(); count != 0 { + t.Errorf("Expected 0 clients, got %d", count) + } + + // Make a request + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + // Should have 1 client now + if count := rl.ClientCount(); count != 1 { + t.Errorf("Expected 1 client, got %d", count) + } + + // Make request from different IP + req = httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.2:12345" + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + + // Should have 2 clients now + if count := rl.ClientCount(); count != 2 { + t.Errorf("Expected 2 clients, got %d", count) + } +} + +func TestRateLimiterTokenRefill(t *testing.T) { + gin.SetMode(gin.TestMode) + rl := NewRateLimiter(100, 1) // 100 req/s, burst of 1 (refills quickly) + defer rl.Stop() + + router := gin.New() + router.Use(rl.Middleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + // First request succeeds + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("First request should succeed, got %d", w.Code) + } + + // Second request should fail (burst exhausted) + req = httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Errorf("Second request should be rate limited, got %d", w.Code) + } + + // Wait for token to refill (at 100 req/s, 1 token takes 10ms) + time.Sleep(20 * time.Millisecond) + + // Third request should succeed (token refilled) + req = httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("Third request should succeed after refill, got %d", w.Code) + } +} diff --git a/pkg/mining/service.go b/pkg/mining/service.go index fd491a4..55410d0 100644 --- a/pkg/mining/service.go +++ b/pkg/mining/service.go @@ -12,7 +12,6 @@ import ( "path/filepath" "runtime" "strings" - "sync" "time" "github.com/Masterminds/semver/v3" @@ -40,6 +39,7 @@ type Service struct { SwaggerInstanceName string APIBasePath string SwaggerUIPath string + rateLimiter *RateLimiter } // APIError represents a structured error response for the API @@ -51,18 +51,7 @@ type APIError struct { Retryable bool `json:"retryable"` // Can the client retry? } -// Error codes for API responses -const ( - ErrCodeMinerNotFound = "MINER_NOT_FOUND" - ErrCodeProfileNotFound = "PROFILE_NOT_FOUND" - ErrCodeInstallFailed = "INSTALL_FAILED" - ErrCodeStartFailed = "START_FAILED" - ErrCodeStopFailed = "STOP_FAILED" - ErrCodeInvalidInput = "INVALID_INPUT" - ErrCodeInternalError = "INTERNAL_ERROR" - ErrCodeNotSupported = "NOT_SUPPORTED" - ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE" -) +// Error codes are defined in errors.go // respondWithError sends a structured error response func respondWithError(c *gin.Context, status int, code string, message string, details string) { @@ -126,64 +115,6 @@ func generateRequestID() string { return fmt.Sprintf("%d-%x", time.Now().UnixMilli(), b[:4]) } -// rateLimitMiddleware provides basic rate limiting per IP address -func rateLimitMiddleware(requestsPerSecond int, burst int) gin.HandlerFunc { - type client struct { - tokens float64 - lastCheck time.Time - } - - var ( - clients = make(map[string]*client) - mu = &sync.RWMutex{} - ) - - // Cleanup old clients every minute - go func() { - for { - time.Sleep(time.Minute) - mu.Lock() - for ip, c := range clients { - if time.Since(c.lastCheck) > 5*time.Minute { - delete(clients, ip) - } - } - mu.Unlock() - } - }() - - return func(c *gin.Context) { - ip := c.ClientIP() - - mu.Lock() - cl, exists := clients[ip] - if !exists { - cl = &client{tokens: float64(burst), lastCheck: time.Now()} - clients[ip] = cl - } - - // Token bucket algorithm - now := time.Now() - elapsed := now.Sub(cl.lastCheck).Seconds() - cl.tokens += elapsed * float64(requestsPerSecond) - if cl.tokens > float64(burst) { - cl.tokens = float64(burst) - } - cl.lastCheck = now - - if cl.tokens < 1 { - mu.Unlock() - respondWithError(c, http.StatusTooManyRequests, "RATE_LIMITED", - "too many requests", "rate limit exceeded") - c.Abort() - return - } - - cl.tokens-- - mu.Unlock() - c.Next() - } -} // WebSocket upgrader for the events endpoint var wsUpgrader = websocket.Upgrader{ @@ -324,11 +255,22 @@ func (s *Service) InitRouter() { s.Router.Use(requestIDMiddleware()) // Add rate limiting (10 requests/second with burst of 20) - s.Router.Use(rateLimitMiddleware(10, 20)) + s.rateLimiter = NewRateLimiter(10, 20) + s.Router.Use(s.rateLimiter.Middleware()) s.SetupRoutes() } +// Stop gracefully stops the service and cleans up resources +func (s *Service) Stop() { + if s.rateLimiter != nil { + s.rateLimiter.Stop() + } + if s.EventHub != nil { + s.EventHub.Stop() + } +} + // ServiceStartup initializes the router and starts the HTTP server. // For embedding without a standalone server, use InitRouter() instead. func (s *Service) ServiceStartup(ctx context.Context) error {