feat: Add rate limiter with cleanup and custom error types
Rate Limiter: - Extract rate limiting to pkg/mining/ratelimiter.go with proper lifecycle - Add Stop() method to gracefully shutdown cleanup goroutine - Add RateLimiter.Middleware() for Gin integration - Add ClientCount() for monitoring - Fix goroutine leak in previous inline implementation Custom Errors: - Add pkg/mining/errors.go with MiningError type - Define error codes: MINER_NOT_FOUND, INSTALL_FAILED, TIMEOUT, etc. - Add predefined error constructors (ErrMinerNotFound, ErrStartFailed, etc.) - Support error chaining with WithCause, WithDetails, WithSuggestion - Include HTTP status codes and retry policies Service: - Add Service.Stop() method for graceful cleanup - Update CLI commands to use context.Background() for Manager methods Tests: - Add comprehensive tests for RateLimiter (token bucket, multi-IP, refill) - Add comprehensive tests for MiningError (codes, status, retryable) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
d1417a1a3c
commit
95ae55e4fa
10 changed files with 734 additions and 78 deletions
|
|
@ -103,7 +103,7 @@ var serveCmd = &cobra.Command{
|
|||
Wallet: cmdArgs[2],
|
||||
LogOutput: true,
|
||||
}
|
||||
miner, err := mgr.StartMiner(minerType, config)
|
||||
miner, err := mgr.StartMiner(context.Background(), minerType, config)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error starting miner: %v\n", err)
|
||||
} else {
|
||||
|
|
@ -137,7 +137,7 @@ var serveCmd = &cobra.Command{
|
|||
fmt.Println("Error: stop command requires miner name (e.g., 'stop xmrig')")
|
||||
} else {
|
||||
minerName := cmdArgs[0]
|
||||
err := mgr.StopMiner(minerName)
|
||||
err := mgr.StopMiner(context.Background(), minerName)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error stopping miner: %v\n", err)
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ Available presets:
|
|||
|
||||
// Stop all simulated miners
|
||||
for _, miner := range mgr.ListMiners() {
|
||||
mgr.StopMiner(miner.GetName())
|
||||
mgr.StopMiner(context.Background(), miner.GetName())
|
||||
}
|
||||
|
||||
fmt.Println("Simulation stopped.")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/Snider/Mining/pkg/mining"
|
||||
|
|
@ -25,7 +26,7 @@ var startCmd = &cobra.Command{
|
|||
Wallet: minerWallet,
|
||||
}
|
||||
|
||||
miner, err := getManager().StartMiner(minerType, config)
|
||||
miner, err := getManager().StartMiner(context.Background(), minerType, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start miner: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
|
@ -16,7 +17,7 @@ var stopCmd = &cobra.Command{
|
|||
minerName := args[0]
|
||||
mgr := getManager()
|
||||
|
||||
if err := mgr.StopMiner(minerName); err != nil {
|
||||
if err := mgr.StopMiner(context.Background(), minerName); err != nil {
|
||||
return fmt.Errorf("failed to stop miner: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
|
@ -17,7 +18,7 @@ var uninstallCmd = &cobra.Command{
|
|||
manager := getManager() // Assuming getManager() provides the singleton manager instance
|
||||
|
||||
fmt.Printf("Uninstalling %s...\n", minerType)
|
||||
if err := manager.UninstallMiner(minerType); err != nil {
|
||||
if err := manager.UninstallMiner(context.Background(), minerType); err != nil {
|
||||
return fmt.Errorf("failed to uninstall miner: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
247
pkg/mining/errors.go
Normal file
247
pkg/mining/errors.go
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Error codes for the mining package
|
||||
const (
|
||||
ErrCodeMinerNotFound = "MINER_NOT_FOUND"
|
||||
ErrCodeMinerExists = "MINER_EXISTS"
|
||||
ErrCodeMinerNotRunning = "MINER_NOT_RUNNING"
|
||||
ErrCodeInstallFailed = "INSTALL_FAILED"
|
||||
ErrCodeStartFailed = "START_FAILED"
|
||||
ErrCodeStopFailed = "STOP_FAILED"
|
||||
ErrCodeInvalidConfig = "INVALID_CONFIG"
|
||||
ErrCodeInvalidInput = "INVALID_INPUT"
|
||||
ErrCodeUnsupportedMiner = "UNSUPPORTED_MINER"
|
||||
ErrCodeNotSupported = "NOT_SUPPORTED"
|
||||
ErrCodeConnectionFailed = "CONNECTION_FAILED"
|
||||
ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE"
|
||||
ErrCodeTimeout = "TIMEOUT"
|
||||
ErrCodeDatabaseError = "DATABASE_ERROR"
|
||||
ErrCodeProfileNotFound = "PROFILE_NOT_FOUND"
|
||||
ErrCodeProfileExists = "PROFILE_EXISTS"
|
||||
ErrCodeInternalError = "INTERNAL_ERROR"
|
||||
)
|
||||
|
||||
// MiningError is a structured error type for the mining package
|
||||
type MiningError struct {
|
||||
Code string // Machine-readable error code
|
||||
Message string // Human-readable message
|
||||
Details string // Technical details (for debugging)
|
||||
Suggestion string // What to do next
|
||||
Retryable bool // Can the client retry?
|
||||
HTTPStatus int // HTTP status code to return
|
||||
Cause error // Underlying error
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *MiningError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("%s: %s (%v)", e.Code, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error
|
||||
func (e *MiningError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// WithCause adds an underlying error
|
||||
func (e *MiningError) WithCause(err error) *MiningError {
|
||||
e.Cause = err
|
||||
return e
|
||||
}
|
||||
|
||||
// WithDetails adds technical details
|
||||
func (e *MiningError) WithDetails(details string) *MiningError {
|
||||
e.Details = details
|
||||
return e
|
||||
}
|
||||
|
||||
// WithSuggestion adds a suggestion for the user
|
||||
func (e *MiningError) WithSuggestion(suggestion string) *MiningError {
|
||||
e.Suggestion = suggestion
|
||||
return e
|
||||
}
|
||||
|
||||
// IsRetryable returns whether the error is retryable
|
||||
func (e *MiningError) IsRetryable() bool {
|
||||
return e.Retryable
|
||||
}
|
||||
|
||||
// StatusCode returns the HTTP status code for this error
|
||||
func (e *MiningError) StatusCode() int {
|
||||
if e.HTTPStatus == 0 {
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
return e.HTTPStatus
|
||||
}
|
||||
|
||||
// NewMiningError creates a new MiningError
|
||||
func NewMiningError(code, message string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
// Predefined error constructors for common errors
|
||||
|
||||
// ErrMinerNotFound creates a miner not found error
|
||||
func ErrMinerNotFound(name string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeMinerNotFound,
|
||||
Message: fmt.Sprintf("miner '%s' not found", name),
|
||||
Suggestion: "Check that the miner name is correct and that it is running",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrMinerExists creates a miner already exists error
|
||||
func ErrMinerExists(name string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeMinerExists,
|
||||
Message: fmt.Sprintf("miner '%s' is already running", name),
|
||||
Suggestion: "Stop the existing miner first or use a different configuration",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusConflict,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrMinerNotRunning creates a miner not running error
|
||||
func ErrMinerNotRunning(name string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeMinerNotRunning,
|
||||
Message: fmt.Sprintf("miner '%s' is not running", name),
|
||||
Suggestion: "Start the miner first before performing this operation",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrInstallFailed creates an installation failed error
|
||||
func ErrInstallFailed(minerType string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeInstallFailed,
|
||||
Message: fmt.Sprintf("failed to install %s", minerType),
|
||||
Suggestion: "Check your internet connection and try again",
|
||||
Retryable: true,
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrStartFailed creates a start failed error
|
||||
func ErrStartFailed(name string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeStartFailed,
|
||||
Message: fmt.Sprintf("failed to start miner '%s'", name),
|
||||
Suggestion: "Check the miner configuration and logs for details",
|
||||
Retryable: true,
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrStopFailed creates a stop failed error
|
||||
func ErrStopFailed(name string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeStopFailed,
|
||||
Message: fmt.Sprintf("failed to stop miner '%s'", name),
|
||||
Suggestion: "The miner process may need to be terminated manually",
|
||||
Retryable: true,
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrInvalidConfig creates an invalid configuration error
|
||||
func ErrInvalidConfig(reason string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeInvalidConfig,
|
||||
Message: fmt.Sprintf("invalid configuration: %s", reason),
|
||||
Suggestion: "Review the configuration and ensure all required fields are provided",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrUnsupportedMiner creates an unsupported miner type error
|
||||
func ErrUnsupportedMiner(minerType string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeUnsupportedMiner,
|
||||
Message: fmt.Sprintf("unsupported miner type: %s", minerType),
|
||||
Suggestion: "Use one of the supported miner types: xmrig, tt-miner",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrConnectionFailed creates a connection failed error
|
||||
func ErrConnectionFailed(target string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeConnectionFailed,
|
||||
Message: fmt.Sprintf("failed to connect to %s", target),
|
||||
Suggestion: "Check network connectivity and try again",
|
||||
Retryable: true,
|
||||
HTTPStatus: http.StatusServiceUnavailable,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrTimeout creates a timeout error
|
||||
func ErrTimeout(operation string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeTimeout,
|
||||
Message: fmt.Sprintf("operation timed out: %s", operation),
|
||||
Suggestion: "The operation is taking longer than expected, try again later",
|
||||
Retryable: true,
|
||||
HTTPStatus: http.StatusGatewayTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrDatabaseError creates a database error
|
||||
func ErrDatabaseError(operation string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeDatabaseError,
|
||||
Message: fmt.Sprintf("database error during %s", operation),
|
||||
Suggestion: "This may be a temporary issue, try again",
|
||||
Retryable: true,
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrProfileNotFound creates a profile not found error
|
||||
func ErrProfileNotFound(id string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeProfileNotFound,
|
||||
Message: fmt.Sprintf("profile '%s' not found", id),
|
||||
Suggestion: "Check that the profile ID is correct",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrProfileExists creates a profile already exists error
|
||||
func ErrProfileExists(name string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeProfileExists,
|
||||
Message: fmt.Sprintf("profile '%s' already exists", name),
|
||||
Suggestion: "Use a different name or update the existing profile",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusConflict,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrInternal creates a generic internal error
|
||||
func ErrInternal(message string) *MiningError {
|
||||
return &MiningError{
|
||||
Code: ErrCodeInternalError,
|
||||
Message: message,
|
||||
Suggestion: "Please report this issue if it persists",
|
||||
Retryable: true,
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
151
pkg/mining/errors_test.go
Normal file
151
pkg/mining/errors_test.go
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMiningError_Error(t *testing.T) {
|
||||
err := NewMiningError(ErrCodeMinerNotFound, "miner not found")
|
||||
expected := "MINER_NOT_FOUND: miner not found"
|
||||
if err.Error() != expected {
|
||||
t.Errorf("Expected %q, got %q", expected, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiningError_ErrorWithCause(t *testing.T) {
|
||||
cause := errors.New("underlying error")
|
||||
err := NewMiningError(ErrCodeStartFailed, "failed to start").WithCause(cause)
|
||||
|
||||
// Should include cause in error message
|
||||
if err.Cause != cause {
|
||||
t.Error("Cause was not set")
|
||||
}
|
||||
|
||||
// Should be unwrappable
|
||||
if errors.Unwrap(err) != cause {
|
||||
t.Error("Unwrap did not return cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiningError_WithDetails(t *testing.T) {
|
||||
err := NewMiningError(ErrCodeInvalidConfig, "invalid config").
|
||||
WithDetails("port must be between 1024 and 65535")
|
||||
|
||||
if err.Details != "port must be between 1024 and 65535" {
|
||||
t.Errorf("Details not set correctly: %s", err.Details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiningError_WithSuggestion(t *testing.T) {
|
||||
err := NewMiningError(ErrCodeConnectionFailed, "connection failed").
|
||||
WithSuggestion("check your network")
|
||||
|
||||
if err.Suggestion != "check your network" {
|
||||
t.Errorf("Suggestion not set correctly: %s", err.Suggestion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiningError_StatusCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *MiningError
|
||||
expected int
|
||||
}{
|
||||
{"default", NewMiningError("TEST", "test"), http.StatusInternalServerError},
|
||||
{"not found", ErrMinerNotFound("test"), http.StatusNotFound},
|
||||
{"conflict", ErrMinerExists("test"), http.StatusConflict},
|
||||
{"bad request", ErrInvalidConfig("bad"), http.StatusBadRequest},
|
||||
{"service unavailable", ErrConnectionFailed("pool"), http.StatusServiceUnavailable},
|
||||
{"timeout", ErrTimeout("operation"), http.StatusGatewayTimeout},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err.StatusCode() != tt.expected {
|
||||
t.Errorf("Expected status %d, got %d", tt.expected, tt.err.StatusCode())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiningError_IsRetryable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *MiningError
|
||||
retryable bool
|
||||
}{
|
||||
{"not found", ErrMinerNotFound("test"), false},
|
||||
{"exists", ErrMinerExists("test"), false},
|
||||
{"invalid config", ErrInvalidConfig("bad"), false},
|
||||
{"install failed", ErrInstallFailed("xmrig"), true},
|
||||
{"start failed", ErrStartFailed("test"), true},
|
||||
{"connection failed", ErrConnectionFailed("pool"), true},
|
||||
{"timeout", ErrTimeout("operation"), true},
|
||||
{"database error", ErrDatabaseError("query"), true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err.IsRetryable() != tt.retryable {
|
||||
t.Errorf("Expected retryable=%v, got %v", tt.retryable, tt.err.IsRetryable())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPredefinedErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *MiningError
|
||||
code string
|
||||
}{
|
||||
{"ErrMinerNotFound", ErrMinerNotFound("test"), ErrCodeMinerNotFound},
|
||||
{"ErrMinerExists", ErrMinerExists("test"), ErrCodeMinerExists},
|
||||
{"ErrMinerNotRunning", ErrMinerNotRunning("test"), ErrCodeMinerNotRunning},
|
||||
{"ErrInstallFailed", ErrInstallFailed("xmrig"), ErrCodeInstallFailed},
|
||||
{"ErrStartFailed", ErrStartFailed("test"), ErrCodeStartFailed},
|
||||
{"ErrStopFailed", ErrStopFailed("test"), ErrCodeStopFailed},
|
||||
{"ErrInvalidConfig", ErrInvalidConfig("bad port"), ErrCodeInvalidConfig},
|
||||
{"ErrUnsupportedMiner", ErrUnsupportedMiner("unknown"), ErrCodeUnsupportedMiner},
|
||||
{"ErrConnectionFailed", ErrConnectionFailed("pool:3333"), ErrCodeConnectionFailed},
|
||||
{"ErrTimeout", ErrTimeout("GetStats"), ErrCodeTimeout},
|
||||
{"ErrDatabaseError", ErrDatabaseError("insert"), ErrCodeDatabaseError},
|
||||
{"ErrProfileNotFound", ErrProfileNotFound("abc123"), ErrCodeProfileNotFound},
|
||||
{"ErrProfileExists", ErrProfileExists("My Profile"), ErrCodeProfileExists},
|
||||
{"ErrInternal", ErrInternal("unexpected error"), ErrCodeInternalError},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err.Code != tt.code {
|
||||
t.Errorf("Expected code %s, got %s", tt.code, tt.err.Code)
|
||||
}
|
||||
if tt.err.Message == "" {
|
||||
t.Error("Message should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiningError_Chaining(t *testing.T) {
|
||||
cause := errors.New("network timeout")
|
||||
err := ErrConnectionFailed("pool:3333").
|
||||
WithCause(cause).
|
||||
WithDetails("timeout after 30s").
|
||||
WithSuggestion("check firewall settings")
|
||||
|
||||
if err.Code != ErrCodeConnectionFailed {
|
||||
t.Errorf("Code changed: %s", err.Code)
|
||||
}
|
||||
if err.Cause != cause {
|
||||
t.Error("Cause not set")
|
||||
}
|
||||
if err.Details != "timeout after 30s" {
|
||||
t.Errorf("Details not set: %s", err.Details)
|
||||
}
|
||||
if err.Suggestion != "check firewall settings" {
|
||||
t.Errorf("Suggestion not set: %s", err.Suggestion)
|
||||
}
|
||||
}
|
||||
119
pkg/mining/ratelimiter.go
Normal file
119
pkg/mining/ratelimiter.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RateLimiter provides token bucket rate limiting per IP address
|
||||
type RateLimiter struct {
|
||||
requestsPerSecond int
|
||||
burst int
|
||||
clients map[string]*rateLimitClient
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
stopped bool
|
||||
}
|
||||
|
||||
type rateLimitClient struct {
|
||||
tokens float64
|
||||
lastCheck time.Time
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter with the specified limits
|
||||
func NewRateLimiter(requestsPerSecond, burst int) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
requestsPerSecond: requestsPerSecond,
|
||||
burst: burst,
|
||||
clients: make(map[string]*rateLimitClient),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanupLoop removes stale clients periodically
|
||||
func (rl *RateLimiter) cleanupLoop() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-rl.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
rl.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes clients that haven't made requests in 5 minutes
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
for ip, c := range rl.clients {
|
||||
if time.Since(c.lastCheck) > 5*time.Minute {
|
||||
delete(rl.clients, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the rate limiter's cleanup goroutine
|
||||
func (rl *RateLimiter) Stop() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if !rl.stopped {
|
||||
close(rl.stopChan)
|
||||
rl.stopped = true
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware returns a Gin middleware handler for rate limiting
|
||||
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
|
||||
rl.mu.Lock()
|
||||
cl, exists := rl.clients[ip]
|
||||
if !exists {
|
||||
cl = &rateLimitClient{tokens: float64(rl.burst), lastCheck: time.Now()}
|
||||
rl.clients[ip] = cl
|
||||
}
|
||||
|
||||
// Token bucket algorithm
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(cl.lastCheck).Seconds()
|
||||
cl.tokens += elapsed * float64(rl.requestsPerSecond)
|
||||
if cl.tokens > float64(rl.burst) {
|
||||
cl.tokens = float64(rl.burst)
|
||||
}
|
||||
cl.lastCheck = now
|
||||
|
||||
if cl.tokens < 1 {
|
||||
rl.mu.Unlock()
|
||||
respondWithError(c, http.StatusTooManyRequests, "RATE_LIMITED",
|
||||
"too many requests", "rate limit exceeded")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
cl.tokens--
|
||||
rl.mu.Unlock()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientCount returns the number of tracked clients (for testing/monitoring)
|
||||
func (rl *RateLimiter) ClientCount() int {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
return len(rl.clients)
|
||||
}
|
||||
194
pkg/mining/ratelimiter_test.go
Normal file
194
pkg/mining/ratelimiter_test.go
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestNewRateLimiter(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 20)
|
||||
if rl == nil {
|
||||
t.Fatal("NewRateLimiter returned nil")
|
||||
}
|
||||
defer rl.Stop()
|
||||
|
||||
if rl.requestsPerSecond != 10 {
|
||||
t.Errorf("Expected requestsPerSecond 10, got %d", rl.requestsPerSecond)
|
||||
}
|
||||
if rl.burst != 20 {
|
||||
t.Errorf("Expected burst 20, got %d", rl.burst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterStop(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 20)
|
||||
|
||||
// Stop should not panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Stop panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
rl.Stop()
|
||||
|
||||
// Calling Stop again should not panic (idempotent)
|
||||
rl.Stop()
|
||||
}
|
||||
|
||||
func TestRateLimiterMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rl := NewRateLimiter(10, 5) // 10 req/s, burst of 5
|
||||
defer rl.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(rl.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
// First 5 requests should succeed (burst)
|
||||
for i := 0; i < 5; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Request %d: expected 200, got %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// 6th request should be rate limited
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("Expected 429 Too Many Requests, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterDifferentIPs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rl := NewRateLimiter(10, 2) // 10 req/s, burst of 2
|
||||
defer rl.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(rl.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
// Exhaust rate limit for IP1
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// IP1 should be rate limited
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("IP1 should be rate limited, got %d", w.Code)
|
||||
}
|
||||
|
||||
// IP2 should still be able to make requests
|
||||
req = httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.2:12345"
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("IP2 should not be rate limited, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterClientCount(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 5)
|
||||
defer rl.Stop()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(rl.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
// Initial count should be 0
|
||||
if count := rl.ClientCount(); count != 0 {
|
||||
t.Errorf("Expected 0 clients, got %d", count)
|
||||
}
|
||||
|
||||
// Make a request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Should have 1 client now
|
||||
if count := rl.ClientCount(); count != 1 {
|
||||
t.Errorf("Expected 1 client, got %d", count)
|
||||
}
|
||||
|
||||
// Make request from different IP
|
||||
req = httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.2:12345"
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Should have 2 clients now
|
||||
if count := rl.ClientCount(); count != 2 {
|
||||
t.Errorf("Expected 2 clients, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterTokenRefill(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rl := NewRateLimiter(100, 1) // 100 req/s, burst of 1 (refills quickly)
|
||||
defer rl.Stop()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(rl.Middleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
// First request succeeds
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("First request should succeed, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Second request should fail (burst exhausted)
|
||||
req = httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("Second request should be rate limited, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Wait for token to refill (at 100 req/s, 1 token takes 10ms)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Third request should succeed (token refilled)
|
||||
req = httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Third request should succeed after refill, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
|
@ -12,7 +12,6 @@ import (
|
|||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
|
|
@ -40,6 +39,7 @@ type Service struct {
|
|||
SwaggerInstanceName string
|
||||
APIBasePath string
|
||||
SwaggerUIPath string
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
// APIError represents a structured error response for the API
|
||||
|
|
@ -51,18 +51,7 @@ type APIError struct {
|
|||
Retryable bool `json:"retryable"` // Can the client retry?
|
||||
}
|
||||
|
||||
// Error codes for API responses
|
||||
const (
|
||||
ErrCodeMinerNotFound = "MINER_NOT_FOUND"
|
||||
ErrCodeProfileNotFound = "PROFILE_NOT_FOUND"
|
||||
ErrCodeInstallFailed = "INSTALL_FAILED"
|
||||
ErrCodeStartFailed = "START_FAILED"
|
||||
ErrCodeStopFailed = "STOP_FAILED"
|
||||
ErrCodeInvalidInput = "INVALID_INPUT"
|
||||
ErrCodeInternalError = "INTERNAL_ERROR"
|
||||
ErrCodeNotSupported = "NOT_SUPPORTED"
|
||||
ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE"
|
||||
)
|
||||
// Error codes are defined in errors.go
|
||||
|
||||
// respondWithError sends a structured error response
|
||||
func respondWithError(c *gin.Context, status int, code string, message string, details string) {
|
||||
|
|
@ -126,64 +115,6 @@ func generateRequestID() string {
|
|||
return fmt.Sprintf("%d-%x", time.Now().UnixMilli(), b[:4])
|
||||
}
|
||||
|
||||
// rateLimitMiddleware provides basic rate limiting per IP address
|
||||
func rateLimitMiddleware(requestsPerSecond int, burst int) gin.HandlerFunc {
|
||||
type client struct {
|
||||
tokens float64
|
||||
lastCheck time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
clients = make(map[string]*client)
|
||||
mu = &sync.RWMutex{}
|
||||
)
|
||||
|
||||
// Cleanup old clients every minute
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
mu.Lock()
|
||||
for ip, c := range clients {
|
||||
if time.Since(c.lastCheck) > 5*time.Minute {
|
||||
delete(clients, ip)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
|
||||
mu.Lock()
|
||||
cl, exists := clients[ip]
|
||||
if !exists {
|
||||
cl = &client{tokens: float64(burst), lastCheck: time.Now()}
|
||||
clients[ip] = cl
|
||||
}
|
||||
|
||||
// Token bucket algorithm
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(cl.lastCheck).Seconds()
|
||||
cl.tokens += elapsed * float64(requestsPerSecond)
|
||||
if cl.tokens > float64(burst) {
|
||||
cl.tokens = float64(burst)
|
||||
}
|
||||
cl.lastCheck = now
|
||||
|
||||
if cl.tokens < 1 {
|
||||
mu.Unlock()
|
||||
respondWithError(c, http.StatusTooManyRequests, "RATE_LIMITED",
|
||||
"too many requests", "rate limit exceeded")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
cl.tokens--
|
||||
mu.Unlock()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocket upgrader for the events endpoint
|
||||
var wsUpgrader = websocket.Upgrader{
|
||||
|
|
@ -324,11 +255,22 @@ func (s *Service) InitRouter() {
|
|||
s.Router.Use(requestIDMiddleware())
|
||||
|
||||
// Add rate limiting (10 requests/second with burst of 20)
|
||||
s.Router.Use(rateLimitMiddleware(10, 20))
|
||||
s.rateLimiter = NewRateLimiter(10, 20)
|
||||
s.Router.Use(s.rateLimiter.Middleware())
|
||||
|
||||
s.SetupRoutes()
|
||||
}
|
||||
|
||||
// Stop gracefully stops the service and cleans up resources
|
||||
func (s *Service) Stop() {
|
||||
if s.rateLimiter != nil {
|
||||
s.rateLimiter.Stop()
|
||||
}
|
||||
if s.EventHub != nil {
|
||||
s.EventHub.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// ServiceStartup initializes the router and starts the HTTP server.
|
||||
// For embedding without a standalone server, use InitRouter() instead.
|
||||
func (s *Service) ServiceStartup(ctx context.Context) error {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue