refactor: Add reliability fixes and architecture improvements

Reliability fixes:
- Fix HTTP response body drainage in xmrig, ttminer, miner
- Fix database init race condition (nil before close)
- Fix empty minerType bug in P2P StartMinerPayload
- Add context timeout to InsertHashratePoint (5s default)

Architecture improvements:
- Extract HashrateStore interface with DefaultStore/NopStore
- Create ServiceContainer for centralized initialization
- Extract protocol response handler (ValidateResponse, ParseResponse)
- Create generic FileRepository[T] with atomic writes

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
snider 2025-12-31 12:43:46 +00:00
parent 89f74aebff
commit 34f860309f
25 changed files with 1553 additions and 88 deletions

View file

@ -82,10 +82,11 @@ var remoteStartCmd = &cobra.Command{
Long: `Start a miner on a remote peer using a profile.`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
profileID, _ := cmd.Flags().GetString("profile")
if profileID == "" {
return fmt.Errorf("--profile is required")
minerType, _ := cmd.Flags().GetString("type")
if minerType == "" {
return fmt.Errorf("--type is required (e.g., xmrig, tt-miner)")
}
profileID, _ := cmd.Flags().GetString("profile")
peerID := args[0]
peer := findPeerByPartialID(peerID)
@ -98,8 +99,8 @@ var remoteStartCmd = &cobra.Command{
return err
}
fmt.Printf("Starting miner on %s with profile %s...\n", peer.Name, profileID)
if err := ctrl.StartRemoteMiner(peer.ID, profileID, nil); err != nil {
fmt.Printf("Starting %s miner on %s with profile %s...\n", minerType, peer.Name, profileID)
if err := ctrl.StartRemoteMiner(peer.ID, minerType, profileID, nil); err != nil {
return fmt.Errorf("failed to start miner: %w", err)
}
@ -298,6 +299,7 @@ func init() {
// remote start
remoteCmd.AddCommand(remoteStartCmd)
remoteStartCmd.Flags().StringP("profile", "p", "", "Profile ID to use for starting the miner")
remoteStartCmd.Flags().StringP("type", "t", "", "Miner type (e.g., xmrig, tt-miner)")
// remote stop
remoteCmd.AddCommand(remoteStopCmd)

View file

@ -77,8 +77,10 @@ func Initialize(cfg Config) error {
// Create tables
if err := createTables(); err != nil {
db.Close()
// Nil out global before closing to prevent use of closed connection
closingDB := db
db = nil
closingDB.Close()
return fmt.Errorf("failed to create tables: %w", err)
}

View file

@ -50,7 +50,7 @@ func TestConcurrentHashrateInserts(t *testing.T) {
Timestamp: time.Now().Add(time.Duration(-j) * time.Second),
Hashrate: 1000 + minerIndex*100 + j,
}
err := InsertHashratePoint(minerName, minerType, point, ResolutionHigh)
err := InsertHashratePoint(nil, minerName, minerType, point, ResolutionHigh)
if err != nil {
t.Errorf("Insert error for %s: %v", minerName, err)
}
@ -95,7 +95,7 @@ func TestConcurrentInsertAndQuery(t *testing.T) {
Timestamp: time.Now(),
Hashrate: 1000 + i,
}
InsertHashratePoint("concurrent-test", "xmrig", point, ResolutionHigh)
InsertHashratePoint(nil, "concurrent-test", "xmrig", point, ResolutionHigh)
time.Sleep(time.Millisecond)
}
}
@ -149,13 +149,13 @@ func TestConcurrentInsertAndCleanup(t *testing.T) {
Timestamp: time.Now().AddDate(0, 0, -10), // 10 days old
Hashrate: 500 + i,
}
InsertHashratePoint("cleanup-test", "xmrig", oldPoint, ResolutionHigh)
InsertHashratePoint(nil, "cleanup-test", "xmrig", oldPoint, ResolutionHigh)
newPoint := HashratePoint{
Timestamp: time.Now(),
Hashrate: 1000 + i,
}
InsertHashratePoint("cleanup-test", "xmrig", newPoint, ResolutionHigh)
InsertHashratePoint(nil, "cleanup-test", "xmrig", newPoint, ResolutionHigh)
time.Sleep(time.Millisecond)
}
}
@ -197,7 +197,7 @@ func TestConcurrentStats(t *testing.T) {
Timestamp: time.Now().Add(time.Duration(-i) * time.Second),
Hashrate: 1000 + i*10,
}
InsertHashratePoint(minerName, "xmrig", point, ResolutionHigh)
InsertHashratePoint(nil, minerName, "xmrig", point, ResolutionHigh)
}
var wg sync.WaitGroup
@ -238,7 +238,7 @@ func TestConcurrentGetAllStats(t *testing.T) {
Timestamp: time.Now().Add(time.Duration(-i) * time.Second),
Hashrate: 1000 + m*100 + i,
}
InsertHashratePoint(minerName, "xmrig", point, ResolutionHigh)
InsertHashratePoint(nil, minerName, "xmrig", point, ResolutionHigh)
}
}
@ -267,7 +267,7 @@ func TestConcurrentGetAllStats(t *testing.T) {
Timestamp: time.Now(),
Hashrate: 2000 + i,
}
InsertHashratePoint("all-stats-new", "xmrig", point, ResolutionHigh)
InsertHashratePoint(nil, "all-stats-new", "xmrig", point, ResolutionHigh)
}
}()

View file

@ -77,7 +77,7 @@ func TestHashrateStorage(t *testing.T) {
}
for _, p := range points {
if err := InsertHashratePoint(minerName, minerType, p, ResolutionHigh); err != nil {
if err := InsertHashratePoint(nil, minerName, minerType, p, ResolutionHigh); err != nil {
t.Fatalf("Failed to store hashrate point: %v", err)
}
}
@ -109,7 +109,7 @@ func TestGetHashrateStats(t *testing.T) {
}
for _, p := range points {
if err := InsertHashratePoint(minerName, minerType, p, ResolutionHigh); err != nil {
if err := InsertHashratePoint(nil, minerName, minerType, p, ResolutionHigh); err != nil {
t.Fatalf("Failed to store point: %v", err)
}
}
@ -175,13 +175,13 @@ func TestCleanupRetention(t *testing.T) {
}
// Insert all points
if err := InsertHashratePoint(minerName, minerType, oldPoint, ResolutionHigh); err != nil {
if err := InsertHashratePoint(nil, minerName, minerType, oldPoint, ResolutionHigh); err != nil {
t.Fatalf("Failed to insert old point: %v", err)
}
if err := InsertHashratePoint(minerName, minerType, midPoint, ResolutionHigh); err != nil {
if err := InsertHashratePoint(nil, minerName, minerType, midPoint, ResolutionHigh); err != nil {
t.Fatalf("Failed to insert mid point: %v", err)
}
if err := InsertHashratePoint(minerName, minerType, newPoint, ResolutionHigh); err != nil {
if err := InsertHashratePoint(nil, minerName, minerType, newPoint, ResolutionHigh); err != nil {
t.Fatalf("Failed to insert new point: %v", err)
}
@ -238,7 +238,7 @@ func TestGetHashrateHistoryTimeRange(t *testing.T) {
Timestamp: now.Add(offset),
Hashrate: 1000 + i*100,
}
if err := InsertHashratePoint(minerName, minerType, point, ResolutionHigh); err != nil {
if err := InsertHashratePoint(nil, minerName, minerType, point, ResolutionHigh); err != nil {
t.Fatalf("Failed to insert point: %v", err)
}
}
@ -291,7 +291,7 @@ func TestMultipleMinerStats(t *testing.T) {
Timestamp: now.Add(time.Duration(-i) * time.Minute),
Hashrate: hr,
}
if err := InsertHashratePoint(m.name, "xmrig", point, ResolutionHigh); err != nil {
if err := InsertHashratePoint(nil, m.name, "xmrig", point, ResolutionHigh); err != nil {
t.Fatalf("Failed to insert point for %s: %v", m.name, err)
}
}

View file

@ -1,6 +1,7 @@
package database
import (
"context"
"fmt"
"time"
@ -47,8 +48,12 @@ type HashratePoint struct {
Hashrate int `json:"hashrate"`
}
// InsertHashratePoint stores a hashrate measurement in the database
func InsertHashratePoint(minerName, minerType string, point HashratePoint, resolution Resolution) error {
// dbInsertTimeout is the maximum time to wait for a database insert operation
const dbInsertTimeout = 5 * time.Second
// InsertHashratePoint stores a hashrate measurement in the database.
// If ctx is nil, a default timeout context will be used.
func InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error {
dbMu.RLock()
defer dbMu.RUnlock()
@ -56,7 +61,14 @@ func InsertHashratePoint(minerName, minerType string, point HashratePoint, resol
return nil // DB not enabled, silently skip
}
_, err := db.Exec(`
// Use provided context or create one with default timeout
if ctx == nil {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), dbInsertTimeout)
defer cancel()
}
_, err := db.ExecContext(ctx, `
INSERT INTO hashrate_history (miner_name, miner_type, timestamp, hashrate, resolution)
VALUES (?, ?, ?, ?, ?)
`, minerName, minerType, point.Timestamp, point.Hashrate, string(resolution))

95
pkg/database/interface.go Normal file
View file

@ -0,0 +1,95 @@
package database
import (
"context"
"time"
)
// HashrateStore defines the interface for hashrate data persistence.
// This interface allows for dependency injection and easier testing.
type HashrateStore interface {
// InsertHashratePoint stores a hashrate measurement.
// If ctx is nil, a default timeout will be used.
InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error
// GetHashrateHistory retrieves hashrate history for a miner within a time range.
GetHashrateHistory(minerName string, resolution Resolution, since, until time.Time) ([]HashratePoint, error)
// GetHashrateStats retrieves aggregated statistics for a specific miner.
GetHashrateStats(minerName string) (*HashrateStats, error)
// GetAllMinerStats retrieves statistics for all miners.
GetAllMinerStats() ([]HashrateStats, error)
// Cleanup removes old data based on retention settings.
Cleanup(retentionDays int) error
// Close closes the store and releases resources.
Close() error
}
// defaultStore implements HashrateStore using the global database connection.
// This provides backward compatibility while allowing interface-based usage.
type defaultStore struct{}
// DefaultStore returns a HashrateStore that uses the global database connection.
// This is useful for gradual migration from package-level functions to interface-based usage.
func DefaultStore() HashrateStore {
return &defaultStore{}
}
func (s *defaultStore) InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error {
return InsertHashratePoint(ctx, minerName, minerType, point, resolution)
}
func (s *defaultStore) GetHashrateHistory(minerName string, resolution Resolution, since, until time.Time) ([]HashratePoint, error) {
return GetHashrateHistory(minerName, resolution, since, until)
}
func (s *defaultStore) GetHashrateStats(minerName string) (*HashrateStats, error) {
return GetHashrateStats(minerName)
}
func (s *defaultStore) GetAllMinerStats() ([]HashrateStats, error) {
return GetAllMinerStats()
}
func (s *defaultStore) Cleanup(retentionDays int) error {
return Cleanup(retentionDays)
}
func (s *defaultStore) Close() error {
return Close()
}
// NopStore returns a HashrateStore that does nothing.
// Useful for testing or when database is disabled.
func NopStore() HashrateStore {
return &nopStore{}
}
type nopStore struct{}
func (s *nopStore) InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error {
return nil
}
func (s *nopStore) GetHashrateHistory(minerName string, resolution Resolution, since, until time.Time) ([]HashratePoint, error) {
return nil, nil
}
func (s *nopStore) GetHashrateStats(minerName string) (*HashrateStats, error) {
return nil, nil
}
func (s *nopStore) GetAllMinerStats() ([]HashrateStats, error) {
return nil, nil
}
func (s *nopStore) Cleanup(retentionDays int) error {
return nil
}
func (s *nopStore) Close() error {
return nil
}

View file

@ -0,0 +1,135 @@
package database
import (
"context"
"testing"
"time"
)
func TestDefaultStore(t *testing.T) {
cleanup := setupTestDB(t)
defer cleanup()
store := DefaultStore()
// Test InsertHashratePoint
point := HashratePoint{
Timestamp: time.Now(),
Hashrate: 1500,
}
if err := store.InsertHashratePoint(nil, "interface-test", "xmrig", point, ResolutionHigh); err != nil {
t.Fatalf("InsertHashratePoint failed: %v", err)
}
// Test GetHashrateHistory
history, err := store.GetHashrateHistory("interface-test", ResolutionHigh, time.Now().Add(-time.Hour), time.Now().Add(time.Hour))
if err != nil {
t.Fatalf("GetHashrateHistory failed: %v", err)
}
if len(history) != 1 {
t.Errorf("Expected 1 point, got %d", len(history))
}
// Test GetHashrateStats
stats, err := store.GetHashrateStats("interface-test")
if err != nil {
t.Fatalf("GetHashrateStats failed: %v", err)
}
if stats == nil {
t.Fatal("Expected non-nil stats")
}
if stats.TotalPoints != 1 {
t.Errorf("Expected 1 total point, got %d", stats.TotalPoints)
}
// Test GetAllMinerStats
allStats, err := store.GetAllMinerStats()
if err != nil {
t.Fatalf("GetAllMinerStats failed: %v", err)
}
if len(allStats) != 1 {
t.Errorf("Expected 1 miner in stats, got %d", len(allStats))
}
// Test Cleanup
if err := store.Cleanup(30); err != nil {
t.Fatalf("Cleanup failed: %v", err)
}
}
func TestDefaultStore_WithContext(t *testing.T) {
cleanup := setupTestDB(t)
defer cleanup()
store := DefaultStore()
ctx := context.Background()
point := HashratePoint{
Timestamp: time.Now(),
Hashrate: 2000,
}
if err := store.InsertHashratePoint(ctx, "ctx-test", "xmrig", point, ResolutionHigh); err != nil {
t.Fatalf("InsertHashratePoint with context failed: %v", err)
}
history, err := store.GetHashrateHistory("ctx-test", ResolutionHigh, time.Now().Add(-time.Hour), time.Now().Add(time.Hour))
if err != nil {
t.Fatalf("GetHashrateHistory failed: %v", err)
}
if len(history) != 1 {
t.Errorf("Expected 1 point, got %d", len(history))
}
}
func TestNopStore(t *testing.T) {
store := NopStore()
// All operations should succeed without error
point := HashratePoint{
Timestamp: time.Now(),
Hashrate: 1000,
}
if err := store.InsertHashratePoint(nil, "test", "xmrig", point, ResolutionHigh); err != nil {
t.Errorf("NopStore InsertHashratePoint should not error: %v", err)
}
history, err := store.GetHashrateHistory("test", ResolutionHigh, time.Now().Add(-time.Hour), time.Now())
if err != nil {
t.Errorf("NopStore GetHashrateHistory should not error: %v", err)
}
if history != nil {
t.Errorf("NopStore GetHashrateHistory should return nil, got %v", history)
}
stats, err := store.GetHashrateStats("test")
if err != nil {
t.Errorf("NopStore GetHashrateStats should not error: %v", err)
}
if stats != nil {
t.Errorf("NopStore GetHashrateStats should return nil, got %v", stats)
}
allStats, err := store.GetAllMinerStats()
if err != nil {
t.Errorf("NopStore GetAllMinerStats should not error: %v", err)
}
if allStats != nil {
t.Errorf("NopStore GetAllMinerStats should return nil, got %v", allStats)
}
if err := store.Cleanup(30); err != nil {
t.Errorf("NopStore Cleanup should not error: %v", err)
}
if err := store.Close(); err != nil {
t.Errorf("NopStore Close should not error: %v", err)
}
}
// TestInterfaceCompatibility ensures all implementations satisfy HashrateStore
func TestInterfaceCompatibility(t *testing.T) {
var _ HashrateStore = DefaultStore()
var _ HashrateStore = NopStore()
var _ HashrateStore = &defaultStore{}
var _ HashrateStore = &nopStore{}
}

260
pkg/mining/container.go Normal file
View file

@ -0,0 +1,260 @@
package mining
import (
"context"
"fmt"
"sync"
"github.com/Snider/Mining/pkg/database"
"github.com/Snider/Mining/pkg/logging"
)
// ContainerConfig holds configuration for the service container.
type ContainerConfig struct {
// Database configuration
Database database.Config
// ListenAddr is the address to listen on (e.g., ":9090")
ListenAddr string
// DisplayAddr is the address shown in Swagger docs
DisplayAddr string
// SwaggerNamespace is the API path prefix
SwaggerNamespace string
// SimulationMode enables simulation mode for testing
SimulationMode bool
}
// DefaultContainerConfig returns sensible defaults for the container.
func DefaultContainerConfig() ContainerConfig {
return ContainerConfig{
Database: database.Config{
Enabled: true,
RetentionDays: 30,
},
ListenAddr: ":9090",
DisplayAddr: "localhost:9090",
SwaggerNamespace: "/api/v1/mining",
SimulationMode: false,
}
}
// Container manages the lifecycle of all services.
// It provides centralized initialization, dependency injection, and graceful shutdown.
type Container struct {
config ContainerConfig
mu sync.RWMutex
// Core services
manager ManagerInterface
profileManager *ProfileManager
nodeService *NodeService
eventHub *EventHub
service *Service
// Database store (interface for testing)
hashrateStore database.HashrateStore
// Initialization state
initialized bool
transportStarted bool
shutdownCh chan struct{}
}
// NewContainer creates a new service container with the given configuration.
func NewContainer(config ContainerConfig) *Container {
return &Container{
config: config,
shutdownCh: make(chan struct{}),
}
}
// Initialize sets up all services in the correct order.
// This should be called before Start().
func (c *Container) Initialize(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.initialized {
return fmt.Errorf("container already initialized")
}
// 1. Initialize database (optional)
if c.config.Database.Enabled {
if err := database.Initialize(c.config.Database); err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
c.hashrateStore = database.DefaultStore()
logging.Info("database initialized", logging.Fields{"retention_days": c.config.Database.RetentionDays})
} else {
c.hashrateStore = database.NopStore()
logging.Info("database disabled, using no-op store", nil)
}
// 2. Initialize profile manager
var err error
c.profileManager, err = NewProfileManager()
if err != nil {
return fmt.Errorf("failed to initialize profile manager: %w", err)
}
// 3. Initialize miner manager
if c.config.SimulationMode {
c.manager = NewManagerForSimulation()
} else {
c.manager = NewManager()
}
// 4. Initialize node service (optional - P2P features)
c.nodeService, err = NewNodeService()
if err != nil {
logging.Warn("node service unavailable", logging.Fields{"error": err})
// Continue without node service - P2P features will be unavailable
}
// 5. Initialize event hub for WebSocket
c.eventHub = NewEventHub()
// Wire up event hub to manager
if mgr, ok := c.manager.(*Manager); ok {
mgr.SetEventHub(c.eventHub)
}
c.initialized = true
logging.Info("service container initialized", nil)
return nil
}
// Start begins all background services.
func (c *Container) Start(ctx context.Context) error {
c.mu.RLock()
defer c.mu.RUnlock()
if !c.initialized {
return fmt.Errorf("container not initialized")
}
// Start event hub
go c.eventHub.Run()
// Start node transport if available
if c.nodeService != nil {
if err := c.nodeService.StartTransport(); err != nil {
logging.Warn("failed to start node transport", logging.Fields{"error": err})
} else {
c.transportStarted = true
}
}
logging.Info("service container started", nil)
return nil
}
// Shutdown gracefully stops all services in reverse order.
func (c *Container) Shutdown(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.initialized {
return nil
}
logging.Info("shutting down service container", nil)
var errs []error
// 1. Stop service (HTTP server)
if c.service != nil {
// Service shutdown is handled externally
}
// 2. Stop node transport (only if it was started)
if c.nodeService != nil && c.transportStarted {
if err := c.nodeService.StopTransport(); err != nil {
errs = append(errs, fmt.Errorf("node transport: %w", err))
}
c.transportStarted = false
}
// 3. Stop event hub
if c.eventHub != nil {
c.eventHub.Stop()
}
// 4. Stop miner manager
if mgr, ok := c.manager.(*Manager); ok {
mgr.Stop()
}
// 5. Close database
if err := database.Close(); err != nil {
errs = append(errs, fmt.Errorf("database: %w", err))
}
c.initialized = false
close(c.shutdownCh)
if len(errs) > 0 {
return fmt.Errorf("shutdown errors: %v", errs)
}
logging.Info("service container shutdown complete", nil)
return nil
}
// Manager returns the miner manager.
func (c *Container) Manager() ManagerInterface {
c.mu.RLock()
defer c.mu.RUnlock()
return c.manager
}
// ProfileManager returns the profile manager.
func (c *Container) ProfileManager() *ProfileManager {
c.mu.RLock()
defer c.mu.RUnlock()
return c.profileManager
}
// NodeService returns the node service (may be nil if P2P is unavailable).
func (c *Container) NodeService() *NodeService {
c.mu.RLock()
defer c.mu.RUnlock()
return c.nodeService
}
// EventHub returns the event hub for WebSocket connections.
func (c *Container) EventHub() *EventHub {
c.mu.RLock()
defer c.mu.RUnlock()
return c.eventHub
}
// HashrateStore returns the hashrate store interface.
func (c *Container) HashrateStore() database.HashrateStore {
c.mu.RLock()
defer c.mu.RUnlock()
return c.hashrateStore
}
// SetHashrateStore allows injecting a custom hashrate store (useful for testing).
func (c *Container) SetHashrateStore(store database.HashrateStore) {
c.mu.Lock()
defer c.mu.Unlock()
c.hashrateStore = store
}
// ShutdownCh returns a channel that's closed when shutdown is complete.
func (c *Container) ShutdownCh() <-chan struct{} {
return c.shutdownCh
}
// IsInitialized returns true if the container has been initialized.
func (c *Container) IsInitialized() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.initialized
}

View file

@ -0,0 +1,238 @@
package mining
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/Snider/Mining/pkg/database"
)
func setupContainerTestEnv(t *testing.T) func() {
tmpDir := t.TempDir()
os.Setenv("XDG_CONFIG_HOME", filepath.Join(tmpDir, "config"))
os.Setenv("XDG_DATA_HOME", filepath.Join(tmpDir, "data"))
return func() {
os.Unsetenv("XDG_CONFIG_HOME")
os.Unsetenv("XDG_DATA_HOME")
}
}
func TestNewContainer(t *testing.T) {
config := DefaultContainerConfig()
container := NewContainer(config)
if container == nil {
t.Fatal("NewContainer returned nil")
}
if container.IsInitialized() {
t.Error("Container should not be initialized before Initialize() is called")
}
}
func TestDefaultContainerConfig(t *testing.T) {
config := DefaultContainerConfig()
if !config.Database.Enabled {
t.Error("Database should be enabled by default")
}
if config.Database.RetentionDays != 30 {
t.Errorf("Expected 30 retention days, got %d", config.Database.RetentionDays)
}
if config.ListenAddr != ":9090" {
t.Errorf("Expected :9090, got %s", config.ListenAddr)
}
if config.SimulationMode {
t.Error("SimulationMode should be false by default")
}
}
func TestContainer_Initialize(t *testing.T) {
cleanup := setupContainerTestEnv(t)
defer cleanup()
config := DefaultContainerConfig()
config.Database.Enabled = true
config.Database.Path = filepath.Join(t.TempDir(), "test.db")
config.SimulationMode = true // Use simulation mode for faster tests
container := NewContainer(config)
ctx := context.Background()
if err := container.Initialize(ctx); err != nil {
t.Fatalf("Initialize failed: %v", err)
}
if !container.IsInitialized() {
t.Error("Container should be initialized after Initialize()")
}
// Verify services are available
if container.Manager() == nil {
t.Error("Manager should not be nil after initialization")
}
if container.ProfileManager() == nil {
t.Error("ProfileManager should not be nil after initialization")
}
if container.EventHub() == nil {
t.Error("EventHub should not be nil after initialization")
}
if container.HashrateStore() == nil {
t.Error("HashrateStore should not be nil after initialization")
}
// Cleanup
if err := container.Shutdown(ctx); err != nil {
t.Errorf("Shutdown failed: %v", err)
}
}
func TestContainer_InitializeTwice(t *testing.T) {
cleanup := setupContainerTestEnv(t)
defer cleanup()
config := DefaultContainerConfig()
config.Database.Enabled = false
config.SimulationMode = true
container := NewContainer(config)
ctx := context.Background()
if err := container.Initialize(ctx); err != nil {
t.Fatalf("First Initialize failed: %v", err)
}
// Second initialization should fail
if err := container.Initialize(ctx); err == nil {
t.Error("Second Initialize should fail")
}
container.Shutdown(ctx)
}
func TestContainer_DatabaseDisabled(t *testing.T) {
cleanup := setupContainerTestEnv(t)
defer cleanup()
config := DefaultContainerConfig()
config.Database.Enabled = false
config.SimulationMode = true
container := NewContainer(config)
ctx := context.Background()
if err := container.Initialize(ctx); err != nil {
t.Fatalf("Initialize failed: %v", err)
}
// Should use NopStore when database is disabled
store := container.HashrateStore()
if store == nil {
t.Fatal("HashrateStore should not be nil")
}
// NopStore should accept inserts without error
point := database.HashratePoint{
Timestamp: time.Now(),
Hashrate: 1000,
}
if err := store.InsertHashratePoint(nil, "test", "xmrig", point, database.ResolutionHigh); err != nil {
t.Errorf("NopStore insert should not fail: %v", err)
}
container.Shutdown(ctx)
}
func TestContainer_SetHashrateStore(t *testing.T) {
cleanup := setupContainerTestEnv(t)
defer cleanup()
config := DefaultContainerConfig()
config.Database.Enabled = false
config.SimulationMode = true
container := NewContainer(config)
ctx := context.Background()
if err := container.Initialize(ctx); err != nil {
t.Fatalf("Initialize failed: %v", err)
}
// Inject custom store
customStore := database.NopStore()
container.SetHashrateStore(customStore)
if container.HashrateStore() != customStore {
t.Error("SetHashrateStore should update the store")
}
container.Shutdown(ctx)
}
func TestContainer_StartWithoutInitialize(t *testing.T) {
config := DefaultContainerConfig()
container := NewContainer(config)
ctx := context.Background()
if err := container.Start(ctx); err == nil {
t.Error("Start should fail if Initialize was not called")
}
}
func TestContainer_ShutdownWithoutInitialize(t *testing.T) {
config := DefaultContainerConfig()
container := NewContainer(config)
ctx := context.Background()
// Shutdown on uninitialized container should not error
if err := container.Shutdown(ctx); err != nil {
t.Errorf("Shutdown on uninitialized container should not error: %v", err)
}
}
func TestContainer_ShutdownChannel(t *testing.T) {
cleanup := setupContainerTestEnv(t)
defer cleanup()
config := DefaultContainerConfig()
config.Database.Enabled = false
config.SimulationMode = true
container := NewContainer(config)
ctx := context.Background()
if err := container.Initialize(ctx); err != nil {
t.Fatalf("Initialize failed: %v", err)
}
shutdownCh := container.ShutdownCh()
// Channel should be open before shutdown
select {
case <-shutdownCh:
t.Error("ShutdownCh should not be closed before Shutdown()")
default:
// Expected
}
if err := container.Shutdown(ctx); err != nil {
t.Errorf("Shutdown failed: %v", err)
}
// Channel should be closed after shutdown
select {
case <-shutdownCh:
// Expected
case <-time.After(time.Second):
t.Error("ShutdownCh should be closed after Shutdown()")
}
}

View file

@ -614,7 +614,8 @@ func (m *Manager) collectSingleMinerStats(miner Miner, minerType string, now tim
Timestamp: point.Timestamp,
Hashrate: point.Hashrate,
}
if err := database.InsertHashratePoint(minerName, minerType, dbPoint, database.ResolutionHigh); err != nil {
// Use nil context to let InsertHashratePoint use its default timeout
if err := database.InsertHashratePoint(nil, minerName, minerType, dbPoint, database.ResolutionHigh); err != nil {
logging.Warn("failed to persist hashrate", logging.Fields{"miner": minerName, "error": err})
}
}

View file

@ -246,6 +246,7 @@ func (b *BaseMiner) InstallFromURL(url string) error {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse
return fmt.Errorf("failed to download release: unexpected status code %d", resp.StatusCode)
}

View file

@ -334,7 +334,8 @@ func (ns *NodeService) handlePeerStats(c *gin.Context) {
// RemoteStartRequest is the request body for starting a remote miner.
type RemoteStartRequest struct {
ProfileID string `json:"profileId" binding:"required"`
MinerType string `json:"minerType" binding:"required"`
ProfileID string `json:"profileId,omitempty"`
Config json.RawMessage `json:"config,omitempty"`
}
@ -356,7 +357,7 @@ func (ns *NodeService) handleRemoteStart(c *gin.Context) {
return
}
if err := ns.controller.StartRemoteMiner(peerID, req.ProfileID, req.Config); err != nil {
if err := ns.controller.StartRemoteMiner(peerID, req.MinerType, req.ProfileID, req.Config); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

196
pkg/mining/repository.go Normal file
View file

@ -0,0 +1,196 @@
package mining
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
)
// Repository defines a generic interface for data persistence.
// Implementations can store data in files, databases, etc.
type Repository[T any] interface {
// Load reads data from the repository
Load() (T, error)
// Save writes data to the repository
Save(data T) error
// Update atomically loads, modifies, and saves data
Update(fn func(*T) error) error
}
// FileRepository provides atomic file-based persistence for JSON data.
// It uses atomic writes (temp file + rename) to prevent corruption.
type FileRepository[T any] struct {
mu sync.RWMutex
path string
defaults func() T
}
// FileRepositoryOption configures a FileRepository.
type FileRepositoryOption[T any] func(*FileRepository[T])
// WithDefaults sets the default value factory for when the file doesn't exist.
func WithDefaults[T any](fn func() T) FileRepositoryOption[T] {
return func(r *FileRepository[T]) {
r.defaults = fn
}
}
// NewFileRepository creates a new file-based repository.
func NewFileRepository[T any](path string, opts ...FileRepositoryOption[T]) *FileRepository[T] {
r := &FileRepository[T]{
path: path,
}
for _, opt := range opts {
opt(r)
}
return r
}
// Load reads and deserializes data from the file.
// Returns defaults if file doesn't exist.
func (r *FileRepository[T]) Load() (T, error) {
r.mu.RLock()
defer r.mu.RUnlock()
var result T
data, err := os.ReadFile(r.path)
if err != nil {
if os.IsNotExist(err) {
if r.defaults != nil {
return r.defaults(), nil
}
return result, nil
}
return result, fmt.Errorf("failed to read file: %w", err)
}
if err := json.Unmarshal(data, &result); err != nil {
return result, fmt.Errorf("failed to unmarshal data: %w", err)
}
return result, nil
}
// Save serializes and writes data to the file atomically.
func (r *FileRepository[T]) Save(data T) error {
r.mu.Lock()
defer r.mu.Unlock()
return r.saveUnlocked(data)
}
// saveUnlocked saves data without acquiring the lock (caller must hold lock).
func (r *FileRepository[T]) saveUnlocked(data T) error {
dir := filepath.Dir(r.path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
jsonData, err := json.MarshalIndent(data, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal data: %w", err)
}
// Atomic write: write to temp file, sync, then rename
tmpFile, err := os.CreateTemp(dir, "repo-*.tmp")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up temp file on error
success := false
defer func() {
if !success {
os.Remove(tmpPath)
}
}()
if _, err := tmpFile.Write(jsonData); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to write temp file: %w", err)
}
if err := tmpFile.Sync(); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to sync temp file: %w", err)
}
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file: %w", err)
}
if err := os.Chmod(tmpPath, 0600); err != nil {
return fmt.Errorf("failed to set file permissions: %w", err)
}
if err := os.Rename(tmpPath, r.path); err != nil {
return fmt.Errorf("failed to rename temp file: %w", err)
}
success = true
return nil
}
// Update atomically loads, modifies, and saves data.
// The modification function receives a pointer to the data.
func (r *FileRepository[T]) Update(fn func(*T) error) error {
r.mu.Lock()
defer r.mu.Unlock()
// Load current data
var data T
fileData, err := os.ReadFile(r.path)
if err != nil {
if os.IsNotExist(err) {
if r.defaults != nil {
data = r.defaults()
}
} else {
return fmt.Errorf("failed to read file: %w", err)
}
} else {
if err := json.Unmarshal(fileData, &data); err != nil {
return fmt.Errorf("failed to unmarshal data: %w", err)
}
}
// Apply modification
if err := fn(&data); err != nil {
return err
}
// Save atomically
return r.saveUnlocked(data)
}
// Path returns the file path of this repository.
func (r *FileRepository[T]) Path() string {
return r.path
}
// Exists returns true if the repository file exists.
func (r *FileRepository[T]) Exists() bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, err := os.Stat(r.path)
return err == nil
}
// Delete removes the repository file.
func (r *FileRepository[T]) Delete() error {
r.mu.Lock()
defer r.mu.Unlock()
err := os.Remove(r.path)
if os.IsNotExist(err) {
return nil
}
return err
}

View file

@ -0,0 +1,300 @@
package mining
import (
"errors"
"os"
"path/filepath"
"testing"
)
type testData struct {
Name string `json:"name"`
Value int `json:"value"`
}
func TestFileRepository_Load(t *testing.T) {
t.Run("NonExistentFile", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "nonexistent.json")
repo := NewFileRepository[testData](path)
data, err := repo.Load()
if err != nil {
t.Fatalf("Load should not error for non-existent file: %v", err)
}
if data.Name != "" || data.Value != 0 {
t.Error("Expected zero value for non-existent file")
}
})
t.Run("NonExistentFileWithDefaults", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "nonexistent.json")
repo := NewFileRepository[testData](path, WithDefaults(func() testData {
return testData{Name: "default", Value: 42}
}))
data, err := repo.Load()
if err != nil {
t.Fatalf("Load should not error: %v", err)
}
if data.Name != "default" || data.Value != 42 {
t.Errorf("Expected default values, got %+v", data)
}
})
t.Run("ExistingFile", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "test.json")
// Write test data
if err := os.WriteFile(path, []byte(`{"name":"test","value":123}`), 0600); err != nil {
t.Fatalf("Failed to write test file: %v", err)
}
repo := NewFileRepository[testData](path)
data, err := repo.Load()
if err != nil {
t.Fatalf("Load failed: %v", err)
}
if data.Name != "test" || data.Value != 123 {
t.Errorf("Unexpected data: %+v", data)
}
})
t.Run("InvalidJSON", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "invalid.json")
if err := os.WriteFile(path, []byte(`{invalid json}`), 0600); err != nil {
t.Fatalf("Failed to write test file: %v", err)
}
repo := NewFileRepository[testData](path)
_, err := repo.Load()
if err == nil {
t.Error("Expected error for invalid JSON")
}
})
}
func TestFileRepository_Save(t *testing.T) {
t.Run("NewFile", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "subdir", "new.json")
repo := NewFileRepository[testData](path)
data := testData{Name: "saved", Value: 456}
if err := repo.Save(data); err != nil {
t.Fatalf("Save failed: %v", err)
}
// Verify file was created
if !repo.Exists() {
t.Error("File should exist after save")
}
// Verify content
loaded, err := repo.Load()
if err != nil {
t.Fatalf("Load after save failed: %v", err)
}
if loaded.Name != "saved" || loaded.Value != 456 {
t.Errorf("Unexpected loaded data: %+v", loaded)
}
})
t.Run("OverwriteExisting", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "existing.json")
repo := NewFileRepository[testData](path)
// Save initial data
if err := repo.Save(testData{Name: "first", Value: 1}); err != nil {
t.Fatalf("First save failed: %v", err)
}
// Overwrite
if err := repo.Save(testData{Name: "second", Value: 2}); err != nil {
t.Fatalf("Second save failed: %v", err)
}
// Verify overwrite
loaded, err := repo.Load()
if err != nil {
t.Fatalf("Load failed: %v", err)
}
if loaded.Name != "second" || loaded.Value != 2 {
t.Errorf("Expected overwritten data, got: %+v", loaded)
}
})
}
func TestFileRepository_Update(t *testing.T) {
t.Run("UpdateExisting", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "update.json")
repo := NewFileRepository[testData](path)
// Save initial data
if err := repo.Save(testData{Name: "initial", Value: 10}); err != nil {
t.Fatalf("Initial save failed: %v", err)
}
// Update
err := repo.Update(func(data *testData) error {
data.Value += 5
return nil
})
if err != nil {
t.Fatalf("Update failed: %v", err)
}
// Verify update
loaded, err := repo.Load()
if err != nil {
t.Fatalf("Load failed: %v", err)
}
if loaded.Value != 15 {
t.Errorf("Expected value 15, got %d", loaded.Value)
}
})
t.Run("UpdateNonExistentWithDefaults", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "new.json")
repo := NewFileRepository[testData](path, WithDefaults(func() testData {
return testData{Name: "default", Value: 100}
}))
err := repo.Update(func(data *testData) error {
data.Value *= 2
return nil
})
if err != nil {
t.Fatalf("Update failed: %v", err)
}
// Verify update started from defaults
loaded, err := repo.Load()
if err != nil {
t.Fatalf("Load failed: %v", err)
}
if loaded.Value != 200 {
t.Errorf("Expected value 200, got %d", loaded.Value)
}
})
t.Run("UpdateWithError", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "error.json")
repo := NewFileRepository[testData](path)
if err := repo.Save(testData{Name: "test", Value: 1}); err != nil {
t.Fatalf("Initial save failed: %v", err)
}
// Update that returns error
testErr := errors.New("update error")
err := repo.Update(func(data *testData) error {
data.Value = 999 // This change should not be saved
return testErr
})
if err != testErr {
t.Errorf("Expected test error, got: %v", err)
}
// Verify original data unchanged
loaded, err := repo.Load()
if err != nil {
t.Fatalf("Load failed: %v", err)
}
if loaded.Value != 1 {
t.Errorf("Expected value 1 (unchanged), got %d", loaded.Value)
}
})
}
func TestFileRepository_Delete(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "delete.json")
repo := NewFileRepository[testData](path)
// Save data
if err := repo.Save(testData{Name: "temp", Value: 1}); err != nil {
t.Fatalf("Save failed: %v", err)
}
if !repo.Exists() {
t.Error("File should exist after save")
}
// Delete
if err := repo.Delete(); err != nil {
t.Fatalf("Delete failed: %v", err)
}
if repo.Exists() {
t.Error("File should not exist after delete")
}
// Delete non-existent should not error
if err := repo.Delete(); err != nil {
t.Errorf("Delete non-existent should not error: %v", err)
}
}
func TestFileRepository_Path(t *testing.T) {
path := "/some/path/config.json"
repo := NewFileRepository[testData](path)
if repo.Path() != path {
t.Errorf("Expected path %s, got %s", path, repo.Path())
}
}
// Test with slice data
func TestFileRepository_SliceData(t *testing.T) {
type item struct {
ID string `json:"id"`
Name string `json:"name"`
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "items.json")
repo := NewFileRepository[[]item](path, WithDefaults(func() []item {
return []item{}
}))
// Save slice
items := []item{
{ID: "1", Name: "First"},
{ID: "2", Name: "Second"},
}
if err := repo.Save(items); err != nil {
t.Fatalf("Save failed: %v", err)
}
// Load and verify
loaded, err := repo.Load()
if err != nil {
t.Fatalf("Load failed: %v", err)
}
if len(loaded) != 2 {
t.Errorf("Expected 2 items, got %d", len(loaded))
}
// Update slice
err = repo.Update(func(data *[]item) error {
*data = append(*data, item{ID: "3", Name: "Third"})
return nil
})
if err != nil {
t.Fatalf("Update failed: %v", err)
}
loaded, _ = repo.Load()
if len(loaded) != 3 {
t.Errorf("Expected 3 items after update, got %d", len(loaded))
}
}

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
@ -92,6 +93,7 @@ func (m *TTMiner) GetLatestVersion() (string, error) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse
return "", fmt.Errorf("failed to get latest release: unexpected status code %d", resp.StatusCode)
}

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
)
@ -41,6 +42,7 @@ func (m *TTMiner) GetStats(ctx context.Context) (*PerformanceMetrics, error) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse
return nil, fmt.Errorf("failed to get stats: unexpected status code %d", resp.StatusCode)
}

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
@ -93,6 +94,7 @@ func (m *XMRigMiner) GetLatestVersion() (string, error) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse
return "", fmt.Errorf("failed to get latest release: unexpected status code %d", resp.StatusCode)
}

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
)
@ -45,6 +46,7 @@ func (m *XMRigMiner) GetStats(ctx context.Context) (*PerformanceMetrics, error)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse
return nil, fmt.Errorf("failed to get stats: unexpected status code %d", resp.StatusCode)
}

View file

@ -125,34 +125,27 @@ func (c *Controller) GetRemoteStats(peerID string) (*StatsPayload, error) {
return nil, err
}
if resp.Type == MsgError {
var errPayload ErrorPayload
if err := resp.ParsePayload(&errPayload); err != nil {
return nil, fmt.Errorf("remote error (unable to parse)")
}
return nil, fmt.Errorf("remote error: %s", errPayload.Message)
}
if resp.Type != MsgStats {
return nil, fmt.Errorf("unexpected response type: %s", resp.Type)
}
var stats StatsPayload
if err := resp.ParsePayload(&stats); err != nil {
return nil, fmt.Errorf("failed to parse stats: %w", err)
if err := ParseResponse(resp, MsgStats, &stats); err != nil {
return nil, err
}
return &stats, nil
}
// StartRemoteMiner requests a remote peer to start a miner with a given profile.
func (c *Controller) StartRemoteMiner(peerID, profileID string, configOverride json.RawMessage) error {
func (c *Controller) StartRemoteMiner(peerID, minerType, profileID string, configOverride json.RawMessage) error {
identity := c.node.GetIdentity()
if identity == nil {
return fmt.Errorf("node identity not initialized")
}
if minerType == "" {
return fmt.Errorf("miner type is required")
}
payload := StartMinerPayload{
MinerType: minerType,
ProfileID: profileID,
Config: configOverride,
}
@ -167,21 +160,9 @@ func (c *Controller) StartRemoteMiner(peerID, profileID string, configOverride j
return err
}
if resp.Type == MsgError {
var errPayload ErrorPayload
if err := resp.ParsePayload(&errPayload); err != nil {
return fmt.Errorf("remote error (unable to parse)")
}
return fmt.Errorf("remote error: %s", errPayload.Message)
}
if resp.Type != MsgMinerAck {
return fmt.Errorf("unexpected response type: %s", resp.Type)
}
var ack MinerAckPayload
if err := resp.ParsePayload(&ack); err != nil {
return fmt.Errorf("failed to parse ack: %w", err)
if err := ParseResponse(resp, MsgMinerAck, &ack); err != nil {
return err
}
if !ack.Success {
@ -212,21 +193,9 @@ func (c *Controller) StopRemoteMiner(peerID, minerName string) error {
return err
}
if resp.Type == MsgError {
var errPayload ErrorPayload
if err := resp.ParsePayload(&errPayload); err != nil {
return fmt.Errorf("remote error (unable to parse)")
}
return fmt.Errorf("remote error: %s", errPayload.Message)
}
if resp.Type != MsgMinerAck {
return fmt.Errorf("unexpected response type: %s", resp.Type)
}
var ack MinerAckPayload
if err := resp.ParsePayload(&ack); err != nil {
return fmt.Errorf("failed to parse ack: %w", err)
if err := ParseResponse(resp, MsgMinerAck, &ack); err != nil {
return err
}
if !ack.Success {
@ -258,21 +227,9 @@ func (c *Controller) GetRemoteLogs(peerID, minerName string, lines int) ([]strin
return nil, err
}
if resp.Type == MsgError {
var errPayload ErrorPayload
if err := resp.ParsePayload(&errPayload); err != nil {
return nil, fmt.Errorf("remote error (unable to parse)")
}
return nil, fmt.Errorf("remote error: %s", errPayload.Message)
}
if resp.Type != MsgLogs {
return nil, fmt.Errorf("unexpected response type: %s", resp.Type)
}
var logs LogsPayload
if err := resp.ParsePayload(&logs); err != nil {
return nil, fmt.Errorf("failed to parse logs: %w", err)
if err := ParseResponse(resp, MsgLogs, &logs); err != nil {
return nil, err
}
return logs.Lines, nil
@ -325,8 +282,8 @@ func (c *Controller) PingPeer(peerID string) (float64, error) {
return 0, err
}
if resp.Type != MsgPong {
return 0, fmt.Errorf("unexpected response type: %s", resp.Type)
if err := ValidateResponse(resp, MsgPong); err != nil {
return 0, err
}
// Calculate round-trip time

View file

@ -117,7 +117,8 @@ type PongPayload struct {
// StartMinerPayload requests starting a miner.
type StartMinerPayload struct {
ProfileID string `json:"profileId"`
MinerType string `json:"minerType"` // Required: miner type (e.g., "xmrig", "tt-miner")
ProfileID string `json:"profileId,omitempty"`
Config json.RawMessage `json:"config,omitempty"` // Override profile config
}

View file

@ -92,6 +92,7 @@ func TestMessageReply(t *testing.T) {
func TestParsePayload(t *testing.T) {
t.Run("ValidPayload", func(t *testing.T) {
payload := StartMinerPayload{
MinerType: "xmrig",
ProfileID: "test-profile",
}
@ -190,6 +191,7 @@ func TestNewErrorMessage(t *testing.T) {
func TestMessageSerialization(t *testing.T) {
original, _ := NewMessage(MsgStartMiner, "ctrl", "worker", StartMinerPayload{
MinerType: "xmrig",
ProfileID: "my-profile",
})

88
pkg/node/protocol.go Normal file
View file

@ -0,0 +1,88 @@
package node
import (
"fmt"
)
// ProtocolError represents an error from the remote peer.
type ProtocolError struct {
Code int
Message string
}
func (e *ProtocolError) Error() string {
return fmt.Sprintf("remote error (%d): %s", e.Code, e.Message)
}
// ResponseHandler provides helpers for handling protocol responses.
type ResponseHandler struct{}
// ValidateResponse checks if the response is valid and returns a parsed error if it's an error response.
// It checks:
// 1. If response is nil (returns error)
// 2. If response is an error message (returns ProtocolError)
// 3. If response type matches expected (returns error if not)
func (h *ResponseHandler) ValidateResponse(resp *Message, expectedType MessageType) error {
if resp == nil {
return fmt.Errorf("nil response")
}
// Check for error response
if resp.Type == MsgError {
var errPayload ErrorPayload
if err := resp.ParsePayload(&errPayload); err != nil {
return &ProtocolError{Code: ErrCodeUnknown, Message: "unable to parse error response"}
}
return &ProtocolError{Code: errPayload.Code, Message: errPayload.Message}
}
// Check expected type
if resp.Type != expectedType {
return fmt.Errorf("unexpected response type: expected %s, got %s", expectedType, resp.Type)
}
return nil
}
// ParseResponse validates the response and parses the payload into the target.
// This combines ValidateResponse and ParsePayload into a single call.
func (h *ResponseHandler) ParseResponse(resp *Message, expectedType MessageType, target interface{}) error {
if err := h.ValidateResponse(resp, expectedType); err != nil {
return err
}
if target != nil {
if err := resp.ParsePayload(target); err != nil {
return fmt.Errorf("failed to parse %s payload: %w", expectedType, err)
}
}
return nil
}
// DefaultResponseHandler is the default response handler instance.
var DefaultResponseHandler = &ResponseHandler{}
// ValidateResponse is a convenience function using the default handler.
func ValidateResponse(resp *Message, expectedType MessageType) error {
return DefaultResponseHandler.ValidateResponse(resp, expectedType)
}
// ParseResponse is a convenience function using the default handler.
func ParseResponse(resp *Message, expectedType MessageType, target interface{}) error {
return DefaultResponseHandler.ParseResponse(resp, expectedType, target)
}
// IsProtocolError returns true if the error is a ProtocolError.
func IsProtocolError(err error) bool {
_, ok := err.(*ProtocolError)
return ok
}
// GetProtocolErrorCode returns the error code if err is a ProtocolError, otherwise returns 0.
func GetProtocolErrorCode(err error) int {
if pe, ok := err.(*ProtocolError); ok {
return pe.Code
}
return 0
}

161
pkg/node/protocol_test.go Normal file
View file

@ -0,0 +1,161 @@
package node
import (
"fmt"
"testing"
)
func TestResponseHandler_ValidateResponse(t *testing.T) {
handler := &ResponseHandler{}
t.Run("NilResponse", func(t *testing.T) {
err := handler.ValidateResponse(nil, MsgStats)
if err == nil {
t.Error("Expected error for nil response")
}
})
t.Run("ErrorResponse", func(t *testing.T) {
errMsg, _ := NewErrorMessage("sender", "receiver", ErrCodeOperationFailed, "operation failed", "")
err := handler.ValidateResponse(errMsg, MsgStats)
if err == nil {
t.Fatal("Expected error for error response")
}
if !IsProtocolError(err) {
t.Errorf("Expected ProtocolError, got %T", err)
}
if GetProtocolErrorCode(err) != ErrCodeOperationFailed {
t.Errorf("Expected code %d, got %d", ErrCodeOperationFailed, GetProtocolErrorCode(err))
}
})
t.Run("WrongType", func(t *testing.T) {
msg, _ := NewMessage(MsgPong, "sender", "receiver", nil)
err := handler.ValidateResponse(msg, MsgStats)
if err == nil {
t.Error("Expected error for wrong type")
}
if IsProtocolError(err) {
t.Error("Should not be a ProtocolError for type mismatch")
}
})
t.Run("ValidResponse", func(t *testing.T) {
msg, _ := NewMessage(MsgStats, "sender", "receiver", StatsPayload{NodeID: "test"})
err := handler.ValidateResponse(msg, MsgStats)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
})
}
func TestResponseHandler_ParseResponse(t *testing.T) {
handler := &ResponseHandler{}
t.Run("ParseStats", func(t *testing.T) {
payload := StatsPayload{
NodeID: "node-123",
NodeName: "Test Node",
Uptime: 3600,
}
msg, _ := NewMessage(MsgStats, "sender", "receiver", payload)
var parsed StatsPayload
err := handler.ParseResponse(msg, MsgStats, &parsed)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if parsed.NodeID != "node-123" {
t.Errorf("Expected NodeID 'node-123', got '%s'", parsed.NodeID)
}
if parsed.Uptime != 3600 {
t.Errorf("Expected Uptime 3600, got %d", parsed.Uptime)
}
})
t.Run("ParseMinerAck", func(t *testing.T) {
payload := MinerAckPayload{
Success: true,
MinerName: "xmrig-1",
}
msg, _ := NewMessage(MsgMinerAck, "sender", "receiver", payload)
var parsed MinerAckPayload
err := handler.ParseResponse(msg, MsgMinerAck, &parsed)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !parsed.Success {
t.Error("Expected Success to be true")
}
if parsed.MinerName != "xmrig-1" {
t.Errorf("Expected MinerName 'xmrig-1', got '%s'", parsed.MinerName)
}
})
t.Run("ErrorResponse", func(t *testing.T) {
errMsg, _ := NewErrorMessage("sender", "receiver", ErrCodeNotFound, "not found", "")
var parsed StatsPayload
err := handler.ParseResponse(errMsg, MsgStats, &parsed)
if err == nil {
t.Error("Expected error for error response")
}
if !IsProtocolError(err) {
t.Errorf("Expected ProtocolError, got %T", err)
}
})
t.Run("NilTarget", func(t *testing.T) {
msg, _ := NewMessage(MsgPong, "sender", "receiver", nil)
err := handler.ParseResponse(msg, MsgPong, nil)
if err != nil {
t.Errorf("Unexpected error with nil target: %v", err)
}
})
}
func TestProtocolError(t *testing.T) {
err := &ProtocolError{Code: 1001, Message: "test error"}
if err.Error() != "remote error (1001): test error" {
t.Errorf("Unexpected error message: %s", err.Error())
}
if !IsProtocolError(err) {
t.Error("IsProtocolError should return true")
}
if GetProtocolErrorCode(err) != 1001 {
t.Errorf("Expected code 1001, got %d", GetProtocolErrorCode(err))
}
}
func TestConvenienceFunctions(t *testing.T) {
msg, _ := NewMessage(MsgStats, "sender", "receiver", StatsPayload{NodeID: "test"})
// Test ValidateResponse
if err := ValidateResponse(msg, MsgStats); err != nil {
t.Errorf("ValidateResponse failed: %v", err)
}
// Test ParseResponse
var parsed StatsPayload
if err := ParseResponse(msg, MsgStats, &parsed); err != nil {
t.Errorf("ParseResponse failed: %v", err)
}
if parsed.NodeID != "test" {
t.Errorf("Expected NodeID 'test', got '%s'", parsed.NodeID)
}
}
func TestGetProtocolErrorCode_NonProtocolError(t *testing.T) {
err := fmt.Errorf("regular error")
if GetProtocolErrorCode(err) != 0 {
t.Error("Expected 0 for non-ProtocolError")
}
}

View file

@ -198,6 +198,11 @@ func (w *Worker) handleStartMiner(msg *Message) (*Message, error) {
return nil, fmt.Errorf("invalid start miner payload: %w", err)
}
// Validate miner type is provided
if payload.MinerType == "" {
return nil, fmt.Errorf("miner type is required")
}
// Get the config from the profile or use the override
var config interface{}
if payload.Config != nil {
@ -213,7 +218,7 @@ func (w *Worker) handleStartMiner(msg *Message) (*Message, error) {
}
// Start the miner
miner, err := w.minerManager.StartMiner("", config)
miner, err := w.minerManager.StartMiner(payload.MinerType, config)
if err != nil {
ack := MinerAckPayload{
Success: false,

View file

@ -247,7 +247,7 @@ func TestWorker_HandleStartMiner_NoManager(t *testing.T) {
if identity == nil {
t.Fatal("expected identity to be generated")
}
payload := StartMinerPayload{ProfileID: "test-profile"}
payload := StartMinerPayload{MinerType: "xmrig", ProfileID: "test-profile"}
msg, err := NewMessage(MsgStartMiner, "sender-id", identity.ID, payload)
if err != nil {
t.Fatalf("failed to create start_miner message: %v", err)