feat: Add rate limiter with cleanup and custom error types

Rate Limiter:
- Extract rate limiting to pkg/mining/ratelimiter.go with proper lifecycle
- Add Stop() method to gracefully shutdown cleanup goroutine
- Add RateLimiter.Middleware() for Gin integration
- Add ClientCount() for monitoring
- Fix goroutine leak in previous inline implementation

Custom Errors:
- Add pkg/mining/errors.go with MiningError type
- Define error codes: MINER_NOT_FOUND, INSTALL_FAILED, TIMEOUT, etc.
- Add predefined error constructors (ErrMinerNotFound, ErrStartFailed, etc.)
- Support error chaining with WithCause, WithDetails, WithSuggestion
- Include HTTP status codes and retry policies

Service:
- Add Service.Stop() method for graceful cleanup
- Update CLI commands to use context.Background() for Manager methods

Tests:
- Add comprehensive tests for RateLimiter (token bucket, multi-IP, refill)
- Add comprehensive tests for MiningError (codes, status, retryable)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
snider 2025-12-31 10:56:26 +00:00
parent d1417a1a3c
commit 95ae55e4fa
10 changed files with 734 additions and 78 deletions

View file

@ -103,7 +103,7 @@ var serveCmd = &cobra.Command{
Wallet: cmdArgs[2],
LogOutput: true,
}
miner, err := mgr.StartMiner(minerType, config)
miner, err := mgr.StartMiner(context.Background(), minerType, config)
if err != nil {
fmt.Fprintf(os.Stderr, "Error starting miner: %v\n", err)
} else {
@ -137,7 +137,7 @@ var serveCmd = &cobra.Command{
fmt.Println("Error: stop command requires miner name (e.g., 'stop xmrig')")
} else {
minerName := cmdArgs[0]
err := mgr.StopMiner(minerName)
err := mgr.StopMiner(context.Background(), minerName)
if err != nil {
fmt.Fprintf(os.Stderr, "Error stopping miner: %v\n", err)
} else {

View file

@ -115,7 +115,7 @@ Available presets:
// Stop all simulated miners
for _, miner := range mgr.ListMiners() {
mgr.StopMiner(miner.GetName())
mgr.StopMiner(context.Background(), miner.GetName())
}
fmt.Println("Simulation stopped.")

View file

@ -1,6 +1,7 @@
package cmd
import (
"context"
"fmt"
"github.com/Snider/Mining/pkg/mining"
@ -25,7 +26,7 @@ var startCmd = &cobra.Command{
Wallet: minerWallet,
}
miner, err := getManager().StartMiner(minerType, config)
miner, err := getManager().StartMiner(context.Background(), minerType, config)
if err != nil {
return fmt.Errorf("failed to start miner: %w", err)
}

View file

@ -1,6 +1,7 @@
package cmd
import (
"context"
"fmt"
"github.com/spf13/cobra"
@ -16,7 +17,7 @@ var stopCmd = &cobra.Command{
minerName := args[0]
mgr := getManager()
if err := mgr.StopMiner(minerName); err != nil {
if err := mgr.StopMiner(context.Background(), minerName); err != nil {
return fmt.Errorf("failed to stop miner: %w", err)
}

View file

@ -1,6 +1,7 @@
package cmd
import (
"context"
"fmt"
"github.com/spf13/cobra"
@ -17,7 +18,7 @@ var uninstallCmd = &cobra.Command{
manager := getManager() // Assuming getManager() provides the singleton manager instance
fmt.Printf("Uninstalling %s...\n", minerType)
if err := manager.UninstallMiner(minerType); err != nil {
if err := manager.UninstallMiner(context.Background(), minerType); err != nil {
return fmt.Errorf("failed to uninstall miner: %w", err)
}

247
pkg/mining/errors.go Normal file
View file

@ -0,0 +1,247 @@
package mining
import (
"fmt"
"net/http"
)
// Error codes for the mining package
const (
ErrCodeMinerNotFound = "MINER_NOT_FOUND"
ErrCodeMinerExists = "MINER_EXISTS"
ErrCodeMinerNotRunning = "MINER_NOT_RUNNING"
ErrCodeInstallFailed = "INSTALL_FAILED"
ErrCodeStartFailed = "START_FAILED"
ErrCodeStopFailed = "STOP_FAILED"
ErrCodeInvalidConfig = "INVALID_CONFIG"
ErrCodeInvalidInput = "INVALID_INPUT"
ErrCodeUnsupportedMiner = "UNSUPPORTED_MINER"
ErrCodeNotSupported = "NOT_SUPPORTED"
ErrCodeConnectionFailed = "CONNECTION_FAILED"
ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE"
ErrCodeTimeout = "TIMEOUT"
ErrCodeDatabaseError = "DATABASE_ERROR"
ErrCodeProfileNotFound = "PROFILE_NOT_FOUND"
ErrCodeProfileExists = "PROFILE_EXISTS"
ErrCodeInternalError = "INTERNAL_ERROR"
)
// MiningError is a structured error type for the mining package
type MiningError struct {
Code string // Machine-readable error code
Message string // Human-readable message
Details string // Technical details (for debugging)
Suggestion string // What to do next
Retryable bool // Can the client retry?
HTTPStatus int // HTTP status code to return
Cause error // Underlying error
}
// Error implements the error interface
func (e *MiningError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("%s: %s (%v)", e.Code, e.Message, e.Cause)
}
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}
// Unwrap returns the underlying error
func (e *MiningError) Unwrap() error {
return e.Cause
}
// WithCause adds an underlying error
func (e *MiningError) WithCause(err error) *MiningError {
e.Cause = err
return e
}
// WithDetails adds technical details
func (e *MiningError) WithDetails(details string) *MiningError {
e.Details = details
return e
}
// WithSuggestion adds a suggestion for the user
func (e *MiningError) WithSuggestion(suggestion string) *MiningError {
e.Suggestion = suggestion
return e
}
// IsRetryable returns whether the error is retryable
func (e *MiningError) IsRetryable() bool {
return e.Retryable
}
// StatusCode returns the HTTP status code for this error
func (e *MiningError) StatusCode() int {
if e.HTTPStatus == 0 {
return http.StatusInternalServerError
}
return e.HTTPStatus
}
// NewMiningError creates a new MiningError
func NewMiningError(code, message string) *MiningError {
return &MiningError{
Code: code,
Message: message,
HTTPStatus: http.StatusInternalServerError,
}
}
// Predefined error constructors for common errors
// ErrMinerNotFound creates a miner not found error
func ErrMinerNotFound(name string) *MiningError {
return &MiningError{
Code: ErrCodeMinerNotFound,
Message: fmt.Sprintf("miner '%s' not found", name),
Suggestion: "Check that the miner name is correct and that it is running",
Retryable: false,
HTTPStatus: http.StatusNotFound,
}
}
// ErrMinerExists creates a miner already exists error
func ErrMinerExists(name string) *MiningError {
return &MiningError{
Code: ErrCodeMinerExists,
Message: fmt.Sprintf("miner '%s' is already running", name),
Suggestion: "Stop the existing miner first or use a different configuration",
Retryable: false,
HTTPStatus: http.StatusConflict,
}
}
// ErrMinerNotRunning creates a miner not running error
func ErrMinerNotRunning(name string) *MiningError {
return &MiningError{
Code: ErrCodeMinerNotRunning,
Message: fmt.Sprintf("miner '%s' is not running", name),
Suggestion: "Start the miner first before performing this operation",
Retryable: false,
HTTPStatus: http.StatusBadRequest,
}
}
// ErrInstallFailed creates an installation failed error
func ErrInstallFailed(minerType string) *MiningError {
return &MiningError{
Code: ErrCodeInstallFailed,
Message: fmt.Sprintf("failed to install %s", minerType),
Suggestion: "Check your internet connection and try again",
Retryable: true,
HTTPStatus: http.StatusInternalServerError,
}
}
// ErrStartFailed creates a start failed error
func ErrStartFailed(name string) *MiningError {
return &MiningError{
Code: ErrCodeStartFailed,
Message: fmt.Sprintf("failed to start miner '%s'", name),
Suggestion: "Check the miner configuration and logs for details",
Retryable: true,
HTTPStatus: http.StatusInternalServerError,
}
}
// ErrStopFailed creates a stop failed error
func ErrStopFailed(name string) *MiningError {
return &MiningError{
Code: ErrCodeStopFailed,
Message: fmt.Sprintf("failed to stop miner '%s'", name),
Suggestion: "The miner process may need to be terminated manually",
Retryable: true,
HTTPStatus: http.StatusInternalServerError,
}
}
// ErrInvalidConfig creates an invalid configuration error
func ErrInvalidConfig(reason string) *MiningError {
return &MiningError{
Code: ErrCodeInvalidConfig,
Message: fmt.Sprintf("invalid configuration: %s", reason),
Suggestion: "Review the configuration and ensure all required fields are provided",
Retryable: false,
HTTPStatus: http.StatusBadRequest,
}
}
// ErrUnsupportedMiner creates an unsupported miner type error
func ErrUnsupportedMiner(minerType string) *MiningError {
return &MiningError{
Code: ErrCodeUnsupportedMiner,
Message: fmt.Sprintf("unsupported miner type: %s", minerType),
Suggestion: "Use one of the supported miner types: xmrig, tt-miner",
Retryable: false,
HTTPStatus: http.StatusBadRequest,
}
}
// ErrConnectionFailed creates a connection failed error
func ErrConnectionFailed(target string) *MiningError {
return &MiningError{
Code: ErrCodeConnectionFailed,
Message: fmt.Sprintf("failed to connect to %s", target),
Suggestion: "Check network connectivity and try again",
Retryable: true,
HTTPStatus: http.StatusServiceUnavailable,
}
}
// ErrTimeout creates a timeout error
func ErrTimeout(operation string) *MiningError {
return &MiningError{
Code: ErrCodeTimeout,
Message: fmt.Sprintf("operation timed out: %s", operation),
Suggestion: "The operation is taking longer than expected, try again later",
Retryable: true,
HTTPStatus: http.StatusGatewayTimeout,
}
}
// ErrDatabaseError creates a database error
func ErrDatabaseError(operation string) *MiningError {
return &MiningError{
Code: ErrCodeDatabaseError,
Message: fmt.Sprintf("database error during %s", operation),
Suggestion: "This may be a temporary issue, try again",
Retryable: true,
HTTPStatus: http.StatusInternalServerError,
}
}
// ErrProfileNotFound creates a profile not found error
func ErrProfileNotFound(id string) *MiningError {
return &MiningError{
Code: ErrCodeProfileNotFound,
Message: fmt.Sprintf("profile '%s' not found", id),
Suggestion: "Check that the profile ID is correct",
Retryable: false,
HTTPStatus: http.StatusNotFound,
}
}
// ErrProfileExists creates a profile already exists error
func ErrProfileExists(name string) *MiningError {
return &MiningError{
Code: ErrCodeProfileExists,
Message: fmt.Sprintf("profile '%s' already exists", name),
Suggestion: "Use a different name or update the existing profile",
Retryable: false,
HTTPStatus: http.StatusConflict,
}
}
// ErrInternal creates a generic internal error
func ErrInternal(message string) *MiningError {
return &MiningError{
Code: ErrCodeInternalError,
Message: message,
Suggestion: "Please report this issue if it persists",
Retryable: true,
HTTPStatus: http.StatusInternalServerError,
}
}

151
pkg/mining/errors_test.go Normal file
View file

@ -0,0 +1,151 @@
package mining
import (
"errors"
"net/http"
"testing"
)
func TestMiningError_Error(t *testing.T) {
err := NewMiningError(ErrCodeMinerNotFound, "miner not found")
expected := "MINER_NOT_FOUND: miner not found"
if err.Error() != expected {
t.Errorf("Expected %q, got %q", expected, err.Error())
}
}
func TestMiningError_ErrorWithCause(t *testing.T) {
cause := errors.New("underlying error")
err := NewMiningError(ErrCodeStartFailed, "failed to start").WithCause(cause)
// Should include cause in error message
if err.Cause != cause {
t.Error("Cause was not set")
}
// Should be unwrappable
if errors.Unwrap(err) != cause {
t.Error("Unwrap did not return cause")
}
}
func TestMiningError_WithDetails(t *testing.T) {
err := NewMiningError(ErrCodeInvalidConfig, "invalid config").
WithDetails("port must be between 1024 and 65535")
if err.Details != "port must be between 1024 and 65535" {
t.Errorf("Details not set correctly: %s", err.Details)
}
}
func TestMiningError_WithSuggestion(t *testing.T) {
err := NewMiningError(ErrCodeConnectionFailed, "connection failed").
WithSuggestion("check your network")
if err.Suggestion != "check your network" {
t.Errorf("Suggestion not set correctly: %s", err.Suggestion)
}
}
func TestMiningError_StatusCode(t *testing.T) {
tests := []struct {
name string
err *MiningError
expected int
}{
{"default", NewMiningError("TEST", "test"), http.StatusInternalServerError},
{"not found", ErrMinerNotFound("test"), http.StatusNotFound},
{"conflict", ErrMinerExists("test"), http.StatusConflict},
{"bad request", ErrInvalidConfig("bad"), http.StatusBadRequest},
{"service unavailable", ErrConnectionFailed("pool"), http.StatusServiceUnavailable},
{"timeout", ErrTimeout("operation"), http.StatusGatewayTimeout},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.err.StatusCode() != tt.expected {
t.Errorf("Expected status %d, got %d", tt.expected, tt.err.StatusCode())
}
})
}
}
func TestMiningError_IsRetryable(t *testing.T) {
tests := []struct {
name string
err *MiningError
retryable bool
}{
{"not found", ErrMinerNotFound("test"), false},
{"exists", ErrMinerExists("test"), false},
{"invalid config", ErrInvalidConfig("bad"), false},
{"install failed", ErrInstallFailed("xmrig"), true},
{"start failed", ErrStartFailed("test"), true},
{"connection failed", ErrConnectionFailed("pool"), true},
{"timeout", ErrTimeout("operation"), true},
{"database error", ErrDatabaseError("query"), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.err.IsRetryable() != tt.retryable {
t.Errorf("Expected retryable=%v, got %v", tt.retryable, tt.err.IsRetryable())
}
})
}
}
func TestPredefinedErrors(t *testing.T) {
tests := []struct {
name string
err *MiningError
code string
}{
{"ErrMinerNotFound", ErrMinerNotFound("test"), ErrCodeMinerNotFound},
{"ErrMinerExists", ErrMinerExists("test"), ErrCodeMinerExists},
{"ErrMinerNotRunning", ErrMinerNotRunning("test"), ErrCodeMinerNotRunning},
{"ErrInstallFailed", ErrInstallFailed("xmrig"), ErrCodeInstallFailed},
{"ErrStartFailed", ErrStartFailed("test"), ErrCodeStartFailed},
{"ErrStopFailed", ErrStopFailed("test"), ErrCodeStopFailed},
{"ErrInvalidConfig", ErrInvalidConfig("bad port"), ErrCodeInvalidConfig},
{"ErrUnsupportedMiner", ErrUnsupportedMiner("unknown"), ErrCodeUnsupportedMiner},
{"ErrConnectionFailed", ErrConnectionFailed("pool:3333"), ErrCodeConnectionFailed},
{"ErrTimeout", ErrTimeout("GetStats"), ErrCodeTimeout},
{"ErrDatabaseError", ErrDatabaseError("insert"), ErrCodeDatabaseError},
{"ErrProfileNotFound", ErrProfileNotFound("abc123"), ErrCodeProfileNotFound},
{"ErrProfileExists", ErrProfileExists("My Profile"), ErrCodeProfileExists},
{"ErrInternal", ErrInternal("unexpected error"), ErrCodeInternalError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.err.Code != tt.code {
t.Errorf("Expected code %s, got %s", tt.code, tt.err.Code)
}
if tt.err.Message == "" {
t.Error("Message should not be empty")
}
})
}
}
func TestMiningError_Chaining(t *testing.T) {
cause := errors.New("network timeout")
err := ErrConnectionFailed("pool:3333").
WithCause(cause).
WithDetails("timeout after 30s").
WithSuggestion("check firewall settings")
if err.Code != ErrCodeConnectionFailed {
t.Errorf("Code changed: %s", err.Code)
}
if err.Cause != cause {
t.Error("Cause not set")
}
if err.Details != "timeout after 30s" {
t.Errorf("Details not set: %s", err.Details)
}
if err.Suggestion != "check firewall settings" {
t.Errorf("Suggestion not set: %s", err.Suggestion)
}
}

119
pkg/mining/ratelimiter.go Normal file
View file

@ -0,0 +1,119 @@
package mining
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// RateLimiter provides token bucket rate limiting per IP address
type RateLimiter struct {
requestsPerSecond int
burst int
clients map[string]*rateLimitClient
mu sync.RWMutex
stopChan chan struct{}
stopped bool
}
type rateLimitClient struct {
tokens float64
lastCheck time.Time
}
// NewRateLimiter creates a new rate limiter with the specified limits
func NewRateLimiter(requestsPerSecond, burst int) *RateLimiter {
rl := &RateLimiter{
requestsPerSecond: requestsPerSecond,
burst: burst,
clients: make(map[string]*rateLimitClient),
stopChan: make(chan struct{}),
}
// Start cleanup goroutine
go rl.cleanupLoop()
return rl
}
// cleanupLoop removes stale clients periodically
func (rl *RateLimiter) cleanupLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-rl.stopChan:
return
case <-ticker.C:
rl.cleanup()
}
}
}
// cleanup removes clients that haven't made requests in 5 minutes
func (rl *RateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
for ip, c := range rl.clients {
if time.Since(c.lastCheck) > 5*time.Minute {
delete(rl.clients, ip)
}
}
}
// Stop stops the rate limiter's cleanup goroutine
func (rl *RateLimiter) Stop() {
rl.mu.Lock()
defer rl.mu.Unlock()
if !rl.stopped {
close(rl.stopChan)
rl.stopped = true
}
}
// Middleware returns a Gin middleware handler for rate limiting
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
rl.mu.Lock()
cl, exists := rl.clients[ip]
if !exists {
cl = &rateLimitClient{tokens: float64(rl.burst), lastCheck: time.Now()}
rl.clients[ip] = cl
}
// Token bucket algorithm
now := time.Now()
elapsed := now.Sub(cl.lastCheck).Seconds()
cl.tokens += elapsed * float64(rl.requestsPerSecond)
if cl.tokens > float64(rl.burst) {
cl.tokens = float64(rl.burst)
}
cl.lastCheck = now
if cl.tokens < 1 {
rl.mu.Unlock()
respondWithError(c, http.StatusTooManyRequests, "RATE_LIMITED",
"too many requests", "rate limit exceeded")
c.Abort()
return
}
cl.tokens--
rl.mu.Unlock()
c.Next()
}
}
// ClientCount returns the number of tracked clients (for testing/monitoring)
func (rl *RateLimiter) ClientCount() int {
rl.mu.RLock()
defer rl.mu.RUnlock()
return len(rl.clients)
}

View file

@ -0,0 +1,194 @@
package mining
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
)
func TestNewRateLimiter(t *testing.T) {
rl := NewRateLimiter(10, 20)
if rl == nil {
t.Fatal("NewRateLimiter returned nil")
}
defer rl.Stop()
if rl.requestsPerSecond != 10 {
t.Errorf("Expected requestsPerSecond 10, got %d", rl.requestsPerSecond)
}
if rl.burst != 20 {
t.Errorf("Expected burst 20, got %d", rl.burst)
}
}
func TestRateLimiterStop(t *testing.T) {
rl := NewRateLimiter(10, 20)
// Stop should not panic
defer func() {
if r := recover(); r != nil {
t.Errorf("Stop panicked: %v", r)
}
}()
rl.Stop()
// Calling Stop again should not panic (idempotent)
rl.Stop()
}
func TestRateLimiterMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
rl := NewRateLimiter(10, 5) // 10 req/s, burst of 5
defer rl.Stop()
router := gin.New()
router.Use(rl.Middleware())
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// First 5 requests should succeed (burst)
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Request %d: expected 200, got %d", i+1, w.Code)
}
}
// 6th request should be rate limited
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("Expected 429 Too Many Requests, got %d", w.Code)
}
}
func TestRateLimiterDifferentIPs(t *testing.T) {
gin.SetMode(gin.TestMode)
rl := NewRateLimiter(10, 2) // 10 req/s, burst of 2
defer rl.Stop()
router := gin.New()
router.Use(rl.Middleware())
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// Exhaust rate limit for IP1
for i := 0; i < 2; i++ {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
// IP1 should be rate limited
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("IP1 should be rate limited, got %d", w.Code)
}
// IP2 should still be able to make requests
req = httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.2:12345"
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("IP2 should not be rate limited, got %d", w.Code)
}
}
func TestRateLimiterClientCount(t *testing.T) {
rl := NewRateLimiter(10, 5)
defer rl.Stop()
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(rl.Middleware())
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// Initial count should be 0
if count := rl.ClientCount(); count != 0 {
t.Errorf("Expected 0 clients, got %d", count)
}
// Make a request
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Should have 1 client now
if count := rl.ClientCount(); count != 1 {
t.Errorf("Expected 1 client, got %d", count)
}
// Make request from different IP
req = httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.2:12345"
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
// Should have 2 clients now
if count := rl.ClientCount(); count != 2 {
t.Errorf("Expected 2 clients, got %d", count)
}
}
func TestRateLimiterTokenRefill(t *testing.T) {
gin.SetMode(gin.TestMode)
rl := NewRateLimiter(100, 1) // 100 req/s, burst of 1 (refills quickly)
defer rl.Stop()
router := gin.New()
router.Use(rl.Middleware())
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// First request succeeds
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("First request should succeed, got %d", w.Code)
}
// Second request should fail (burst exhausted)
req = httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("Second request should be rate limited, got %d", w.Code)
}
// Wait for token to refill (at 100 req/s, 1 token takes 10ms)
time.Sleep(20 * time.Millisecond)
// Third request should succeed (token refilled)
req = httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Third request should succeed after refill, got %d", w.Code)
}
}

View file

@ -12,7 +12,6 @@ import (
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/Masterminds/semver/v3"
@ -40,6 +39,7 @@ type Service struct {
SwaggerInstanceName string
APIBasePath string
SwaggerUIPath string
rateLimiter *RateLimiter
}
// APIError represents a structured error response for the API
@ -51,18 +51,7 @@ type APIError struct {
Retryable bool `json:"retryable"` // Can the client retry?
}
// Error codes for API responses
const (
ErrCodeMinerNotFound = "MINER_NOT_FOUND"
ErrCodeProfileNotFound = "PROFILE_NOT_FOUND"
ErrCodeInstallFailed = "INSTALL_FAILED"
ErrCodeStartFailed = "START_FAILED"
ErrCodeStopFailed = "STOP_FAILED"
ErrCodeInvalidInput = "INVALID_INPUT"
ErrCodeInternalError = "INTERNAL_ERROR"
ErrCodeNotSupported = "NOT_SUPPORTED"
ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE"
)
// Error codes are defined in errors.go
// respondWithError sends a structured error response
func respondWithError(c *gin.Context, status int, code string, message string, details string) {
@ -126,64 +115,6 @@ func generateRequestID() string {
return fmt.Sprintf("%d-%x", time.Now().UnixMilli(), b[:4])
}
// rateLimitMiddleware provides basic rate limiting per IP address
func rateLimitMiddleware(requestsPerSecond int, burst int) gin.HandlerFunc {
type client struct {
tokens float64
lastCheck time.Time
}
var (
clients = make(map[string]*client)
mu = &sync.RWMutex{}
)
// Cleanup old clients every minute
go func() {
for {
time.Sleep(time.Minute)
mu.Lock()
for ip, c := range clients {
if time.Since(c.lastCheck) > 5*time.Minute {
delete(clients, ip)
}
}
mu.Unlock()
}
}()
return func(c *gin.Context) {
ip := c.ClientIP()
mu.Lock()
cl, exists := clients[ip]
if !exists {
cl = &client{tokens: float64(burst), lastCheck: time.Now()}
clients[ip] = cl
}
// Token bucket algorithm
now := time.Now()
elapsed := now.Sub(cl.lastCheck).Seconds()
cl.tokens += elapsed * float64(requestsPerSecond)
if cl.tokens > float64(burst) {
cl.tokens = float64(burst)
}
cl.lastCheck = now
if cl.tokens < 1 {
mu.Unlock()
respondWithError(c, http.StatusTooManyRequests, "RATE_LIMITED",
"too many requests", "rate limit exceeded")
c.Abort()
return
}
cl.tokens--
mu.Unlock()
c.Next()
}
}
// WebSocket upgrader for the events endpoint
var wsUpgrader = websocket.Upgrader{
@ -324,11 +255,22 @@ func (s *Service) InitRouter() {
s.Router.Use(requestIDMiddleware())
// Add rate limiting (10 requests/second with burst of 20)
s.Router.Use(rateLimitMiddleware(10, 20))
s.rateLimiter = NewRateLimiter(10, 20)
s.Router.Use(s.rateLimiter.Middleware())
s.SetupRoutes()
}
// Stop gracefully stops the service and cleans up resources
func (s *Service) Stop() {
if s.rateLimiter != nil {
s.rateLimiter.Stop()
}
if s.EventHub != nil {
s.EventHub.Stop()
}
}
// ServiceStartup initializes the router and starts the HTTP server.
// For embedding without a standalone server, use InitRouter() instead.
func (s *Service) ServiceStartup(ctx context.Context) error {