From 34f860309facdbf4472cb15fd3f7952d4faf336c Mon Sep 17 00:00:00 2001 From: snider Date: Wed, 31 Dec 2025 12:43:46 +0000 Subject: [PATCH] refactor: Add reliability fixes and architecture improvements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reliability fixes: - Fix HTTP response body drainage in xmrig, ttminer, miner - Fix database init race condition (nil before close) - Fix empty minerType bug in P2P StartMinerPayload - Add context timeout to InsertHashratePoint (5s default) Architecture improvements: - Extract HashrateStore interface with DefaultStore/NopStore - Create ServiceContainer for centralized initialization - Extract protocol response handler (ValidateResponse, ParseResponse) - Create generic FileRepository[T] with atomic writes 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- cmd/mining/cmd/remote.go | 12 +- pkg/database/database.go | 4 +- pkg/database/database_race_test.go | 14 +- pkg/database/database_test.go | 14 +- pkg/database/hashrate.go | 18 +- pkg/database/interface.go | 95 +++++++++ pkg/database/interface_test.go | 135 +++++++++++++ pkg/mining/container.go | 260 +++++++++++++++++++++++++ pkg/mining/container_test.go | 238 +++++++++++++++++++++++ pkg/mining/manager.go | 3 +- pkg/mining/miner.go | 1 + pkg/mining/node_service.go | 5 +- pkg/mining/repository.go | 196 +++++++++++++++++++ pkg/mining/repository_test.go | 300 +++++++++++++++++++++++++++++ pkg/mining/ttminer.go | 2 + pkg/mining/ttminer_stats.go | 2 + pkg/mining/xmrig.go | 2 + pkg/mining/xmrig_stats.go | 2 + pkg/node/controller.go | 75 ++------ pkg/node/message.go | 3 +- pkg/node/message_test.go | 2 + pkg/node/protocol.go | 88 +++++++++ pkg/node/protocol_test.go | 161 ++++++++++++++++ pkg/node/worker.go | 7 +- pkg/node/worker_test.go | 2 +- 25 files changed, 1553 insertions(+), 88 deletions(-) create mode 100644 pkg/database/interface.go create mode 100644 pkg/database/interface_test.go create mode 100644 pkg/mining/container.go create mode 100644 pkg/mining/container_test.go create mode 100644 pkg/mining/repository.go create mode 100644 pkg/mining/repository_test.go create mode 100644 pkg/node/protocol.go create mode 100644 pkg/node/protocol_test.go diff --git a/cmd/mining/cmd/remote.go b/cmd/mining/cmd/remote.go index fa664f6..b970300 100644 --- a/cmd/mining/cmd/remote.go +++ b/cmd/mining/cmd/remote.go @@ -82,10 +82,11 @@ var remoteStartCmd = &cobra.Command{ Long: `Start a miner on a remote peer using a profile.`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - profileID, _ := cmd.Flags().GetString("profile") - if profileID == "" { - return fmt.Errorf("--profile is required") + minerType, _ := cmd.Flags().GetString("type") + if minerType == "" { + return fmt.Errorf("--type is required (e.g., xmrig, tt-miner)") } + profileID, _ := cmd.Flags().GetString("profile") peerID := args[0] peer := findPeerByPartialID(peerID) @@ -98,8 +99,8 @@ var remoteStartCmd = &cobra.Command{ return err } - fmt.Printf("Starting miner on %s with profile %s...\n", peer.Name, profileID) - if err := ctrl.StartRemoteMiner(peer.ID, profileID, nil); err != nil { + fmt.Printf("Starting %s miner on %s with profile %s...\n", minerType, peer.Name, profileID) + if err := ctrl.StartRemoteMiner(peer.ID, minerType, profileID, nil); err != nil { return fmt.Errorf("failed to start miner: %w", err) } @@ -298,6 +299,7 @@ func init() { // remote start remoteCmd.AddCommand(remoteStartCmd) remoteStartCmd.Flags().StringP("profile", "p", "", "Profile ID to use for starting the miner") + remoteStartCmd.Flags().StringP("type", "t", "", "Miner type (e.g., xmrig, tt-miner)") // remote stop remoteCmd.AddCommand(remoteStopCmd) diff --git a/pkg/database/database.go b/pkg/database/database.go index 34f2b40..8251e7d 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -77,8 +77,10 @@ func Initialize(cfg Config) error { // Create tables if err := createTables(); err != nil { - db.Close() + // Nil out global before closing to prevent use of closed connection + closingDB := db db = nil + closingDB.Close() return fmt.Errorf("failed to create tables: %w", err) } diff --git a/pkg/database/database_race_test.go b/pkg/database/database_race_test.go index 9815d71..02ae0ce 100644 --- a/pkg/database/database_race_test.go +++ b/pkg/database/database_race_test.go @@ -50,7 +50,7 @@ func TestConcurrentHashrateInserts(t *testing.T) { Timestamp: time.Now().Add(time.Duration(-j) * time.Second), Hashrate: 1000 + minerIndex*100 + j, } - err := InsertHashratePoint(minerName, minerType, point, ResolutionHigh) + err := InsertHashratePoint(nil, minerName, minerType, point, ResolutionHigh) if err != nil { t.Errorf("Insert error for %s: %v", minerName, err) } @@ -95,7 +95,7 @@ func TestConcurrentInsertAndQuery(t *testing.T) { Timestamp: time.Now(), Hashrate: 1000 + i, } - InsertHashratePoint("concurrent-test", "xmrig", point, ResolutionHigh) + InsertHashratePoint(nil, "concurrent-test", "xmrig", point, ResolutionHigh) time.Sleep(time.Millisecond) } } @@ -149,13 +149,13 @@ func TestConcurrentInsertAndCleanup(t *testing.T) { Timestamp: time.Now().AddDate(0, 0, -10), // 10 days old Hashrate: 500 + i, } - InsertHashratePoint("cleanup-test", "xmrig", oldPoint, ResolutionHigh) + InsertHashratePoint(nil, "cleanup-test", "xmrig", oldPoint, ResolutionHigh) newPoint := HashratePoint{ Timestamp: time.Now(), Hashrate: 1000 + i, } - InsertHashratePoint("cleanup-test", "xmrig", newPoint, ResolutionHigh) + InsertHashratePoint(nil, "cleanup-test", "xmrig", newPoint, ResolutionHigh) time.Sleep(time.Millisecond) } } @@ -197,7 +197,7 @@ func TestConcurrentStats(t *testing.T) { Timestamp: time.Now().Add(time.Duration(-i) * time.Second), Hashrate: 1000 + i*10, } - InsertHashratePoint(minerName, "xmrig", point, ResolutionHigh) + InsertHashratePoint(nil, minerName, "xmrig", point, ResolutionHigh) } var wg sync.WaitGroup @@ -238,7 +238,7 @@ func TestConcurrentGetAllStats(t *testing.T) { Timestamp: time.Now().Add(time.Duration(-i) * time.Second), Hashrate: 1000 + m*100 + i, } - InsertHashratePoint(minerName, "xmrig", point, ResolutionHigh) + InsertHashratePoint(nil, minerName, "xmrig", point, ResolutionHigh) } } @@ -267,7 +267,7 @@ func TestConcurrentGetAllStats(t *testing.T) { Timestamp: time.Now(), Hashrate: 2000 + i, } - InsertHashratePoint("all-stats-new", "xmrig", point, ResolutionHigh) + InsertHashratePoint(nil, "all-stats-new", "xmrig", point, ResolutionHigh) } }() diff --git a/pkg/database/database_test.go b/pkg/database/database_test.go index 3ad4a87..2bb3f7c 100644 --- a/pkg/database/database_test.go +++ b/pkg/database/database_test.go @@ -77,7 +77,7 @@ func TestHashrateStorage(t *testing.T) { } for _, p := range points { - if err := InsertHashratePoint(minerName, minerType, p, ResolutionHigh); err != nil { + if err := InsertHashratePoint(nil, minerName, minerType, p, ResolutionHigh); err != nil { t.Fatalf("Failed to store hashrate point: %v", err) } } @@ -109,7 +109,7 @@ func TestGetHashrateStats(t *testing.T) { } for _, p := range points { - if err := InsertHashratePoint(minerName, minerType, p, ResolutionHigh); err != nil { + if err := InsertHashratePoint(nil, minerName, minerType, p, ResolutionHigh); err != nil { t.Fatalf("Failed to store point: %v", err) } } @@ -175,13 +175,13 @@ func TestCleanupRetention(t *testing.T) { } // Insert all points - if err := InsertHashratePoint(minerName, minerType, oldPoint, ResolutionHigh); err != nil { + if err := InsertHashratePoint(nil, minerName, minerType, oldPoint, ResolutionHigh); err != nil { t.Fatalf("Failed to insert old point: %v", err) } - if err := InsertHashratePoint(minerName, minerType, midPoint, ResolutionHigh); err != nil { + if err := InsertHashratePoint(nil, minerName, minerType, midPoint, ResolutionHigh); err != nil { t.Fatalf("Failed to insert mid point: %v", err) } - if err := InsertHashratePoint(minerName, minerType, newPoint, ResolutionHigh); err != nil { + if err := InsertHashratePoint(nil, minerName, minerType, newPoint, ResolutionHigh); err != nil { t.Fatalf("Failed to insert new point: %v", err) } @@ -238,7 +238,7 @@ func TestGetHashrateHistoryTimeRange(t *testing.T) { Timestamp: now.Add(offset), Hashrate: 1000 + i*100, } - if err := InsertHashratePoint(minerName, minerType, point, ResolutionHigh); err != nil { + if err := InsertHashratePoint(nil, minerName, minerType, point, ResolutionHigh); err != nil { t.Fatalf("Failed to insert point: %v", err) } } @@ -291,7 +291,7 @@ func TestMultipleMinerStats(t *testing.T) { Timestamp: now.Add(time.Duration(-i) * time.Minute), Hashrate: hr, } - if err := InsertHashratePoint(m.name, "xmrig", point, ResolutionHigh); err != nil { + if err := InsertHashratePoint(nil, m.name, "xmrig", point, ResolutionHigh); err != nil { t.Fatalf("Failed to insert point for %s: %v", m.name, err) } } diff --git a/pkg/database/hashrate.go b/pkg/database/hashrate.go index ed2a497..d1c8450 100644 --- a/pkg/database/hashrate.go +++ b/pkg/database/hashrate.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "time" @@ -47,8 +48,12 @@ type HashratePoint struct { Hashrate int `json:"hashrate"` } -// InsertHashratePoint stores a hashrate measurement in the database -func InsertHashratePoint(minerName, minerType string, point HashratePoint, resolution Resolution) error { +// dbInsertTimeout is the maximum time to wait for a database insert operation +const dbInsertTimeout = 5 * time.Second + +// InsertHashratePoint stores a hashrate measurement in the database. +// If ctx is nil, a default timeout context will be used. +func InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error { dbMu.RLock() defer dbMu.RUnlock() @@ -56,7 +61,14 @@ func InsertHashratePoint(minerName, minerType string, point HashratePoint, resol return nil // DB not enabled, silently skip } - _, err := db.Exec(` + // Use provided context or create one with default timeout + if ctx == nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), dbInsertTimeout) + defer cancel() + } + + _, err := db.ExecContext(ctx, ` INSERT INTO hashrate_history (miner_name, miner_type, timestamp, hashrate, resolution) VALUES (?, ?, ?, ?, ?) `, minerName, minerType, point.Timestamp, point.Hashrate, string(resolution)) diff --git a/pkg/database/interface.go b/pkg/database/interface.go new file mode 100644 index 0000000..0312687 --- /dev/null +++ b/pkg/database/interface.go @@ -0,0 +1,95 @@ +package database + +import ( + "context" + "time" +) + +// HashrateStore defines the interface for hashrate data persistence. +// This interface allows for dependency injection and easier testing. +type HashrateStore interface { + // InsertHashratePoint stores a hashrate measurement. + // If ctx is nil, a default timeout will be used. + InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error + + // GetHashrateHistory retrieves hashrate history for a miner within a time range. + GetHashrateHistory(minerName string, resolution Resolution, since, until time.Time) ([]HashratePoint, error) + + // GetHashrateStats retrieves aggregated statistics for a specific miner. + GetHashrateStats(minerName string) (*HashrateStats, error) + + // GetAllMinerStats retrieves statistics for all miners. + GetAllMinerStats() ([]HashrateStats, error) + + // Cleanup removes old data based on retention settings. + Cleanup(retentionDays int) error + + // Close closes the store and releases resources. + Close() error +} + +// defaultStore implements HashrateStore using the global database connection. +// This provides backward compatibility while allowing interface-based usage. +type defaultStore struct{} + +// DefaultStore returns a HashrateStore that uses the global database connection. +// This is useful for gradual migration from package-level functions to interface-based usage. +func DefaultStore() HashrateStore { + return &defaultStore{} +} + +func (s *defaultStore) InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error { + return InsertHashratePoint(ctx, minerName, minerType, point, resolution) +} + +func (s *defaultStore) GetHashrateHistory(minerName string, resolution Resolution, since, until time.Time) ([]HashratePoint, error) { + return GetHashrateHistory(minerName, resolution, since, until) +} + +func (s *defaultStore) GetHashrateStats(minerName string) (*HashrateStats, error) { + return GetHashrateStats(minerName) +} + +func (s *defaultStore) GetAllMinerStats() ([]HashrateStats, error) { + return GetAllMinerStats() +} + +func (s *defaultStore) Cleanup(retentionDays int) error { + return Cleanup(retentionDays) +} + +func (s *defaultStore) Close() error { + return Close() +} + +// NopStore returns a HashrateStore that does nothing. +// Useful for testing or when database is disabled. +func NopStore() HashrateStore { + return &nopStore{} +} + +type nopStore struct{} + +func (s *nopStore) InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error { + return nil +} + +func (s *nopStore) GetHashrateHistory(minerName string, resolution Resolution, since, until time.Time) ([]HashratePoint, error) { + return nil, nil +} + +func (s *nopStore) GetHashrateStats(minerName string) (*HashrateStats, error) { + return nil, nil +} + +func (s *nopStore) GetAllMinerStats() ([]HashrateStats, error) { + return nil, nil +} + +func (s *nopStore) Cleanup(retentionDays int) error { + return nil +} + +func (s *nopStore) Close() error { + return nil +} diff --git a/pkg/database/interface_test.go b/pkg/database/interface_test.go new file mode 100644 index 0000000..272fa43 --- /dev/null +++ b/pkg/database/interface_test.go @@ -0,0 +1,135 @@ +package database + +import ( + "context" + "testing" + "time" +) + +func TestDefaultStore(t *testing.T) { + cleanup := setupTestDB(t) + defer cleanup() + + store := DefaultStore() + + // Test InsertHashratePoint + point := HashratePoint{ + Timestamp: time.Now(), + Hashrate: 1500, + } + if err := store.InsertHashratePoint(nil, "interface-test", "xmrig", point, ResolutionHigh); err != nil { + t.Fatalf("InsertHashratePoint failed: %v", err) + } + + // Test GetHashrateHistory + history, err := store.GetHashrateHistory("interface-test", ResolutionHigh, time.Now().Add(-time.Hour), time.Now().Add(time.Hour)) + if err != nil { + t.Fatalf("GetHashrateHistory failed: %v", err) + } + if len(history) != 1 { + t.Errorf("Expected 1 point, got %d", len(history)) + } + + // Test GetHashrateStats + stats, err := store.GetHashrateStats("interface-test") + if err != nil { + t.Fatalf("GetHashrateStats failed: %v", err) + } + if stats == nil { + t.Fatal("Expected non-nil stats") + } + if stats.TotalPoints != 1 { + t.Errorf("Expected 1 total point, got %d", stats.TotalPoints) + } + + // Test GetAllMinerStats + allStats, err := store.GetAllMinerStats() + if err != nil { + t.Fatalf("GetAllMinerStats failed: %v", err) + } + if len(allStats) != 1 { + t.Errorf("Expected 1 miner in stats, got %d", len(allStats)) + } + + // Test Cleanup + if err := store.Cleanup(30); err != nil { + t.Fatalf("Cleanup failed: %v", err) + } +} + +func TestDefaultStore_WithContext(t *testing.T) { + cleanup := setupTestDB(t) + defer cleanup() + + store := DefaultStore() + ctx := context.Background() + + point := HashratePoint{ + Timestamp: time.Now(), + Hashrate: 2000, + } + if err := store.InsertHashratePoint(ctx, "ctx-test", "xmrig", point, ResolutionHigh); err != nil { + t.Fatalf("InsertHashratePoint with context failed: %v", err) + } + + history, err := store.GetHashrateHistory("ctx-test", ResolutionHigh, time.Now().Add(-time.Hour), time.Now().Add(time.Hour)) + if err != nil { + t.Fatalf("GetHashrateHistory failed: %v", err) + } + if len(history) != 1 { + t.Errorf("Expected 1 point, got %d", len(history)) + } +} + +func TestNopStore(t *testing.T) { + store := NopStore() + + // All operations should succeed without error + point := HashratePoint{ + Timestamp: time.Now(), + Hashrate: 1000, + } + if err := store.InsertHashratePoint(nil, "test", "xmrig", point, ResolutionHigh); err != nil { + t.Errorf("NopStore InsertHashratePoint should not error: %v", err) + } + + history, err := store.GetHashrateHistory("test", ResolutionHigh, time.Now().Add(-time.Hour), time.Now()) + if err != nil { + t.Errorf("NopStore GetHashrateHistory should not error: %v", err) + } + if history != nil { + t.Errorf("NopStore GetHashrateHistory should return nil, got %v", history) + } + + stats, err := store.GetHashrateStats("test") + if err != nil { + t.Errorf("NopStore GetHashrateStats should not error: %v", err) + } + if stats != nil { + t.Errorf("NopStore GetHashrateStats should return nil, got %v", stats) + } + + allStats, err := store.GetAllMinerStats() + if err != nil { + t.Errorf("NopStore GetAllMinerStats should not error: %v", err) + } + if allStats != nil { + t.Errorf("NopStore GetAllMinerStats should return nil, got %v", allStats) + } + + if err := store.Cleanup(30); err != nil { + t.Errorf("NopStore Cleanup should not error: %v", err) + } + + if err := store.Close(); err != nil { + t.Errorf("NopStore Close should not error: %v", err) + } +} + +// TestInterfaceCompatibility ensures all implementations satisfy HashrateStore +func TestInterfaceCompatibility(t *testing.T) { + var _ HashrateStore = DefaultStore() + var _ HashrateStore = NopStore() + var _ HashrateStore = &defaultStore{} + var _ HashrateStore = &nopStore{} +} diff --git a/pkg/mining/container.go b/pkg/mining/container.go new file mode 100644 index 0000000..14c150b --- /dev/null +++ b/pkg/mining/container.go @@ -0,0 +1,260 @@ +package mining + +import ( + "context" + "fmt" + "sync" + + "github.com/Snider/Mining/pkg/database" + "github.com/Snider/Mining/pkg/logging" +) + +// ContainerConfig holds configuration for the service container. +type ContainerConfig struct { + // Database configuration + Database database.Config + + // ListenAddr is the address to listen on (e.g., ":9090") + ListenAddr string + + // DisplayAddr is the address shown in Swagger docs + DisplayAddr string + + // SwaggerNamespace is the API path prefix + SwaggerNamespace string + + // SimulationMode enables simulation mode for testing + SimulationMode bool +} + +// DefaultContainerConfig returns sensible defaults for the container. +func DefaultContainerConfig() ContainerConfig { + return ContainerConfig{ + Database: database.Config{ + Enabled: true, + RetentionDays: 30, + }, + ListenAddr: ":9090", + DisplayAddr: "localhost:9090", + SwaggerNamespace: "/api/v1/mining", + SimulationMode: false, + } +} + +// Container manages the lifecycle of all services. +// It provides centralized initialization, dependency injection, and graceful shutdown. +type Container struct { + config ContainerConfig + mu sync.RWMutex + + // Core services + manager ManagerInterface + profileManager *ProfileManager + nodeService *NodeService + eventHub *EventHub + service *Service + + // Database store (interface for testing) + hashrateStore database.HashrateStore + + // Initialization state + initialized bool + transportStarted bool + shutdownCh chan struct{} +} + +// NewContainer creates a new service container with the given configuration. +func NewContainer(config ContainerConfig) *Container { + return &Container{ + config: config, + shutdownCh: make(chan struct{}), + } +} + +// Initialize sets up all services in the correct order. +// This should be called before Start(). +func (c *Container) Initialize(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.initialized { + return fmt.Errorf("container already initialized") + } + + // 1. Initialize database (optional) + if c.config.Database.Enabled { + if err := database.Initialize(c.config.Database); err != nil { + return fmt.Errorf("failed to initialize database: %w", err) + } + c.hashrateStore = database.DefaultStore() + logging.Info("database initialized", logging.Fields{"retention_days": c.config.Database.RetentionDays}) + } else { + c.hashrateStore = database.NopStore() + logging.Info("database disabled, using no-op store", nil) + } + + // 2. Initialize profile manager + var err error + c.profileManager, err = NewProfileManager() + if err != nil { + return fmt.Errorf("failed to initialize profile manager: %w", err) + } + + // 3. Initialize miner manager + if c.config.SimulationMode { + c.manager = NewManagerForSimulation() + } else { + c.manager = NewManager() + } + + // 4. Initialize node service (optional - P2P features) + c.nodeService, err = NewNodeService() + if err != nil { + logging.Warn("node service unavailable", logging.Fields{"error": err}) + // Continue without node service - P2P features will be unavailable + } + + // 5. Initialize event hub for WebSocket + c.eventHub = NewEventHub() + + // Wire up event hub to manager + if mgr, ok := c.manager.(*Manager); ok { + mgr.SetEventHub(c.eventHub) + } + + c.initialized = true + logging.Info("service container initialized", nil) + return nil +} + +// Start begins all background services. +func (c *Container) Start(ctx context.Context) error { + c.mu.RLock() + defer c.mu.RUnlock() + + if !c.initialized { + return fmt.Errorf("container not initialized") + } + + // Start event hub + go c.eventHub.Run() + + // Start node transport if available + if c.nodeService != nil { + if err := c.nodeService.StartTransport(); err != nil { + logging.Warn("failed to start node transport", logging.Fields{"error": err}) + } else { + c.transportStarted = true + } + } + + logging.Info("service container started", nil) + return nil +} + +// Shutdown gracefully stops all services in reverse order. +func (c *Container) Shutdown(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.initialized { + return nil + } + + logging.Info("shutting down service container", nil) + + var errs []error + + // 1. Stop service (HTTP server) + if c.service != nil { + // Service shutdown is handled externally + } + + // 2. Stop node transport (only if it was started) + if c.nodeService != nil && c.transportStarted { + if err := c.nodeService.StopTransport(); err != nil { + errs = append(errs, fmt.Errorf("node transport: %w", err)) + } + c.transportStarted = false + } + + // 3. Stop event hub + if c.eventHub != nil { + c.eventHub.Stop() + } + + // 4. Stop miner manager + if mgr, ok := c.manager.(*Manager); ok { + mgr.Stop() + } + + // 5. Close database + if err := database.Close(); err != nil { + errs = append(errs, fmt.Errorf("database: %w", err)) + } + + c.initialized = false + close(c.shutdownCh) + + if len(errs) > 0 { + return fmt.Errorf("shutdown errors: %v", errs) + } + + logging.Info("service container shutdown complete", nil) + return nil +} + +// Manager returns the miner manager. +func (c *Container) Manager() ManagerInterface { + c.mu.RLock() + defer c.mu.RUnlock() + return c.manager +} + +// ProfileManager returns the profile manager. +func (c *Container) ProfileManager() *ProfileManager { + c.mu.RLock() + defer c.mu.RUnlock() + return c.profileManager +} + + +// NodeService returns the node service (may be nil if P2P is unavailable). +func (c *Container) NodeService() *NodeService { + c.mu.RLock() + defer c.mu.RUnlock() + return c.nodeService +} + +// EventHub returns the event hub for WebSocket connections. +func (c *Container) EventHub() *EventHub { + c.mu.RLock() + defer c.mu.RUnlock() + return c.eventHub +} + +// HashrateStore returns the hashrate store interface. +func (c *Container) HashrateStore() database.HashrateStore { + c.mu.RLock() + defer c.mu.RUnlock() + return c.hashrateStore +} + +// SetHashrateStore allows injecting a custom hashrate store (useful for testing). +func (c *Container) SetHashrateStore(store database.HashrateStore) { + c.mu.Lock() + defer c.mu.Unlock() + c.hashrateStore = store +} + +// ShutdownCh returns a channel that's closed when shutdown is complete. +func (c *Container) ShutdownCh() <-chan struct{} { + return c.shutdownCh +} + +// IsInitialized returns true if the container has been initialized. +func (c *Container) IsInitialized() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.initialized +} diff --git a/pkg/mining/container_test.go b/pkg/mining/container_test.go new file mode 100644 index 0000000..0815b64 --- /dev/null +++ b/pkg/mining/container_test.go @@ -0,0 +1,238 @@ +package mining + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/Snider/Mining/pkg/database" +) + +func setupContainerTestEnv(t *testing.T) func() { + tmpDir := t.TempDir() + os.Setenv("XDG_CONFIG_HOME", filepath.Join(tmpDir, "config")) + os.Setenv("XDG_DATA_HOME", filepath.Join(tmpDir, "data")) + return func() { + os.Unsetenv("XDG_CONFIG_HOME") + os.Unsetenv("XDG_DATA_HOME") + } +} + +func TestNewContainer(t *testing.T) { + config := DefaultContainerConfig() + container := NewContainer(config) + + if container == nil { + t.Fatal("NewContainer returned nil") + } + + if container.IsInitialized() { + t.Error("Container should not be initialized before Initialize() is called") + } +} + +func TestDefaultContainerConfig(t *testing.T) { + config := DefaultContainerConfig() + + if !config.Database.Enabled { + t.Error("Database should be enabled by default") + } + + if config.Database.RetentionDays != 30 { + t.Errorf("Expected 30 retention days, got %d", config.Database.RetentionDays) + } + + if config.ListenAddr != ":9090" { + t.Errorf("Expected :9090, got %s", config.ListenAddr) + } + + if config.SimulationMode { + t.Error("SimulationMode should be false by default") + } +} + +func TestContainer_Initialize(t *testing.T) { + cleanup := setupContainerTestEnv(t) + defer cleanup() + + config := DefaultContainerConfig() + config.Database.Enabled = true + config.Database.Path = filepath.Join(t.TempDir(), "test.db") + config.SimulationMode = true // Use simulation mode for faster tests + + container := NewContainer(config) + ctx := context.Background() + + if err := container.Initialize(ctx); err != nil { + t.Fatalf("Initialize failed: %v", err) + } + + if !container.IsInitialized() { + t.Error("Container should be initialized after Initialize()") + } + + // Verify services are available + if container.Manager() == nil { + t.Error("Manager should not be nil after initialization") + } + + if container.ProfileManager() == nil { + t.Error("ProfileManager should not be nil after initialization") + } + + if container.EventHub() == nil { + t.Error("EventHub should not be nil after initialization") + } + + if container.HashrateStore() == nil { + t.Error("HashrateStore should not be nil after initialization") + } + + // Cleanup + if err := container.Shutdown(ctx); err != nil { + t.Errorf("Shutdown failed: %v", err) + } +} + +func TestContainer_InitializeTwice(t *testing.T) { + cleanup := setupContainerTestEnv(t) + defer cleanup() + + config := DefaultContainerConfig() + config.Database.Enabled = false + config.SimulationMode = true + + container := NewContainer(config) + ctx := context.Background() + + if err := container.Initialize(ctx); err != nil { + t.Fatalf("First Initialize failed: %v", err) + } + + // Second initialization should fail + if err := container.Initialize(ctx); err == nil { + t.Error("Second Initialize should fail") + } + + container.Shutdown(ctx) +} + +func TestContainer_DatabaseDisabled(t *testing.T) { + cleanup := setupContainerTestEnv(t) + defer cleanup() + + config := DefaultContainerConfig() + config.Database.Enabled = false + config.SimulationMode = true + + container := NewContainer(config) + ctx := context.Background() + + if err := container.Initialize(ctx); err != nil { + t.Fatalf("Initialize failed: %v", err) + } + + // Should use NopStore when database is disabled + store := container.HashrateStore() + if store == nil { + t.Fatal("HashrateStore should not be nil") + } + + // NopStore should accept inserts without error + point := database.HashratePoint{ + Timestamp: time.Now(), + Hashrate: 1000, + } + if err := store.InsertHashratePoint(nil, "test", "xmrig", point, database.ResolutionHigh); err != nil { + t.Errorf("NopStore insert should not fail: %v", err) + } + + container.Shutdown(ctx) +} + +func TestContainer_SetHashrateStore(t *testing.T) { + cleanup := setupContainerTestEnv(t) + defer cleanup() + + config := DefaultContainerConfig() + config.Database.Enabled = false + config.SimulationMode = true + + container := NewContainer(config) + ctx := context.Background() + + if err := container.Initialize(ctx); err != nil { + t.Fatalf("Initialize failed: %v", err) + } + + // Inject custom store + customStore := database.NopStore() + container.SetHashrateStore(customStore) + + if container.HashrateStore() != customStore { + t.Error("SetHashrateStore should update the store") + } + + container.Shutdown(ctx) +} + +func TestContainer_StartWithoutInitialize(t *testing.T) { + config := DefaultContainerConfig() + container := NewContainer(config) + ctx := context.Background() + + if err := container.Start(ctx); err == nil { + t.Error("Start should fail if Initialize was not called") + } +} + +func TestContainer_ShutdownWithoutInitialize(t *testing.T) { + config := DefaultContainerConfig() + container := NewContainer(config) + ctx := context.Background() + + // Shutdown on uninitialized container should not error + if err := container.Shutdown(ctx); err != nil { + t.Errorf("Shutdown on uninitialized container should not error: %v", err) + } +} + +func TestContainer_ShutdownChannel(t *testing.T) { + cleanup := setupContainerTestEnv(t) + defer cleanup() + + config := DefaultContainerConfig() + config.Database.Enabled = false + config.SimulationMode = true + + container := NewContainer(config) + ctx := context.Background() + + if err := container.Initialize(ctx); err != nil { + t.Fatalf("Initialize failed: %v", err) + } + + shutdownCh := container.ShutdownCh() + + // Channel should be open before shutdown + select { + case <-shutdownCh: + t.Error("ShutdownCh should not be closed before Shutdown()") + default: + // Expected + } + + if err := container.Shutdown(ctx); err != nil { + t.Errorf("Shutdown failed: %v", err) + } + + // Channel should be closed after shutdown + select { + case <-shutdownCh: + // Expected + case <-time.After(time.Second): + t.Error("ShutdownCh should be closed after Shutdown()") + } +} diff --git a/pkg/mining/manager.go b/pkg/mining/manager.go index 3af4baa..ebb747e 100644 --- a/pkg/mining/manager.go +++ b/pkg/mining/manager.go @@ -614,7 +614,8 @@ func (m *Manager) collectSingleMinerStats(miner Miner, minerType string, now tim Timestamp: point.Timestamp, Hashrate: point.Hashrate, } - if err := database.InsertHashratePoint(minerName, minerType, dbPoint, database.ResolutionHigh); err != nil { + // Use nil context to let InsertHashratePoint use its default timeout + if err := database.InsertHashratePoint(nil, minerName, minerType, dbPoint, database.ResolutionHigh); err != nil { logging.Warn("failed to persist hashrate", logging.Fields{"miner": minerName, "error": err}) } } diff --git a/pkg/mining/miner.go b/pkg/mining/miner.go index 44c1d05..e3c4144 100644 --- a/pkg/mining/miner.go +++ b/pkg/mining/miner.go @@ -246,6 +246,7 @@ func (b *BaseMiner) InstallFromURL(url string) error { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse return fmt.Errorf("failed to download release: unexpected status code %d", resp.StatusCode) } diff --git a/pkg/mining/node_service.go b/pkg/mining/node_service.go index d2b06e6..a85e718 100644 --- a/pkg/mining/node_service.go +++ b/pkg/mining/node_service.go @@ -334,7 +334,8 @@ func (ns *NodeService) handlePeerStats(c *gin.Context) { // RemoteStartRequest is the request body for starting a remote miner. type RemoteStartRequest struct { - ProfileID string `json:"profileId" binding:"required"` + MinerType string `json:"minerType" binding:"required"` + ProfileID string `json:"profileId,omitempty"` Config json.RawMessage `json:"config,omitempty"` } @@ -356,7 +357,7 @@ func (ns *NodeService) handleRemoteStart(c *gin.Context) { return } - if err := ns.controller.StartRemoteMiner(peerID, req.ProfileID, req.Config); err != nil { + if err := ns.controller.StartRemoteMiner(peerID, req.MinerType, req.ProfileID, req.Config); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } diff --git a/pkg/mining/repository.go b/pkg/mining/repository.go new file mode 100644 index 0000000..7f405c1 --- /dev/null +++ b/pkg/mining/repository.go @@ -0,0 +1,196 @@ +package mining + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" +) + +// Repository defines a generic interface for data persistence. +// Implementations can store data in files, databases, etc. +type Repository[T any] interface { + // Load reads data from the repository + Load() (T, error) + + // Save writes data to the repository + Save(data T) error + + // Update atomically loads, modifies, and saves data + Update(fn func(*T) error) error +} + +// FileRepository provides atomic file-based persistence for JSON data. +// It uses atomic writes (temp file + rename) to prevent corruption. +type FileRepository[T any] struct { + mu sync.RWMutex + path string + defaults func() T +} + +// FileRepositoryOption configures a FileRepository. +type FileRepositoryOption[T any] func(*FileRepository[T]) + +// WithDefaults sets the default value factory for when the file doesn't exist. +func WithDefaults[T any](fn func() T) FileRepositoryOption[T] { + return func(r *FileRepository[T]) { + r.defaults = fn + } +} + +// NewFileRepository creates a new file-based repository. +func NewFileRepository[T any](path string, opts ...FileRepositoryOption[T]) *FileRepository[T] { + r := &FileRepository[T]{ + path: path, + } + for _, opt := range opts { + opt(r) + } + return r +} + +// Load reads and deserializes data from the file. +// Returns defaults if file doesn't exist. +func (r *FileRepository[T]) Load() (T, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + var result T + + data, err := os.ReadFile(r.path) + if err != nil { + if os.IsNotExist(err) { + if r.defaults != nil { + return r.defaults(), nil + } + return result, nil + } + return result, fmt.Errorf("failed to read file: %w", err) + } + + if err := json.Unmarshal(data, &result); err != nil { + return result, fmt.Errorf("failed to unmarshal data: %w", err) + } + + return result, nil +} + +// Save serializes and writes data to the file atomically. +func (r *FileRepository[T]) Save(data T) error { + r.mu.Lock() + defer r.mu.Unlock() + + return r.saveUnlocked(data) +} + +// saveUnlocked saves data without acquiring the lock (caller must hold lock). +func (r *FileRepository[T]) saveUnlocked(data T) error { + dir := filepath.Dir(r.path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + jsonData, err := json.MarshalIndent(data, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + + // Atomic write: write to temp file, sync, then rename + tmpFile, err := os.CreateTemp(dir, "repo-*.tmp") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + // Clean up temp file on error + success := false + defer func() { + if !success { + os.Remove(tmpPath) + } + }() + + if _, err := tmpFile.Write(jsonData); err != nil { + tmpFile.Close() + return fmt.Errorf("failed to write temp file: %w", err) + } + + if err := tmpFile.Sync(); err != nil { + tmpFile.Close() + return fmt.Errorf("failed to sync temp file: %w", err) + } + + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("failed to close temp file: %w", err) + } + + if err := os.Chmod(tmpPath, 0600); err != nil { + return fmt.Errorf("failed to set file permissions: %w", err) + } + + if err := os.Rename(tmpPath, r.path); err != nil { + return fmt.Errorf("failed to rename temp file: %w", err) + } + + success = true + return nil +} + +// Update atomically loads, modifies, and saves data. +// The modification function receives a pointer to the data. +func (r *FileRepository[T]) Update(fn func(*T) error) error { + r.mu.Lock() + defer r.mu.Unlock() + + // Load current data + var data T + fileData, err := os.ReadFile(r.path) + if err != nil { + if os.IsNotExist(err) { + if r.defaults != nil { + data = r.defaults() + } + } else { + return fmt.Errorf("failed to read file: %w", err) + } + } else { + if err := json.Unmarshal(fileData, &data); err != nil { + return fmt.Errorf("failed to unmarshal data: %w", err) + } + } + + // Apply modification + if err := fn(&data); err != nil { + return err + } + + // Save atomically + return r.saveUnlocked(data) +} + +// Path returns the file path of this repository. +func (r *FileRepository[T]) Path() string { + return r.path +} + +// Exists returns true if the repository file exists. +func (r *FileRepository[T]) Exists() bool { + r.mu.RLock() + defer r.mu.RUnlock() + + _, err := os.Stat(r.path) + return err == nil +} + +// Delete removes the repository file. +func (r *FileRepository[T]) Delete() error { + r.mu.Lock() + defer r.mu.Unlock() + + err := os.Remove(r.path) + if os.IsNotExist(err) { + return nil + } + return err +} diff --git a/pkg/mining/repository_test.go b/pkg/mining/repository_test.go new file mode 100644 index 0000000..a15a475 --- /dev/null +++ b/pkg/mining/repository_test.go @@ -0,0 +1,300 @@ +package mining + +import ( + "errors" + "os" + "path/filepath" + "testing" +) + +type testData struct { + Name string `json:"name"` + Value int `json:"value"` +} + +func TestFileRepository_Load(t *testing.T) { + t.Run("NonExistentFile", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "nonexistent.json") + repo := NewFileRepository[testData](path) + + data, err := repo.Load() + if err != nil { + t.Fatalf("Load should not error for non-existent file: %v", err) + } + if data.Name != "" || data.Value != 0 { + t.Error("Expected zero value for non-existent file") + } + }) + + t.Run("NonExistentFileWithDefaults", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "nonexistent.json") + repo := NewFileRepository[testData](path, WithDefaults(func() testData { + return testData{Name: "default", Value: 42} + })) + + data, err := repo.Load() + if err != nil { + t.Fatalf("Load should not error: %v", err) + } + if data.Name != "default" || data.Value != 42 { + t.Errorf("Expected default values, got %+v", data) + } + }) + + t.Run("ExistingFile", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "test.json") + + // Write test data + if err := os.WriteFile(path, []byte(`{"name":"test","value":123}`), 0600); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + repo := NewFileRepository[testData](path) + data, err := repo.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if data.Name != "test" || data.Value != 123 { + t.Errorf("Unexpected data: %+v", data) + } + }) + + t.Run("InvalidJSON", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "invalid.json") + + if err := os.WriteFile(path, []byte(`{invalid json}`), 0600); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + repo := NewFileRepository[testData](path) + _, err := repo.Load() + if err == nil { + t.Error("Expected error for invalid JSON") + } + }) +} + +func TestFileRepository_Save(t *testing.T) { + t.Run("NewFile", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "subdir", "new.json") + repo := NewFileRepository[testData](path) + + data := testData{Name: "saved", Value: 456} + if err := repo.Save(data); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Verify file was created + if !repo.Exists() { + t.Error("File should exist after save") + } + + // Verify content + loaded, err := repo.Load() + if err != nil { + t.Fatalf("Load after save failed: %v", err) + } + if loaded.Name != "saved" || loaded.Value != 456 { + t.Errorf("Unexpected loaded data: %+v", loaded) + } + }) + + t.Run("OverwriteExisting", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "existing.json") + repo := NewFileRepository[testData](path) + + // Save initial data + if err := repo.Save(testData{Name: "first", Value: 1}); err != nil { + t.Fatalf("First save failed: %v", err) + } + + // Overwrite + if err := repo.Save(testData{Name: "second", Value: 2}); err != nil { + t.Fatalf("Second save failed: %v", err) + } + + // Verify overwrite + loaded, err := repo.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if loaded.Name != "second" || loaded.Value != 2 { + t.Errorf("Expected overwritten data, got: %+v", loaded) + } + }) +} + +func TestFileRepository_Update(t *testing.T) { + t.Run("UpdateExisting", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "update.json") + repo := NewFileRepository[testData](path) + + // Save initial data + if err := repo.Save(testData{Name: "initial", Value: 10}); err != nil { + t.Fatalf("Initial save failed: %v", err) + } + + // Update + err := repo.Update(func(data *testData) error { + data.Value += 5 + return nil + }) + if err != nil { + t.Fatalf("Update failed: %v", err) + } + + // Verify update + loaded, err := repo.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if loaded.Value != 15 { + t.Errorf("Expected value 15, got %d", loaded.Value) + } + }) + + t.Run("UpdateNonExistentWithDefaults", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "new.json") + repo := NewFileRepository[testData](path, WithDefaults(func() testData { + return testData{Name: "default", Value: 100} + })) + + err := repo.Update(func(data *testData) error { + data.Value *= 2 + return nil + }) + if err != nil { + t.Fatalf("Update failed: %v", err) + } + + // Verify update started from defaults + loaded, err := repo.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if loaded.Value != 200 { + t.Errorf("Expected value 200, got %d", loaded.Value) + } + }) + + t.Run("UpdateWithError", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "error.json") + repo := NewFileRepository[testData](path) + + if err := repo.Save(testData{Name: "test", Value: 1}); err != nil { + t.Fatalf("Initial save failed: %v", err) + } + + // Update that returns error + testErr := errors.New("update error") + err := repo.Update(func(data *testData) error { + data.Value = 999 // This change should not be saved + return testErr + }) + if err != testErr { + t.Errorf("Expected test error, got: %v", err) + } + + // Verify original data unchanged + loaded, err := repo.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if loaded.Value != 1 { + t.Errorf("Expected value 1 (unchanged), got %d", loaded.Value) + } + }) +} + +func TestFileRepository_Delete(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "delete.json") + repo := NewFileRepository[testData](path) + + // Save data + if err := repo.Save(testData{Name: "temp", Value: 1}); err != nil { + t.Fatalf("Save failed: %v", err) + } + + if !repo.Exists() { + t.Error("File should exist after save") + } + + // Delete + if err := repo.Delete(); err != nil { + t.Fatalf("Delete failed: %v", err) + } + + if repo.Exists() { + t.Error("File should not exist after delete") + } + + // Delete non-existent should not error + if err := repo.Delete(); err != nil { + t.Errorf("Delete non-existent should not error: %v", err) + } +} + +func TestFileRepository_Path(t *testing.T) { + path := "/some/path/config.json" + repo := NewFileRepository[testData](path) + + if repo.Path() != path { + t.Errorf("Expected path %s, got %s", path, repo.Path()) + } +} + +// Test with slice data +func TestFileRepository_SliceData(t *testing.T) { + type item struct { + ID string `json:"id"` + Name string `json:"name"` + } + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "items.json") + repo := NewFileRepository[[]item](path, WithDefaults(func() []item { + return []item{} + })) + + // Save slice + items := []item{ + {ID: "1", Name: "First"}, + {ID: "2", Name: "Second"}, + } + if err := repo.Save(items); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Load and verify + loaded, err := repo.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if len(loaded) != 2 { + t.Errorf("Expected 2 items, got %d", len(loaded)) + } + + // Update slice + err = repo.Update(func(data *[]item) error { + *data = append(*data, item{ID: "3", Name: "Third"}) + return nil + }) + if err != nil { + t.Fatalf("Update failed: %v", err) + } + + loaded, _ = repo.Load() + if len(loaded) != 3 { + t.Errorf("Expected 3 items after update, got %d", len(loaded)) + } +} diff --git a/pkg/mining/ttminer.go b/pkg/mining/ttminer.go index f42aba3..9523c1e 100644 --- a/pkg/mining/ttminer.go +++ b/pkg/mining/ttminer.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "os" "os/exec" @@ -92,6 +93,7 @@ func (m *TTMiner) GetLatestVersion() (string, error) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse return "", fmt.Errorf("failed to get latest release: unexpected status code %d", resp.StatusCode) } diff --git a/pkg/mining/ttminer_stats.go b/pkg/mining/ttminer_stats.go index 6fc863d..810adb9 100644 --- a/pkg/mining/ttminer_stats.go +++ b/pkg/mining/ttminer_stats.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" ) @@ -41,6 +42,7 @@ func (m *TTMiner) GetStats(ctx context.Context) (*PerformanceMetrics, error) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse return nil, fmt.Errorf("failed to get stats: unexpected status code %d", resp.StatusCode) } diff --git a/pkg/mining/xmrig.go b/pkg/mining/xmrig.go index 7be552b..73ceda4 100644 --- a/pkg/mining/xmrig.go +++ b/pkg/mining/xmrig.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "os" "os/exec" @@ -93,6 +94,7 @@ func (m *XMRigMiner) GetLatestVersion() (string, error) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse return "", fmt.Errorf("failed to get latest release: unexpected status code %d", resp.StatusCode) } diff --git a/pkg/mining/xmrig_stats.go b/pkg/mining/xmrig_stats.go index 9e9d2a1..4fe4e2b 100644 --- a/pkg/mining/xmrig_stats.go +++ b/pkg/mining/xmrig_stats.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "time" ) @@ -45,6 +46,7 @@ func (m *XMRigMiner) GetStats(ctx context.Context) (*PerformanceMetrics, error) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse return nil, fmt.Errorf("failed to get stats: unexpected status code %d", resp.StatusCode) } diff --git a/pkg/node/controller.go b/pkg/node/controller.go index 3914801..c744583 100644 --- a/pkg/node/controller.go +++ b/pkg/node/controller.go @@ -125,34 +125,27 @@ func (c *Controller) GetRemoteStats(peerID string) (*StatsPayload, error) { return nil, err } - if resp.Type == MsgError { - var errPayload ErrorPayload - if err := resp.ParsePayload(&errPayload); err != nil { - return nil, fmt.Errorf("remote error (unable to parse)") - } - return nil, fmt.Errorf("remote error: %s", errPayload.Message) - } - - if resp.Type != MsgStats { - return nil, fmt.Errorf("unexpected response type: %s", resp.Type) - } - var stats StatsPayload - if err := resp.ParsePayload(&stats); err != nil { - return nil, fmt.Errorf("failed to parse stats: %w", err) + if err := ParseResponse(resp, MsgStats, &stats); err != nil { + return nil, err } return &stats, nil } // StartRemoteMiner requests a remote peer to start a miner with a given profile. -func (c *Controller) StartRemoteMiner(peerID, profileID string, configOverride json.RawMessage) error { +func (c *Controller) StartRemoteMiner(peerID, minerType, profileID string, configOverride json.RawMessage) error { identity := c.node.GetIdentity() if identity == nil { return fmt.Errorf("node identity not initialized") } + if minerType == "" { + return fmt.Errorf("miner type is required") + } + payload := StartMinerPayload{ + MinerType: minerType, ProfileID: profileID, Config: configOverride, } @@ -167,21 +160,9 @@ func (c *Controller) StartRemoteMiner(peerID, profileID string, configOverride j return err } - if resp.Type == MsgError { - var errPayload ErrorPayload - if err := resp.ParsePayload(&errPayload); err != nil { - return fmt.Errorf("remote error (unable to parse)") - } - return fmt.Errorf("remote error: %s", errPayload.Message) - } - - if resp.Type != MsgMinerAck { - return fmt.Errorf("unexpected response type: %s", resp.Type) - } - var ack MinerAckPayload - if err := resp.ParsePayload(&ack); err != nil { - return fmt.Errorf("failed to parse ack: %w", err) + if err := ParseResponse(resp, MsgMinerAck, &ack); err != nil { + return err } if !ack.Success { @@ -212,21 +193,9 @@ func (c *Controller) StopRemoteMiner(peerID, minerName string) error { return err } - if resp.Type == MsgError { - var errPayload ErrorPayload - if err := resp.ParsePayload(&errPayload); err != nil { - return fmt.Errorf("remote error (unable to parse)") - } - return fmt.Errorf("remote error: %s", errPayload.Message) - } - - if resp.Type != MsgMinerAck { - return fmt.Errorf("unexpected response type: %s", resp.Type) - } - var ack MinerAckPayload - if err := resp.ParsePayload(&ack); err != nil { - return fmt.Errorf("failed to parse ack: %w", err) + if err := ParseResponse(resp, MsgMinerAck, &ack); err != nil { + return err } if !ack.Success { @@ -258,21 +227,9 @@ func (c *Controller) GetRemoteLogs(peerID, minerName string, lines int) ([]strin return nil, err } - if resp.Type == MsgError { - var errPayload ErrorPayload - if err := resp.ParsePayload(&errPayload); err != nil { - return nil, fmt.Errorf("remote error (unable to parse)") - } - return nil, fmt.Errorf("remote error: %s", errPayload.Message) - } - - if resp.Type != MsgLogs { - return nil, fmt.Errorf("unexpected response type: %s", resp.Type) - } - var logs LogsPayload - if err := resp.ParsePayload(&logs); err != nil { - return nil, fmt.Errorf("failed to parse logs: %w", err) + if err := ParseResponse(resp, MsgLogs, &logs); err != nil { + return nil, err } return logs.Lines, nil @@ -325,8 +282,8 @@ func (c *Controller) PingPeer(peerID string) (float64, error) { return 0, err } - if resp.Type != MsgPong { - return 0, fmt.Errorf("unexpected response type: %s", resp.Type) + if err := ValidateResponse(resp, MsgPong); err != nil { + return 0, err } // Calculate round-trip time diff --git a/pkg/node/message.go b/pkg/node/message.go index 8599798..58b9f24 100644 --- a/pkg/node/message.go +++ b/pkg/node/message.go @@ -117,7 +117,8 @@ type PongPayload struct { // StartMinerPayload requests starting a miner. type StartMinerPayload struct { - ProfileID string `json:"profileId"` + MinerType string `json:"minerType"` // Required: miner type (e.g., "xmrig", "tt-miner") + ProfileID string `json:"profileId,omitempty"` Config json.RawMessage `json:"config,omitempty"` // Override profile config } diff --git a/pkg/node/message_test.go b/pkg/node/message_test.go index 6cc03ab..6f68ffc 100644 --- a/pkg/node/message_test.go +++ b/pkg/node/message_test.go @@ -92,6 +92,7 @@ func TestMessageReply(t *testing.T) { func TestParsePayload(t *testing.T) { t.Run("ValidPayload", func(t *testing.T) { payload := StartMinerPayload{ + MinerType: "xmrig", ProfileID: "test-profile", } @@ -190,6 +191,7 @@ func TestNewErrorMessage(t *testing.T) { func TestMessageSerialization(t *testing.T) { original, _ := NewMessage(MsgStartMiner, "ctrl", "worker", StartMinerPayload{ + MinerType: "xmrig", ProfileID: "my-profile", }) diff --git a/pkg/node/protocol.go b/pkg/node/protocol.go new file mode 100644 index 0000000..197d5e4 --- /dev/null +++ b/pkg/node/protocol.go @@ -0,0 +1,88 @@ +package node + +import ( + "fmt" +) + +// ProtocolError represents an error from the remote peer. +type ProtocolError struct { + Code int + Message string +} + +func (e *ProtocolError) Error() string { + return fmt.Sprintf("remote error (%d): %s", e.Code, e.Message) +} + +// ResponseHandler provides helpers for handling protocol responses. +type ResponseHandler struct{} + +// ValidateResponse checks if the response is valid and returns a parsed error if it's an error response. +// It checks: +// 1. If response is nil (returns error) +// 2. If response is an error message (returns ProtocolError) +// 3. If response type matches expected (returns error if not) +func (h *ResponseHandler) ValidateResponse(resp *Message, expectedType MessageType) error { + if resp == nil { + return fmt.Errorf("nil response") + } + + // Check for error response + if resp.Type == MsgError { + var errPayload ErrorPayload + if err := resp.ParsePayload(&errPayload); err != nil { + return &ProtocolError{Code: ErrCodeUnknown, Message: "unable to parse error response"} + } + return &ProtocolError{Code: errPayload.Code, Message: errPayload.Message} + } + + // Check expected type + if resp.Type != expectedType { + return fmt.Errorf("unexpected response type: expected %s, got %s", expectedType, resp.Type) + } + + return nil +} + +// ParseResponse validates the response and parses the payload into the target. +// This combines ValidateResponse and ParsePayload into a single call. +func (h *ResponseHandler) ParseResponse(resp *Message, expectedType MessageType, target interface{}) error { + if err := h.ValidateResponse(resp, expectedType); err != nil { + return err + } + + if target != nil { + if err := resp.ParsePayload(target); err != nil { + return fmt.Errorf("failed to parse %s payload: %w", expectedType, err) + } + } + + return nil +} + +// DefaultResponseHandler is the default response handler instance. +var DefaultResponseHandler = &ResponseHandler{} + +// ValidateResponse is a convenience function using the default handler. +func ValidateResponse(resp *Message, expectedType MessageType) error { + return DefaultResponseHandler.ValidateResponse(resp, expectedType) +} + +// ParseResponse is a convenience function using the default handler. +func ParseResponse(resp *Message, expectedType MessageType, target interface{}) error { + return DefaultResponseHandler.ParseResponse(resp, expectedType, target) +} + +// IsProtocolError returns true if the error is a ProtocolError. +func IsProtocolError(err error) bool { + _, ok := err.(*ProtocolError) + return ok +} + +// GetProtocolErrorCode returns the error code if err is a ProtocolError, otherwise returns 0. +func GetProtocolErrorCode(err error) int { + if pe, ok := err.(*ProtocolError); ok { + return pe.Code + } + return 0 +} diff --git a/pkg/node/protocol_test.go b/pkg/node/protocol_test.go new file mode 100644 index 0000000..1d728a4 --- /dev/null +++ b/pkg/node/protocol_test.go @@ -0,0 +1,161 @@ +package node + +import ( + "fmt" + "testing" +) + +func TestResponseHandler_ValidateResponse(t *testing.T) { + handler := &ResponseHandler{} + + t.Run("NilResponse", func(t *testing.T) { + err := handler.ValidateResponse(nil, MsgStats) + if err == nil { + t.Error("Expected error for nil response") + } + }) + + t.Run("ErrorResponse", func(t *testing.T) { + errMsg, _ := NewErrorMessage("sender", "receiver", ErrCodeOperationFailed, "operation failed", "") + err := handler.ValidateResponse(errMsg, MsgStats) + if err == nil { + t.Fatal("Expected error for error response") + } + + if !IsProtocolError(err) { + t.Errorf("Expected ProtocolError, got %T", err) + } + + if GetProtocolErrorCode(err) != ErrCodeOperationFailed { + t.Errorf("Expected code %d, got %d", ErrCodeOperationFailed, GetProtocolErrorCode(err)) + } + }) + + t.Run("WrongType", func(t *testing.T) { + msg, _ := NewMessage(MsgPong, "sender", "receiver", nil) + err := handler.ValidateResponse(msg, MsgStats) + if err == nil { + t.Error("Expected error for wrong type") + } + if IsProtocolError(err) { + t.Error("Should not be a ProtocolError for type mismatch") + } + }) + + t.Run("ValidResponse", func(t *testing.T) { + msg, _ := NewMessage(MsgStats, "sender", "receiver", StatsPayload{NodeID: "test"}) + err := handler.ValidateResponse(msg, MsgStats) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) +} + +func TestResponseHandler_ParseResponse(t *testing.T) { + handler := &ResponseHandler{} + + t.Run("ParseStats", func(t *testing.T) { + payload := StatsPayload{ + NodeID: "node-123", + NodeName: "Test Node", + Uptime: 3600, + } + msg, _ := NewMessage(MsgStats, "sender", "receiver", payload) + + var parsed StatsPayload + err := handler.ParseResponse(msg, MsgStats, &parsed) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if parsed.NodeID != "node-123" { + t.Errorf("Expected NodeID 'node-123', got '%s'", parsed.NodeID) + } + if parsed.Uptime != 3600 { + t.Errorf("Expected Uptime 3600, got %d", parsed.Uptime) + } + }) + + t.Run("ParseMinerAck", func(t *testing.T) { + payload := MinerAckPayload{ + Success: true, + MinerName: "xmrig-1", + } + msg, _ := NewMessage(MsgMinerAck, "sender", "receiver", payload) + + var parsed MinerAckPayload + err := handler.ParseResponse(msg, MsgMinerAck, &parsed) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !parsed.Success { + t.Error("Expected Success to be true") + } + if parsed.MinerName != "xmrig-1" { + t.Errorf("Expected MinerName 'xmrig-1', got '%s'", parsed.MinerName) + } + }) + + t.Run("ErrorResponse", func(t *testing.T) { + errMsg, _ := NewErrorMessage("sender", "receiver", ErrCodeNotFound, "not found", "") + + var parsed StatsPayload + err := handler.ParseResponse(errMsg, MsgStats, &parsed) + if err == nil { + t.Error("Expected error for error response") + } + if !IsProtocolError(err) { + t.Errorf("Expected ProtocolError, got %T", err) + } + }) + + t.Run("NilTarget", func(t *testing.T) { + msg, _ := NewMessage(MsgPong, "sender", "receiver", nil) + err := handler.ParseResponse(msg, MsgPong, nil) + if err != nil { + t.Errorf("Unexpected error with nil target: %v", err) + } + }) +} + +func TestProtocolError(t *testing.T) { + err := &ProtocolError{Code: 1001, Message: "test error"} + + if err.Error() != "remote error (1001): test error" { + t.Errorf("Unexpected error message: %s", err.Error()) + } + + if !IsProtocolError(err) { + t.Error("IsProtocolError should return true") + } + + if GetProtocolErrorCode(err) != 1001 { + t.Errorf("Expected code 1001, got %d", GetProtocolErrorCode(err)) + } +} + +func TestConvenienceFunctions(t *testing.T) { + msg, _ := NewMessage(MsgStats, "sender", "receiver", StatsPayload{NodeID: "test"}) + + // Test ValidateResponse + if err := ValidateResponse(msg, MsgStats); err != nil { + t.Errorf("ValidateResponse failed: %v", err) + } + + // Test ParseResponse + var parsed StatsPayload + if err := ParseResponse(msg, MsgStats, &parsed); err != nil { + t.Errorf("ParseResponse failed: %v", err) + } + if parsed.NodeID != "test" { + t.Errorf("Expected NodeID 'test', got '%s'", parsed.NodeID) + } +} + +func TestGetProtocolErrorCode_NonProtocolError(t *testing.T) { + err := fmt.Errorf("regular error") + if GetProtocolErrorCode(err) != 0 { + t.Error("Expected 0 for non-ProtocolError") + } +} diff --git a/pkg/node/worker.go b/pkg/node/worker.go index 518f252..c4d82a3 100644 --- a/pkg/node/worker.go +++ b/pkg/node/worker.go @@ -198,6 +198,11 @@ func (w *Worker) handleStartMiner(msg *Message) (*Message, error) { return nil, fmt.Errorf("invalid start miner payload: %w", err) } + // Validate miner type is provided + if payload.MinerType == "" { + return nil, fmt.Errorf("miner type is required") + } + // Get the config from the profile or use the override var config interface{} if payload.Config != nil { @@ -213,7 +218,7 @@ func (w *Worker) handleStartMiner(msg *Message) (*Message, error) { } // Start the miner - miner, err := w.minerManager.StartMiner("", config) + miner, err := w.minerManager.StartMiner(payload.MinerType, config) if err != nil { ack := MinerAckPayload{ Success: false, diff --git a/pkg/node/worker_test.go b/pkg/node/worker_test.go index f9d0303..27f67a4 100644 --- a/pkg/node/worker_test.go +++ b/pkg/node/worker_test.go @@ -247,7 +247,7 @@ func TestWorker_HandleStartMiner_NoManager(t *testing.T) { if identity == nil { t.Fatal("expected identity to be generated") } - payload := StartMinerPayload{ProfileID: "test-profile"} + payload := StartMinerPayload{MinerType: "xmrig", ProfileID: "test-profile"} msg, err := NewMessage(MsgStartMiner, "sender-id", identity.ID, payload) if err != nil { t.Fatalf("failed to create start_miner message: %v", err)