refactor: Add reliability fixes and architecture improvements
Reliability fixes: - Fix HTTP response body drainage in xmrig, ttminer, miner - Fix database init race condition (nil before close) - Fix empty minerType bug in P2P StartMinerPayload - Add context timeout to InsertHashratePoint (5s default) Architecture improvements: - Extract HashrateStore interface with DefaultStore/NopStore - Create ServiceContainer for centralized initialization - Extract protocol response handler (ValidateResponse, ParseResponse) - Create generic FileRepository[T] with atomic writes 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
89f74aebff
commit
34f860309f
25 changed files with 1553 additions and 88 deletions
|
|
@ -82,10 +82,11 @@ var remoteStartCmd = &cobra.Command{
|
|||
Long: `Start a miner on a remote peer using a profile.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
profileID, _ := cmd.Flags().GetString("profile")
|
||||
if profileID == "" {
|
||||
return fmt.Errorf("--profile is required")
|
||||
minerType, _ := cmd.Flags().GetString("type")
|
||||
if minerType == "" {
|
||||
return fmt.Errorf("--type is required (e.g., xmrig, tt-miner)")
|
||||
}
|
||||
profileID, _ := cmd.Flags().GetString("profile")
|
||||
|
||||
peerID := args[0]
|
||||
peer := findPeerByPartialID(peerID)
|
||||
|
|
@ -98,8 +99,8 @@ var remoteStartCmd = &cobra.Command{
|
|||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Starting miner on %s with profile %s...\n", peer.Name, profileID)
|
||||
if err := ctrl.StartRemoteMiner(peer.ID, profileID, nil); err != nil {
|
||||
fmt.Printf("Starting %s miner on %s with profile %s...\n", minerType, peer.Name, profileID)
|
||||
if err := ctrl.StartRemoteMiner(peer.ID, minerType, profileID, nil); err != nil {
|
||||
return fmt.Errorf("failed to start miner: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -298,6 +299,7 @@ func init() {
|
|||
// remote start
|
||||
remoteCmd.AddCommand(remoteStartCmd)
|
||||
remoteStartCmd.Flags().StringP("profile", "p", "", "Profile ID to use for starting the miner")
|
||||
remoteStartCmd.Flags().StringP("type", "t", "", "Miner type (e.g., xmrig, tt-miner)")
|
||||
|
||||
// remote stop
|
||||
remoteCmd.AddCommand(remoteStopCmd)
|
||||
|
|
|
|||
|
|
@ -77,8 +77,10 @@ func Initialize(cfg Config) error {
|
|||
|
||||
// Create tables
|
||||
if err := createTables(); err != nil {
|
||||
db.Close()
|
||||
// Nil out global before closing to prevent use of closed connection
|
||||
closingDB := db
|
||||
db = nil
|
||||
closingDB.Close()
|
||||
return fmt.Errorf("failed to create tables: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ func TestConcurrentHashrateInserts(t *testing.T) {
|
|||
Timestamp: time.Now().Add(time.Duration(-j) * time.Second),
|
||||
Hashrate: 1000 + minerIndex*100 + j,
|
||||
}
|
||||
err := InsertHashratePoint(minerName, minerType, point, ResolutionHigh)
|
||||
err := InsertHashratePoint(nil, minerName, minerType, point, ResolutionHigh)
|
||||
if err != nil {
|
||||
t.Errorf("Insert error for %s: %v", minerName, err)
|
||||
}
|
||||
|
|
@ -95,7 +95,7 @@ func TestConcurrentInsertAndQuery(t *testing.T) {
|
|||
Timestamp: time.Now(),
|
||||
Hashrate: 1000 + i,
|
||||
}
|
||||
InsertHashratePoint("concurrent-test", "xmrig", point, ResolutionHigh)
|
||||
InsertHashratePoint(nil, "concurrent-test", "xmrig", point, ResolutionHigh)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
|
@ -149,13 +149,13 @@ func TestConcurrentInsertAndCleanup(t *testing.T) {
|
|||
Timestamp: time.Now().AddDate(0, 0, -10), // 10 days old
|
||||
Hashrate: 500 + i,
|
||||
}
|
||||
InsertHashratePoint("cleanup-test", "xmrig", oldPoint, ResolutionHigh)
|
||||
InsertHashratePoint(nil, "cleanup-test", "xmrig", oldPoint, ResolutionHigh)
|
||||
|
||||
newPoint := HashratePoint{
|
||||
Timestamp: time.Now(),
|
||||
Hashrate: 1000 + i,
|
||||
}
|
||||
InsertHashratePoint("cleanup-test", "xmrig", newPoint, ResolutionHigh)
|
||||
InsertHashratePoint(nil, "cleanup-test", "xmrig", newPoint, ResolutionHigh)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
|
@ -197,7 +197,7 @@ func TestConcurrentStats(t *testing.T) {
|
|||
Timestamp: time.Now().Add(time.Duration(-i) * time.Second),
|
||||
Hashrate: 1000 + i*10,
|
||||
}
|
||||
InsertHashratePoint(minerName, "xmrig", point, ResolutionHigh)
|
||||
InsertHashratePoint(nil, minerName, "xmrig", point, ResolutionHigh)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
|
@ -238,7 +238,7 @@ func TestConcurrentGetAllStats(t *testing.T) {
|
|||
Timestamp: time.Now().Add(time.Duration(-i) * time.Second),
|
||||
Hashrate: 1000 + m*100 + i,
|
||||
}
|
||||
InsertHashratePoint(minerName, "xmrig", point, ResolutionHigh)
|
||||
InsertHashratePoint(nil, minerName, "xmrig", point, ResolutionHigh)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -267,7 +267,7 @@ func TestConcurrentGetAllStats(t *testing.T) {
|
|||
Timestamp: time.Now(),
|
||||
Hashrate: 2000 + i,
|
||||
}
|
||||
InsertHashratePoint("all-stats-new", "xmrig", point, ResolutionHigh)
|
||||
InsertHashratePoint(nil, "all-stats-new", "xmrig", point, ResolutionHigh)
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ func TestHashrateStorage(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, p := range points {
|
||||
if err := InsertHashratePoint(minerName, minerType, p, ResolutionHigh); err != nil {
|
||||
if err := InsertHashratePoint(nil, minerName, minerType, p, ResolutionHigh); err != nil {
|
||||
t.Fatalf("Failed to store hashrate point: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -109,7 +109,7 @@ func TestGetHashrateStats(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, p := range points {
|
||||
if err := InsertHashratePoint(minerName, minerType, p, ResolutionHigh); err != nil {
|
||||
if err := InsertHashratePoint(nil, minerName, minerType, p, ResolutionHigh); err != nil {
|
||||
t.Fatalf("Failed to store point: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -175,13 +175,13 @@ func TestCleanupRetention(t *testing.T) {
|
|||
}
|
||||
|
||||
// Insert all points
|
||||
if err := InsertHashratePoint(minerName, minerType, oldPoint, ResolutionHigh); err != nil {
|
||||
if err := InsertHashratePoint(nil, minerName, minerType, oldPoint, ResolutionHigh); err != nil {
|
||||
t.Fatalf("Failed to insert old point: %v", err)
|
||||
}
|
||||
if err := InsertHashratePoint(minerName, minerType, midPoint, ResolutionHigh); err != nil {
|
||||
if err := InsertHashratePoint(nil, minerName, minerType, midPoint, ResolutionHigh); err != nil {
|
||||
t.Fatalf("Failed to insert mid point: %v", err)
|
||||
}
|
||||
if err := InsertHashratePoint(minerName, minerType, newPoint, ResolutionHigh); err != nil {
|
||||
if err := InsertHashratePoint(nil, minerName, minerType, newPoint, ResolutionHigh); err != nil {
|
||||
t.Fatalf("Failed to insert new point: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -238,7 +238,7 @@ func TestGetHashrateHistoryTimeRange(t *testing.T) {
|
|||
Timestamp: now.Add(offset),
|
||||
Hashrate: 1000 + i*100,
|
||||
}
|
||||
if err := InsertHashratePoint(minerName, minerType, point, ResolutionHigh); err != nil {
|
||||
if err := InsertHashratePoint(nil, minerName, minerType, point, ResolutionHigh); err != nil {
|
||||
t.Fatalf("Failed to insert point: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -291,7 +291,7 @@ func TestMultipleMinerStats(t *testing.T) {
|
|||
Timestamp: now.Add(time.Duration(-i) * time.Minute),
|
||||
Hashrate: hr,
|
||||
}
|
||||
if err := InsertHashratePoint(m.name, "xmrig", point, ResolutionHigh); err != nil {
|
||||
if err := InsertHashratePoint(nil, m.name, "xmrig", point, ResolutionHigh); err != nil {
|
||||
t.Fatalf("Failed to insert point for %s: %v", m.name, err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
|
@ -47,8 +48,12 @@ type HashratePoint struct {
|
|||
Hashrate int `json:"hashrate"`
|
||||
}
|
||||
|
||||
// InsertHashratePoint stores a hashrate measurement in the database
|
||||
func InsertHashratePoint(minerName, minerType string, point HashratePoint, resolution Resolution) error {
|
||||
// dbInsertTimeout is the maximum time to wait for a database insert operation
|
||||
const dbInsertTimeout = 5 * time.Second
|
||||
|
||||
// InsertHashratePoint stores a hashrate measurement in the database.
|
||||
// If ctx is nil, a default timeout context will be used.
|
||||
func InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error {
|
||||
dbMu.RLock()
|
||||
defer dbMu.RUnlock()
|
||||
|
||||
|
|
@ -56,7 +61,14 @@ func InsertHashratePoint(minerName, minerType string, point HashratePoint, resol
|
|||
return nil // DB not enabled, silently skip
|
||||
}
|
||||
|
||||
_, err := db.Exec(`
|
||||
// Use provided context or create one with default timeout
|
||||
if ctx == nil {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(context.Background(), dbInsertTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO hashrate_history (miner_name, miner_type, timestamp, hashrate, resolution)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`, minerName, minerType, point.Timestamp, point.Hashrate, string(resolution))
|
||||
|
|
|
|||
95
pkg/database/interface.go
Normal file
95
pkg/database/interface.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HashrateStore defines the interface for hashrate data persistence.
|
||||
// This interface allows for dependency injection and easier testing.
|
||||
type HashrateStore interface {
|
||||
// InsertHashratePoint stores a hashrate measurement.
|
||||
// If ctx is nil, a default timeout will be used.
|
||||
InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error
|
||||
|
||||
// GetHashrateHistory retrieves hashrate history for a miner within a time range.
|
||||
GetHashrateHistory(minerName string, resolution Resolution, since, until time.Time) ([]HashratePoint, error)
|
||||
|
||||
// GetHashrateStats retrieves aggregated statistics for a specific miner.
|
||||
GetHashrateStats(minerName string) (*HashrateStats, error)
|
||||
|
||||
// GetAllMinerStats retrieves statistics for all miners.
|
||||
GetAllMinerStats() ([]HashrateStats, error)
|
||||
|
||||
// Cleanup removes old data based on retention settings.
|
||||
Cleanup(retentionDays int) error
|
||||
|
||||
// Close closes the store and releases resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// defaultStore implements HashrateStore using the global database connection.
|
||||
// This provides backward compatibility while allowing interface-based usage.
|
||||
type defaultStore struct{}
|
||||
|
||||
// DefaultStore returns a HashrateStore that uses the global database connection.
|
||||
// This is useful for gradual migration from package-level functions to interface-based usage.
|
||||
func DefaultStore() HashrateStore {
|
||||
return &defaultStore{}
|
||||
}
|
||||
|
||||
func (s *defaultStore) InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error {
|
||||
return InsertHashratePoint(ctx, minerName, minerType, point, resolution)
|
||||
}
|
||||
|
||||
func (s *defaultStore) GetHashrateHistory(minerName string, resolution Resolution, since, until time.Time) ([]HashratePoint, error) {
|
||||
return GetHashrateHistory(minerName, resolution, since, until)
|
||||
}
|
||||
|
||||
func (s *defaultStore) GetHashrateStats(minerName string) (*HashrateStats, error) {
|
||||
return GetHashrateStats(minerName)
|
||||
}
|
||||
|
||||
func (s *defaultStore) GetAllMinerStats() ([]HashrateStats, error) {
|
||||
return GetAllMinerStats()
|
||||
}
|
||||
|
||||
func (s *defaultStore) Cleanup(retentionDays int) error {
|
||||
return Cleanup(retentionDays)
|
||||
}
|
||||
|
||||
func (s *defaultStore) Close() error {
|
||||
return Close()
|
||||
}
|
||||
|
||||
// NopStore returns a HashrateStore that does nothing.
|
||||
// Useful for testing or when database is disabled.
|
||||
func NopStore() HashrateStore {
|
||||
return &nopStore{}
|
||||
}
|
||||
|
||||
type nopStore struct{}
|
||||
|
||||
func (s *nopStore) InsertHashratePoint(ctx context.Context, minerName, minerType string, point HashratePoint, resolution Resolution) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *nopStore) GetHashrateHistory(minerName string, resolution Resolution, since, until time.Time) ([]HashratePoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *nopStore) GetHashrateStats(minerName string) (*HashrateStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *nopStore) GetAllMinerStats() ([]HashrateStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *nopStore) Cleanup(retentionDays int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *nopStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
135
pkg/database/interface_test.go
Normal file
135
pkg/database/interface_test.go
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDefaultStore(t *testing.T) {
|
||||
cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
store := DefaultStore()
|
||||
|
||||
// Test InsertHashratePoint
|
||||
point := HashratePoint{
|
||||
Timestamp: time.Now(),
|
||||
Hashrate: 1500,
|
||||
}
|
||||
if err := store.InsertHashratePoint(nil, "interface-test", "xmrig", point, ResolutionHigh); err != nil {
|
||||
t.Fatalf("InsertHashratePoint failed: %v", err)
|
||||
}
|
||||
|
||||
// Test GetHashrateHistory
|
||||
history, err := store.GetHashrateHistory("interface-test", ResolutionHigh, time.Now().Add(-time.Hour), time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("GetHashrateHistory failed: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Errorf("Expected 1 point, got %d", len(history))
|
||||
}
|
||||
|
||||
// Test GetHashrateStats
|
||||
stats, err := store.GetHashrateStats("interface-test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHashrateStats failed: %v", err)
|
||||
}
|
||||
if stats == nil {
|
||||
t.Fatal("Expected non-nil stats")
|
||||
}
|
||||
if stats.TotalPoints != 1 {
|
||||
t.Errorf("Expected 1 total point, got %d", stats.TotalPoints)
|
||||
}
|
||||
|
||||
// Test GetAllMinerStats
|
||||
allStats, err := store.GetAllMinerStats()
|
||||
if err != nil {
|
||||
t.Fatalf("GetAllMinerStats failed: %v", err)
|
||||
}
|
||||
if len(allStats) != 1 {
|
||||
t.Errorf("Expected 1 miner in stats, got %d", len(allStats))
|
||||
}
|
||||
|
||||
// Test Cleanup
|
||||
if err := store.Cleanup(30); err != nil {
|
||||
t.Fatalf("Cleanup failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultStore_WithContext(t *testing.T) {
|
||||
cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
store := DefaultStore()
|
||||
ctx := context.Background()
|
||||
|
||||
point := HashratePoint{
|
||||
Timestamp: time.Now(),
|
||||
Hashrate: 2000,
|
||||
}
|
||||
if err := store.InsertHashratePoint(ctx, "ctx-test", "xmrig", point, ResolutionHigh); err != nil {
|
||||
t.Fatalf("InsertHashratePoint with context failed: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHashrateHistory("ctx-test", ResolutionHigh, time.Now().Add(-time.Hour), time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("GetHashrateHistory failed: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Errorf("Expected 1 point, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNopStore(t *testing.T) {
|
||||
store := NopStore()
|
||||
|
||||
// All operations should succeed without error
|
||||
point := HashratePoint{
|
||||
Timestamp: time.Now(),
|
||||
Hashrate: 1000,
|
||||
}
|
||||
if err := store.InsertHashratePoint(nil, "test", "xmrig", point, ResolutionHigh); err != nil {
|
||||
t.Errorf("NopStore InsertHashratePoint should not error: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHashrateHistory("test", ResolutionHigh, time.Now().Add(-time.Hour), time.Now())
|
||||
if err != nil {
|
||||
t.Errorf("NopStore GetHashrateHistory should not error: %v", err)
|
||||
}
|
||||
if history != nil {
|
||||
t.Errorf("NopStore GetHashrateHistory should return nil, got %v", history)
|
||||
}
|
||||
|
||||
stats, err := store.GetHashrateStats("test")
|
||||
if err != nil {
|
||||
t.Errorf("NopStore GetHashrateStats should not error: %v", err)
|
||||
}
|
||||
if stats != nil {
|
||||
t.Errorf("NopStore GetHashrateStats should return nil, got %v", stats)
|
||||
}
|
||||
|
||||
allStats, err := store.GetAllMinerStats()
|
||||
if err != nil {
|
||||
t.Errorf("NopStore GetAllMinerStats should not error: %v", err)
|
||||
}
|
||||
if allStats != nil {
|
||||
t.Errorf("NopStore GetAllMinerStats should return nil, got %v", allStats)
|
||||
}
|
||||
|
||||
if err := store.Cleanup(30); err != nil {
|
||||
t.Errorf("NopStore Cleanup should not error: %v", err)
|
||||
}
|
||||
|
||||
if err := store.Close(); err != nil {
|
||||
t.Errorf("NopStore Close should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInterfaceCompatibility ensures all implementations satisfy HashrateStore
|
||||
func TestInterfaceCompatibility(t *testing.T) {
|
||||
var _ HashrateStore = DefaultStore()
|
||||
var _ HashrateStore = NopStore()
|
||||
var _ HashrateStore = &defaultStore{}
|
||||
var _ HashrateStore = &nopStore{}
|
||||
}
|
||||
260
pkg/mining/container.go
Normal file
260
pkg/mining/container.go
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/Snider/Mining/pkg/database"
|
||||
"github.com/Snider/Mining/pkg/logging"
|
||||
)
|
||||
|
||||
// ContainerConfig holds configuration for the service container.
|
||||
type ContainerConfig struct {
|
||||
// Database configuration
|
||||
Database database.Config
|
||||
|
||||
// ListenAddr is the address to listen on (e.g., ":9090")
|
||||
ListenAddr string
|
||||
|
||||
// DisplayAddr is the address shown in Swagger docs
|
||||
DisplayAddr string
|
||||
|
||||
// SwaggerNamespace is the API path prefix
|
||||
SwaggerNamespace string
|
||||
|
||||
// SimulationMode enables simulation mode for testing
|
||||
SimulationMode bool
|
||||
}
|
||||
|
||||
// DefaultContainerConfig returns sensible defaults for the container.
|
||||
func DefaultContainerConfig() ContainerConfig {
|
||||
return ContainerConfig{
|
||||
Database: database.Config{
|
||||
Enabled: true,
|
||||
RetentionDays: 30,
|
||||
},
|
||||
ListenAddr: ":9090",
|
||||
DisplayAddr: "localhost:9090",
|
||||
SwaggerNamespace: "/api/v1/mining",
|
||||
SimulationMode: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Container manages the lifecycle of all services.
|
||||
// It provides centralized initialization, dependency injection, and graceful shutdown.
|
||||
type Container struct {
|
||||
config ContainerConfig
|
||||
mu sync.RWMutex
|
||||
|
||||
// Core services
|
||||
manager ManagerInterface
|
||||
profileManager *ProfileManager
|
||||
nodeService *NodeService
|
||||
eventHub *EventHub
|
||||
service *Service
|
||||
|
||||
// Database store (interface for testing)
|
||||
hashrateStore database.HashrateStore
|
||||
|
||||
// Initialization state
|
||||
initialized bool
|
||||
transportStarted bool
|
||||
shutdownCh chan struct{}
|
||||
}
|
||||
|
||||
// NewContainer creates a new service container with the given configuration.
|
||||
func NewContainer(config ContainerConfig) *Container {
|
||||
return &Container{
|
||||
config: config,
|
||||
shutdownCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize sets up all services in the correct order.
|
||||
// This should be called before Start().
|
||||
func (c *Container) Initialize(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.initialized {
|
||||
return fmt.Errorf("container already initialized")
|
||||
}
|
||||
|
||||
// 1. Initialize database (optional)
|
||||
if c.config.Database.Enabled {
|
||||
if err := database.Initialize(c.config.Database); err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
c.hashrateStore = database.DefaultStore()
|
||||
logging.Info("database initialized", logging.Fields{"retention_days": c.config.Database.RetentionDays})
|
||||
} else {
|
||||
c.hashrateStore = database.NopStore()
|
||||
logging.Info("database disabled, using no-op store", nil)
|
||||
}
|
||||
|
||||
// 2. Initialize profile manager
|
||||
var err error
|
||||
c.profileManager, err = NewProfileManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize profile manager: %w", err)
|
||||
}
|
||||
|
||||
// 3. Initialize miner manager
|
||||
if c.config.SimulationMode {
|
||||
c.manager = NewManagerForSimulation()
|
||||
} else {
|
||||
c.manager = NewManager()
|
||||
}
|
||||
|
||||
// 4. Initialize node service (optional - P2P features)
|
||||
c.nodeService, err = NewNodeService()
|
||||
if err != nil {
|
||||
logging.Warn("node service unavailable", logging.Fields{"error": err})
|
||||
// Continue without node service - P2P features will be unavailable
|
||||
}
|
||||
|
||||
// 5. Initialize event hub for WebSocket
|
||||
c.eventHub = NewEventHub()
|
||||
|
||||
// Wire up event hub to manager
|
||||
if mgr, ok := c.manager.(*Manager); ok {
|
||||
mgr.SetEventHub(c.eventHub)
|
||||
}
|
||||
|
||||
c.initialized = true
|
||||
logging.Info("service container initialized", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start begins all background services.
|
||||
func (c *Container) Start(ctx context.Context) error {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if !c.initialized {
|
||||
return fmt.Errorf("container not initialized")
|
||||
}
|
||||
|
||||
// Start event hub
|
||||
go c.eventHub.Run()
|
||||
|
||||
// Start node transport if available
|
||||
if c.nodeService != nil {
|
||||
if err := c.nodeService.StartTransport(); err != nil {
|
||||
logging.Warn("failed to start node transport", logging.Fields{"error": err})
|
||||
} else {
|
||||
c.transportStarted = true
|
||||
}
|
||||
}
|
||||
|
||||
logging.Info("service container started", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops all services in reverse order.
|
||||
func (c *Container) Shutdown(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.initialized {
|
||||
return nil
|
||||
}
|
||||
|
||||
logging.Info("shutting down service container", nil)
|
||||
|
||||
var errs []error
|
||||
|
||||
// 1. Stop service (HTTP server)
|
||||
if c.service != nil {
|
||||
// Service shutdown is handled externally
|
||||
}
|
||||
|
||||
// 2. Stop node transport (only if it was started)
|
||||
if c.nodeService != nil && c.transportStarted {
|
||||
if err := c.nodeService.StopTransport(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("node transport: %w", err))
|
||||
}
|
||||
c.transportStarted = false
|
||||
}
|
||||
|
||||
// 3. Stop event hub
|
||||
if c.eventHub != nil {
|
||||
c.eventHub.Stop()
|
||||
}
|
||||
|
||||
// 4. Stop miner manager
|
||||
if mgr, ok := c.manager.(*Manager); ok {
|
||||
mgr.Stop()
|
||||
}
|
||||
|
||||
// 5. Close database
|
||||
if err := database.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("database: %w", err))
|
||||
}
|
||||
|
||||
c.initialized = false
|
||||
close(c.shutdownCh)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("shutdown errors: %v", errs)
|
||||
}
|
||||
|
||||
logging.Info("service container shutdown complete", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Manager returns the miner manager.
|
||||
func (c *Container) Manager() ManagerInterface {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.manager
|
||||
}
|
||||
|
||||
// ProfileManager returns the profile manager.
|
||||
func (c *Container) ProfileManager() *ProfileManager {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.profileManager
|
||||
}
|
||||
|
||||
|
||||
// NodeService returns the node service (may be nil if P2P is unavailable).
|
||||
func (c *Container) NodeService() *NodeService {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.nodeService
|
||||
}
|
||||
|
||||
// EventHub returns the event hub for WebSocket connections.
|
||||
func (c *Container) EventHub() *EventHub {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.eventHub
|
||||
}
|
||||
|
||||
// HashrateStore returns the hashrate store interface.
|
||||
func (c *Container) HashrateStore() database.HashrateStore {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.hashrateStore
|
||||
}
|
||||
|
||||
// SetHashrateStore allows injecting a custom hashrate store (useful for testing).
|
||||
func (c *Container) SetHashrateStore(store database.HashrateStore) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.hashrateStore = store
|
||||
}
|
||||
|
||||
// ShutdownCh returns a channel that's closed when shutdown is complete.
|
||||
func (c *Container) ShutdownCh() <-chan struct{} {
|
||||
return c.shutdownCh
|
||||
}
|
||||
|
||||
// IsInitialized returns true if the container has been initialized.
|
||||
func (c *Container) IsInitialized() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.initialized
|
||||
}
|
||||
238
pkg/mining/container_test.go
Normal file
238
pkg/mining/container_test.go
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Mining/pkg/database"
|
||||
)
|
||||
|
||||
func setupContainerTestEnv(t *testing.T) func() {
|
||||
tmpDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", filepath.Join(tmpDir, "config"))
|
||||
os.Setenv("XDG_DATA_HOME", filepath.Join(tmpDir, "data"))
|
||||
return func() {
|
||||
os.Unsetenv("XDG_CONFIG_HOME")
|
||||
os.Unsetenv("XDG_DATA_HOME")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewContainer(t *testing.T) {
|
||||
config := DefaultContainerConfig()
|
||||
container := NewContainer(config)
|
||||
|
||||
if container == nil {
|
||||
t.Fatal("NewContainer returned nil")
|
||||
}
|
||||
|
||||
if container.IsInitialized() {
|
||||
t.Error("Container should not be initialized before Initialize() is called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultContainerConfig(t *testing.T) {
|
||||
config := DefaultContainerConfig()
|
||||
|
||||
if !config.Database.Enabled {
|
||||
t.Error("Database should be enabled by default")
|
||||
}
|
||||
|
||||
if config.Database.RetentionDays != 30 {
|
||||
t.Errorf("Expected 30 retention days, got %d", config.Database.RetentionDays)
|
||||
}
|
||||
|
||||
if config.ListenAddr != ":9090" {
|
||||
t.Errorf("Expected :9090, got %s", config.ListenAddr)
|
||||
}
|
||||
|
||||
if config.SimulationMode {
|
||||
t.Error("SimulationMode should be false by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainer_Initialize(t *testing.T) {
|
||||
cleanup := setupContainerTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
config := DefaultContainerConfig()
|
||||
config.Database.Enabled = true
|
||||
config.Database.Path = filepath.Join(t.TempDir(), "test.db")
|
||||
config.SimulationMode = true // Use simulation mode for faster tests
|
||||
|
||||
container := NewContainer(config)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := container.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize failed: %v", err)
|
||||
}
|
||||
|
||||
if !container.IsInitialized() {
|
||||
t.Error("Container should be initialized after Initialize()")
|
||||
}
|
||||
|
||||
// Verify services are available
|
||||
if container.Manager() == nil {
|
||||
t.Error("Manager should not be nil after initialization")
|
||||
}
|
||||
|
||||
if container.ProfileManager() == nil {
|
||||
t.Error("ProfileManager should not be nil after initialization")
|
||||
}
|
||||
|
||||
if container.EventHub() == nil {
|
||||
t.Error("EventHub should not be nil after initialization")
|
||||
}
|
||||
|
||||
if container.HashrateStore() == nil {
|
||||
t.Error("HashrateStore should not be nil after initialization")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
if err := container.Shutdown(ctx); err != nil {
|
||||
t.Errorf("Shutdown failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainer_InitializeTwice(t *testing.T) {
|
||||
cleanup := setupContainerTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
config := DefaultContainerConfig()
|
||||
config.Database.Enabled = false
|
||||
config.SimulationMode = true
|
||||
|
||||
container := NewContainer(config)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := container.Initialize(ctx); err != nil {
|
||||
t.Fatalf("First Initialize failed: %v", err)
|
||||
}
|
||||
|
||||
// Second initialization should fail
|
||||
if err := container.Initialize(ctx); err == nil {
|
||||
t.Error("Second Initialize should fail")
|
||||
}
|
||||
|
||||
container.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestContainer_DatabaseDisabled(t *testing.T) {
|
||||
cleanup := setupContainerTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
config := DefaultContainerConfig()
|
||||
config.Database.Enabled = false
|
||||
config.SimulationMode = true
|
||||
|
||||
container := NewContainer(config)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := container.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize failed: %v", err)
|
||||
}
|
||||
|
||||
// Should use NopStore when database is disabled
|
||||
store := container.HashrateStore()
|
||||
if store == nil {
|
||||
t.Fatal("HashrateStore should not be nil")
|
||||
}
|
||||
|
||||
// NopStore should accept inserts without error
|
||||
point := database.HashratePoint{
|
||||
Timestamp: time.Now(),
|
||||
Hashrate: 1000,
|
||||
}
|
||||
if err := store.InsertHashratePoint(nil, "test", "xmrig", point, database.ResolutionHigh); err != nil {
|
||||
t.Errorf("NopStore insert should not fail: %v", err)
|
||||
}
|
||||
|
||||
container.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestContainer_SetHashrateStore(t *testing.T) {
|
||||
cleanup := setupContainerTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
config := DefaultContainerConfig()
|
||||
config.Database.Enabled = false
|
||||
config.SimulationMode = true
|
||||
|
||||
container := NewContainer(config)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := container.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize failed: %v", err)
|
||||
}
|
||||
|
||||
// Inject custom store
|
||||
customStore := database.NopStore()
|
||||
container.SetHashrateStore(customStore)
|
||||
|
||||
if container.HashrateStore() != customStore {
|
||||
t.Error("SetHashrateStore should update the store")
|
||||
}
|
||||
|
||||
container.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestContainer_StartWithoutInitialize(t *testing.T) {
|
||||
config := DefaultContainerConfig()
|
||||
container := NewContainer(config)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := container.Start(ctx); err == nil {
|
||||
t.Error("Start should fail if Initialize was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainer_ShutdownWithoutInitialize(t *testing.T) {
|
||||
config := DefaultContainerConfig()
|
||||
container := NewContainer(config)
|
||||
ctx := context.Background()
|
||||
|
||||
// Shutdown on uninitialized container should not error
|
||||
if err := container.Shutdown(ctx); err != nil {
|
||||
t.Errorf("Shutdown on uninitialized container should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainer_ShutdownChannel(t *testing.T) {
|
||||
cleanup := setupContainerTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
config := DefaultContainerConfig()
|
||||
config.Database.Enabled = false
|
||||
config.SimulationMode = true
|
||||
|
||||
container := NewContainer(config)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := container.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize failed: %v", err)
|
||||
}
|
||||
|
||||
shutdownCh := container.ShutdownCh()
|
||||
|
||||
// Channel should be open before shutdown
|
||||
select {
|
||||
case <-shutdownCh:
|
||||
t.Error("ShutdownCh should not be closed before Shutdown()")
|
||||
default:
|
||||
// Expected
|
||||
}
|
||||
|
||||
if err := container.Shutdown(ctx); err != nil {
|
||||
t.Errorf("Shutdown failed: %v", err)
|
||||
}
|
||||
|
||||
// Channel should be closed after shutdown
|
||||
select {
|
||||
case <-shutdownCh:
|
||||
// Expected
|
||||
case <-time.After(time.Second):
|
||||
t.Error("ShutdownCh should be closed after Shutdown()")
|
||||
}
|
||||
}
|
||||
|
|
@ -614,7 +614,8 @@ func (m *Manager) collectSingleMinerStats(miner Miner, minerType string, now tim
|
|||
Timestamp: point.Timestamp,
|
||||
Hashrate: point.Hashrate,
|
||||
}
|
||||
if err := database.InsertHashratePoint(minerName, minerType, dbPoint, database.ResolutionHigh); err != nil {
|
||||
// Use nil context to let InsertHashratePoint use its default timeout
|
||||
if err := database.InsertHashratePoint(nil, minerName, minerType, dbPoint, database.ResolutionHigh); err != nil {
|
||||
logging.Warn("failed to persist hashrate", logging.Fields{"miner": minerName, "error": err})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -246,6 +246,7 @@ func (b *BaseMiner) InstallFromURL(url string) error {
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse
|
||||
return fmt.Errorf("failed to download release: unexpected status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -334,7 +334,8 @@ func (ns *NodeService) handlePeerStats(c *gin.Context) {
|
|||
|
||||
// RemoteStartRequest is the request body for starting a remote miner.
|
||||
type RemoteStartRequest struct {
|
||||
ProfileID string `json:"profileId" binding:"required"`
|
||||
MinerType string `json:"minerType" binding:"required"`
|
||||
ProfileID string `json:"profileId,omitempty"`
|
||||
Config json.RawMessage `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
|
|
@ -356,7 +357,7 @@ func (ns *NodeService) handleRemoteStart(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := ns.controller.StartRemoteMiner(peerID, req.ProfileID, req.Config); err != nil {
|
||||
if err := ns.controller.StartRemoteMiner(peerID, req.MinerType, req.ProfileID, req.Config); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
|
|
|||
196
pkg/mining/repository.go
Normal file
196
pkg/mining/repository.go
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Repository defines a generic interface for data persistence.
|
||||
// Implementations can store data in files, databases, etc.
|
||||
type Repository[T any] interface {
|
||||
// Load reads data from the repository
|
||||
Load() (T, error)
|
||||
|
||||
// Save writes data to the repository
|
||||
Save(data T) error
|
||||
|
||||
// Update atomically loads, modifies, and saves data
|
||||
Update(fn func(*T) error) error
|
||||
}
|
||||
|
||||
// FileRepository provides atomic file-based persistence for JSON data.
|
||||
// It uses atomic writes (temp file + rename) to prevent corruption.
|
||||
type FileRepository[T any] struct {
|
||||
mu sync.RWMutex
|
||||
path string
|
||||
defaults func() T
|
||||
}
|
||||
|
||||
// FileRepositoryOption configures a FileRepository.
|
||||
type FileRepositoryOption[T any] func(*FileRepository[T])
|
||||
|
||||
// WithDefaults sets the default value factory for when the file doesn't exist.
|
||||
func WithDefaults[T any](fn func() T) FileRepositoryOption[T] {
|
||||
return func(r *FileRepository[T]) {
|
||||
r.defaults = fn
|
||||
}
|
||||
}
|
||||
|
||||
// NewFileRepository creates a new file-based repository.
|
||||
func NewFileRepository[T any](path string, opts ...FileRepositoryOption[T]) *FileRepository[T] {
|
||||
r := &FileRepository[T]{
|
||||
path: path,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Load reads and deserializes data from the file.
|
||||
// Returns defaults if file doesn't exist.
|
||||
func (r *FileRepository[T]) Load() (T, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var result T
|
||||
|
||||
data, err := os.ReadFile(r.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if r.defaults != nil {
|
||||
return r.defaults(), nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return result, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return result, fmt.Errorf("failed to unmarshal data: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Save serializes and writes data to the file atomically.
|
||||
func (r *FileRepository[T]) Save(data T) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
return r.saveUnlocked(data)
|
||||
}
|
||||
|
||||
// saveUnlocked saves data without acquiring the lock (caller must hold lock).
|
||||
func (r *FileRepository[T]) saveUnlocked(data T) error {
|
||||
dir := filepath.Dir(r.path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %w", err)
|
||||
}
|
||||
|
||||
// Atomic write: write to temp file, sync, then rename
|
||||
tmpFile, err := os.CreateTemp(dir, "repo-*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
// Clean up temp file on error
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := tmpFile.Write(jsonData); err != nil {
|
||||
tmpFile.Close()
|
||||
return fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Sync(); err != nil {
|
||||
tmpFile.Close()
|
||||
return fmt.Errorf("failed to sync temp file: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tmpPath, 0600); err != nil {
|
||||
return fmt.Errorf("failed to set file permissions: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, r.path); err != nil {
|
||||
return fmt.Errorf("failed to rename temp file: %w", err)
|
||||
}
|
||||
|
||||
success = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update atomically loads, modifies, and saves data.
|
||||
// The modification function receives a pointer to the data.
|
||||
func (r *FileRepository[T]) Update(fn func(*T) error) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Load current data
|
||||
var data T
|
||||
fileData, err := os.ReadFile(r.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if r.defaults != nil {
|
||||
data = r.defaults()
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(fileData, &data); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply modification
|
||||
if err := fn(&data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save atomically
|
||||
return r.saveUnlocked(data)
|
||||
}
|
||||
|
||||
// Path returns the file path of this repository.
|
||||
func (r *FileRepository[T]) Path() string {
|
||||
return r.path
|
||||
}
|
||||
|
||||
// Exists returns true if the repository file exists.
|
||||
func (r *FileRepository[T]) Exists() bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
_, err := os.Stat(r.path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Delete removes the repository file.
|
||||
func (r *FileRepository[T]) Delete() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
err := os.Remove(r.path)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
300
pkg/mining/repository_test.go
Normal file
300
pkg/mining/repository_test.go
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
package mining
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testData struct {
|
||||
Name string `json:"name"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
func TestFileRepository_Load(t *testing.T) {
|
||||
t.Run("NonExistentFile", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "nonexistent.json")
|
||||
repo := NewFileRepository[testData](path)
|
||||
|
||||
data, err := repo.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load should not error for non-existent file: %v", err)
|
||||
}
|
||||
if data.Name != "" || data.Value != 0 {
|
||||
t.Error("Expected zero value for non-existent file")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NonExistentFileWithDefaults", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "nonexistent.json")
|
||||
repo := NewFileRepository[testData](path, WithDefaults(func() testData {
|
||||
return testData{Name: "default", Value: 42}
|
||||
}))
|
||||
|
||||
data, err := repo.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load should not error: %v", err)
|
||||
}
|
||||
if data.Name != "default" || data.Value != 42 {
|
||||
t.Errorf("Expected default values, got %+v", data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExistingFile", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "test.json")
|
||||
|
||||
// Write test data
|
||||
if err := os.WriteFile(path, []byte(`{"name":"test","value":123}`), 0600); err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
repo := NewFileRepository[testData](path)
|
||||
data, err := repo.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
if data.Name != "test" || data.Value != 123 {
|
||||
t.Errorf("Unexpected data: %+v", data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InvalidJSON", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "invalid.json")
|
||||
|
||||
if err := os.WriteFile(path, []byte(`{invalid json}`), 0600); err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
repo := NewFileRepository[testData](path)
|
||||
_, err := repo.Load()
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileRepository_Save(t *testing.T) {
|
||||
t.Run("NewFile", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "subdir", "new.json")
|
||||
repo := NewFileRepository[testData](path)
|
||||
|
||||
data := testData{Name: "saved", Value: 456}
|
||||
if err := repo.Save(data); err != nil {
|
||||
t.Fatalf("Save failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify file was created
|
||||
if !repo.Exists() {
|
||||
t.Error("File should exist after save")
|
||||
}
|
||||
|
||||
// Verify content
|
||||
loaded, err := repo.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load after save failed: %v", err)
|
||||
}
|
||||
if loaded.Name != "saved" || loaded.Value != 456 {
|
||||
t.Errorf("Unexpected loaded data: %+v", loaded)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OverwriteExisting", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "existing.json")
|
||||
repo := NewFileRepository[testData](path)
|
||||
|
||||
// Save initial data
|
||||
if err := repo.Save(testData{Name: "first", Value: 1}); err != nil {
|
||||
t.Fatalf("First save failed: %v", err)
|
||||
}
|
||||
|
||||
// Overwrite
|
||||
if err := repo.Save(testData{Name: "second", Value: 2}); err != nil {
|
||||
t.Fatalf("Second save failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify overwrite
|
||||
loaded, err := repo.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
if loaded.Name != "second" || loaded.Value != 2 {
|
||||
t.Errorf("Expected overwritten data, got: %+v", loaded)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileRepository_Update(t *testing.T) {
|
||||
t.Run("UpdateExisting", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "update.json")
|
||||
repo := NewFileRepository[testData](path)
|
||||
|
||||
// Save initial data
|
||||
if err := repo.Save(testData{Name: "initial", Value: 10}); err != nil {
|
||||
t.Fatalf("Initial save failed: %v", err)
|
||||
}
|
||||
|
||||
// Update
|
||||
err := repo.Update(func(data *testData) error {
|
||||
data.Value += 5
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify update
|
||||
loaded, err := repo.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
if loaded.Value != 15 {
|
||||
t.Errorf("Expected value 15, got %d", loaded.Value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateNonExistentWithDefaults", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "new.json")
|
||||
repo := NewFileRepository[testData](path, WithDefaults(func() testData {
|
||||
return testData{Name: "default", Value: 100}
|
||||
}))
|
||||
|
||||
err := repo.Update(func(data *testData) error {
|
||||
data.Value *= 2
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify update started from defaults
|
||||
loaded, err := repo.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
if loaded.Value != 200 {
|
||||
t.Errorf("Expected value 200, got %d", loaded.Value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateWithError", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "error.json")
|
||||
repo := NewFileRepository[testData](path)
|
||||
|
||||
if err := repo.Save(testData{Name: "test", Value: 1}); err != nil {
|
||||
t.Fatalf("Initial save failed: %v", err)
|
||||
}
|
||||
|
||||
// Update that returns error
|
||||
testErr := errors.New("update error")
|
||||
err := repo.Update(func(data *testData) error {
|
||||
data.Value = 999 // This change should not be saved
|
||||
return testErr
|
||||
})
|
||||
if err != testErr {
|
||||
t.Errorf("Expected test error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify original data unchanged
|
||||
loaded, err := repo.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
if loaded.Value != 1 {
|
||||
t.Errorf("Expected value 1 (unchanged), got %d", loaded.Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileRepository_Delete(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "delete.json")
|
||||
repo := NewFileRepository[testData](path)
|
||||
|
||||
// Save data
|
||||
if err := repo.Save(testData{Name: "temp", Value: 1}); err != nil {
|
||||
t.Fatalf("Save failed: %v", err)
|
||||
}
|
||||
|
||||
if !repo.Exists() {
|
||||
t.Error("File should exist after save")
|
||||
}
|
||||
|
||||
// Delete
|
||||
if err := repo.Delete(); err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if repo.Exists() {
|
||||
t.Error("File should not exist after delete")
|
||||
}
|
||||
|
||||
// Delete non-existent should not error
|
||||
if err := repo.Delete(); err != nil {
|
||||
t.Errorf("Delete non-existent should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileRepository_Path(t *testing.T) {
|
||||
path := "/some/path/config.json"
|
||||
repo := NewFileRepository[testData](path)
|
||||
|
||||
if repo.Path() != path {
|
||||
t.Errorf("Expected path %s, got %s", path, repo.Path())
|
||||
}
|
||||
}
|
||||
|
||||
// Test with slice data
|
||||
func TestFileRepository_SliceData(t *testing.T) {
|
||||
type item struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "items.json")
|
||||
repo := NewFileRepository[[]item](path, WithDefaults(func() []item {
|
||||
return []item{}
|
||||
}))
|
||||
|
||||
// Save slice
|
||||
items := []item{
|
||||
{ID: "1", Name: "First"},
|
||||
{ID: "2", Name: "Second"},
|
||||
}
|
||||
if err := repo.Save(items); err != nil {
|
||||
t.Fatalf("Save failed: %v", err)
|
||||
}
|
||||
|
||||
// Load and verify
|
||||
loaded, err := repo.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
if len(loaded) != 2 {
|
||||
t.Errorf("Expected 2 items, got %d", len(loaded))
|
||||
}
|
||||
|
||||
// Update slice
|
||||
err = repo.Update(func(data *[]item) error {
|
||||
*data = append(*data, item{ID: "3", Name: "Third"})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
loaded, _ = repo.Load()
|
||||
if len(loaded) != 3 {
|
||||
t.Errorf("Expected 3 items after update, got %d", len(loaded))
|
||||
}
|
||||
}
|
||||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
|
@ -92,6 +93,7 @@ func (m *TTMiner) GetLatestVersion() (string, error) {
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse
|
||||
return "", fmt.Errorf("failed to get latest release: unexpected status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
|
|
@ -41,6 +42,7 @@ func (m *TTMiner) GetStats(ctx context.Context) (*PerformanceMetrics, error) {
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse
|
||||
return nil, fmt.Errorf("failed to get stats: unexpected status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
|
@ -93,6 +94,7 @@ func (m *XMRigMiner) GetLatestVersion() (string, error) {
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse
|
||||
return "", fmt.Errorf("failed to get latest release: unexpected status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -45,6 +46,7 @@ func (m *XMRigMiner) GetStats(ctx context.Context) (*PerformanceMetrics, error)
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
io.Copy(io.Discard, resp.Body) // Drain body to allow connection reuse
|
||||
return nil, fmt.Errorf("failed to get stats: unexpected status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -125,34 +125,27 @@ func (c *Controller) GetRemoteStats(peerID string) (*StatsPayload, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Type == MsgError {
|
||||
var errPayload ErrorPayload
|
||||
if err := resp.ParsePayload(&errPayload); err != nil {
|
||||
return nil, fmt.Errorf("remote error (unable to parse)")
|
||||
}
|
||||
return nil, fmt.Errorf("remote error: %s", errPayload.Message)
|
||||
}
|
||||
|
||||
if resp.Type != MsgStats {
|
||||
return nil, fmt.Errorf("unexpected response type: %s", resp.Type)
|
||||
}
|
||||
|
||||
var stats StatsPayload
|
||||
if err := resp.ParsePayload(&stats); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse stats: %w", err)
|
||||
if err := ParseResponse(resp, MsgStats, &stats); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// StartRemoteMiner requests a remote peer to start a miner with a given profile.
|
||||
func (c *Controller) StartRemoteMiner(peerID, profileID string, configOverride json.RawMessage) error {
|
||||
func (c *Controller) StartRemoteMiner(peerID, minerType, profileID string, configOverride json.RawMessage) error {
|
||||
identity := c.node.GetIdentity()
|
||||
if identity == nil {
|
||||
return fmt.Errorf("node identity not initialized")
|
||||
}
|
||||
|
||||
if minerType == "" {
|
||||
return fmt.Errorf("miner type is required")
|
||||
}
|
||||
|
||||
payload := StartMinerPayload{
|
||||
MinerType: minerType,
|
||||
ProfileID: profileID,
|
||||
Config: configOverride,
|
||||
}
|
||||
|
|
@ -167,21 +160,9 @@ func (c *Controller) StartRemoteMiner(peerID, profileID string, configOverride j
|
|||
return err
|
||||
}
|
||||
|
||||
if resp.Type == MsgError {
|
||||
var errPayload ErrorPayload
|
||||
if err := resp.ParsePayload(&errPayload); err != nil {
|
||||
return fmt.Errorf("remote error (unable to parse)")
|
||||
}
|
||||
return fmt.Errorf("remote error: %s", errPayload.Message)
|
||||
}
|
||||
|
||||
if resp.Type != MsgMinerAck {
|
||||
return fmt.Errorf("unexpected response type: %s", resp.Type)
|
||||
}
|
||||
|
||||
var ack MinerAckPayload
|
||||
if err := resp.ParsePayload(&ack); err != nil {
|
||||
return fmt.Errorf("failed to parse ack: %w", err)
|
||||
if err := ParseResponse(resp, MsgMinerAck, &ack); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !ack.Success {
|
||||
|
|
@ -212,21 +193,9 @@ func (c *Controller) StopRemoteMiner(peerID, minerName string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if resp.Type == MsgError {
|
||||
var errPayload ErrorPayload
|
||||
if err := resp.ParsePayload(&errPayload); err != nil {
|
||||
return fmt.Errorf("remote error (unable to parse)")
|
||||
}
|
||||
return fmt.Errorf("remote error: %s", errPayload.Message)
|
||||
}
|
||||
|
||||
if resp.Type != MsgMinerAck {
|
||||
return fmt.Errorf("unexpected response type: %s", resp.Type)
|
||||
}
|
||||
|
||||
var ack MinerAckPayload
|
||||
if err := resp.ParsePayload(&ack); err != nil {
|
||||
return fmt.Errorf("failed to parse ack: %w", err)
|
||||
if err := ParseResponse(resp, MsgMinerAck, &ack); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !ack.Success {
|
||||
|
|
@ -258,21 +227,9 @@ func (c *Controller) GetRemoteLogs(peerID, minerName string, lines int) ([]strin
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Type == MsgError {
|
||||
var errPayload ErrorPayload
|
||||
if err := resp.ParsePayload(&errPayload); err != nil {
|
||||
return nil, fmt.Errorf("remote error (unable to parse)")
|
||||
}
|
||||
return nil, fmt.Errorf("remote error: %s", errPayload.Message)
|
||||
}
|
||||
|
||||
if resp.Type != MsgLogs {
|
||||
return nil, fmt.Errorf("unexpected response type: %s", resp.Type)
|
||||
}
|
||||
|
||||
var logs LogsPayload
|
||||
if err := resp.ParsePayload(&logs); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse logs: %w", err)
|
||||
if err := ParseResponse(resp, MsgLogs, &logs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logs.Lines, nil
|
||||
|
|
@ -325,8 +282,8 @@ func (c *Controller) PingPeer(peerID string) (float64, error) {
|
|||
return 0, err
|
||||
}
|
||||
|
||||
if resp.Type != MsgPong {
|
||||
return 0, fmt.Errorf("unexpected response type: %s", resp.Type)
|
||||
if err := ValidateResponse(resp, MsgPong); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Calculate round-trip time
|
||||
|
|
|
|||
|
|
@ -117,7 +117,8 @@ type PongPayload struct {
|
|||
|
||||
// StartMinerPayload requests starting a miner.
|
||||
type StartMinerPayload struct {
|
||||
ProfileID string `json:"profileId"`
|
||||
MinerType string `json:"minerType"` // Required: miner type (e.g., "xmrig", "tt-miner")
|
||||
ProfileID string `json:"profileId,omitempty"`
|
||||
Config json.RawMessage `json:"config,omitempty"` // Override profile config
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ func TestMessageReply(t *testing.T) {
|
|||
func TestParsePayload(t *testing.T) {
|
||||
t.Run("ValidPayload", func(t *testing.T) {
|
||||
payload := StartMinerPayload{
|
||||
MinerType: "xmrig",
|
||||
ProfileID: "test-profile",
|
||||
}
|
||||
|
||||
|
|
@ -190,6 +191,7 @@ func TestNewErrorMessage(t *testing.T) {
|
|||
|
||||
func TestMessageSerialization(t *testing.T) {
|
||||
original, _ := NewMessage(MsgStartMiner, "ctrl", "worker", StartMinerPayload{
|
||||
MinerType: "xmrig",
|
||||
ProfileID: "my-profile",
|
||||
})
|
||||
|
||||
|
|
|
|||
88
pkg/node/protocol.go
Normal file
88
pkg/node/protocol.go
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
package node
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ProtocolError represents an error from the remote peer.
|
||||
type ProtocolError struct {
|
||||
Code int
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ProtocolError) Error() string {
|
||||
return fmt.Sprintf("remote error (%d): %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// ResponseHandler provides helpers for handling protocol responses.
|
||||
type ResponseHandler struct{}
|
||||
|
||||
// ValidateResponse checks if the response is valid and returns a parsed error if it's an error response.
|
||||
// It checks:
|
||||
// 1. If response is nil (returns error)
|
||||
// 2. If response is an error message (returns ProtocolError)
|
||||
// 3. If response type matches expected (returns error if not)
|
||||
func (h *ResponseHandler) ValidateResponse(resp *Message, expectedType MessageType) error {
|
||||
if resp == nil {
|
||||
return fmt.Errorf("nil response")
|
||||
}
|
||||
|
||||
// Check for error response
|
||||
if resp.Type == MsgError {
|
||||
var errPayload ErrorPayload
|
||||
if err := resp.ParsePayload(&errPayload); err != nil {
|
||||
return &ProtocolError{Code: ErrCodeUnknown, Message: "unable to parse error response"}
|
||||
}
|
||||
return &ProtocolError{Code: errPayload.Code, Message: errPayload.Message}
|
||||
}
|
||||
|
||||
// Check expected type
|
||||
if resp.Type != expectedType {
|
||||
return fmt.Errorf("unexpected response type: expected %s, got %s", expectedType, resp.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseResponse validates the response and parses the payload into the target.
|
||||
// This combines ValidateResponse and ParsePayload into a single call.
|
||||
func (h *ResponseHandler) ParseResponse(resp *Message, expectedType MessageType, target interface{}) error {
|
||||
if err := h.ValidateResponse(resp, expectedType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if target != nil {
|
||||
if err := resp.ParsePayload(target); err != nil {
|
||||
return fmt.Errorf("failed to parse %s payload: %w", expectedType, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultResponseHandler is the default response handler instance.
|
||||
var DefaultResponseHandler = &ResponseHandler{}
|
||||
|
||||
// ValidateResponse is a convenience function using the default handler.
|
||||
func ValidateResponse(resp *Message, expectedType MessageType) error {
|
||||
return DefaultResponseHandler.ValidateResponse(resp, expectedType)
|
||||
}
|
||||
|
||||
// ParseResponse is a convenience function using the default handler.
|
||||
func ParseResponse(resp *Message, expectedType MessageType, target interface{}) error {
|
||||
return DefaultResponseHandler.ParseResponse(resp, expectedType, target)
|
||||
}
|
||||
|
||||
// IsProtocolError returns true if the error is a ProtocolError.
|
||||
func IsProtocolError(err error) bool {
|
||||
_, ok := err.(*ProtocolError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetProtocolErrorCode returns the error code if err is a ProtocolError, otherwise returns 0.
|
||||
func GetProtocolErrorCode(err error) int {
|
||||
if pe, ok := err.(*ProtocolError); ok {
|
||||
return pe.Code
|
||||
}
|
||||
return 0
|
||||
}
|
||||
161
pkg/node/protocol_test.go
Normal file
161
pkg/node/protocol_test.go
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
package node
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResponseHandler_ValidateResponse(t *testing.T) {
|
||||
handler := &ResponseHandler{}
|
||||
|
||||
t.Run("NilResponse", func(t *testing.T) {
|
||||
err := handler.ValidateResponse(nil, MsgStats)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ErrorResponse", func(t *testing.T) {
|
||||
errMsg, _ := NewErrorMessage("sender", "receiver", ErrCodeOperationFailed, "operation failed", "")
|
||||
err := handler.ValidateResponse(errMsg, MsgStats)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for error response")
|
||||
}
|
||||
|
||||
if !IsProtocolError(err) {
|
||||
t.Errorf("Expected ProtocolError, got %T", err)
|
||||
}
|
||||
|
||||
if GetProtocolErrorCode(err) != ErrCodeOperationFailed {
|
||||
t.Errorf("Expected code %d, got %d", ErrCodeOperationFailed, GetProtocolErrorCode(err))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WrongType", func(t *testing.T) {
|
||||
msg, _ := NewMessage(MsgPong, "sender", "receiver", nil)
|
||||
err := handler.ValidateResponse(msg, MsgStats)
|
||||
if err == nil {
|
||||
t.Error("Expected error for wrong type")
|
||||
}
|
||||
if IsProtocolError(err) {
|
||||
t.Error("Should not be a ProtocolError for type mismatch")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ValidResponse", func(t *testing.T) {
|
||||
msg, _ := NewMessage(MsgStats, "sender", "receiver", StatsPayload{NodeID: "test"})
|
||||
err := handler.ValidateResponse(msg, MsgStats)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponseHandler_ParseResponse(t *testing.T) {
|
||||
handler := &ResponseHandler{}
|
||||
|
||||
t.Run("ParseStats", func(t *testing.T) {
|
||||
payload := StatsPayload{
|
||||
NodeID: "node-123",
|
||||
NodeName: "Test Node",
|
||||
Uptime: 3600,
|
||||
}
|
||||
msg, _ := NewMessage(MsgStats, "sender", "receiver", payload)
|
||||
|
||||
var parsed StatsPayload
|
||||
err := handler.ParseResponse(msg, MsgStats, &parsed)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if parsed.NodeID != "node-123" {
|
||||
t.Errorf("Expected NodeID 'node-123', got '%s'", parsed.NodeID)
|
||||
}
|
||||
if parsed.Uptime != 3600 {
|
||||
t.Errorf("Expected Uptime 3600, got %d", parsed.Uptime)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ParseMinerAck", func(t *testing.T) {
|
||||
payload := MinerAckPayload{
|
||||
Success: true,
|
||||
MinerName: "xmrig-1",
|
||||
}
|
||||
msg, _ := NewMessage(MsgMinerAck, "sender", "receiver", payload)
|
||||
|
||||
var parsed MinerAckPayload
|
||||
err := handler.ParseResponse(msg, MsgMinerAck, &parsed)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !parsed.Success {
|
||||
t.Error("Expected Success to be true")
|
||||
}
|
||||
if parsed.MinerName != "xmrig-1" {
|
||||
t.Errorf("Expected MinerName 'xmrig-1', got '%s'", parsed.MinerName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ErrorResponse", func(t *testing.T) {
|
||||
errMsg, _ := NewErrorMessage("sender", "receiver", ErrCodeNotFound, "not found", "")
|
||||
|
||||
var parsed StatsPayload
|
||||
err := handler.ParseResponse(errMsg, MsgStats, &parsed)
|
||||
if err == nil {
|
||||
t.Error("Expected error for error response")
|
||||
}
|
||||
if !IsProtocolError(err) {
|
||||
t.Errorf("Expected ProtocolError, got %T", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NilTarget", func(t *testing.T) {
|
||||
msg, _ := NewMessage(MsgPong, "sender", "receiver", nil)
|
||||
err := handler.ParseResponse(msg, MsgPong, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error with nil target: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProtocolError(t *testing.T) {
|
||||
err := &ProtocolError{Code: 1001, Message: "test error"}
|
||||
|
||||
if err.Error() != "remote error (1001): test error" {
|
||||
t.Errorf("Unexpected error message: %s", err.Error())
|
||||
}
|
||||
|
||||
if !IsProtocolError(err) {
|
||||
t.Error("IsProtocolError should return true")
|
||||
}
|
||||
|
||||
if GetProtocolErrorCode(err) != 1001 {
|
||||
t.Errorf("Expected code 1001, got %d", GetProtocolErrorCode(err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvenienceFunctions(t *testing.T) {
|
||||
msg, _ := NewMessage(MsgStats, "sender", "receiver", StatsPayload{NodeID: "test"})
|
||||
|
||||
// Test ValidateResponse
|
||||
if err := ValidateResponse(msg, MsgStats); err != nil {
|
||||
t.Errorf("ValidateResponse failed: %v", err)
|
||||
}
|
||||
|
||||
// Test ParseResponse
|
||||
var parsed StatsPayload
|
||||
if err := ParseResponse(msg, MsgStats, &parsed); err != nil {
|
||||
t.Errorf("ParseResponse failed: %v", err)
|
||||
}
|
||||
if parsed.NodeID != "test" {
|
||||
t.Errorf("Expected NodeID 'test', got '%s'", parsed.NodeID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProtocolErrorCode_NonProtocolError(t *testing.T) {
|
||||
err := fmt.Errorf("regular error")
|
||||
if GetProtocolErrorCode(err) != 0 {
|
||||
t.Error("Expected 0 for non-ProtocolError")
|
||||
}
|
||||
}
|
||||
|
|
@ -198,6 +198,11 @@ func (w *Worker) handleStartMiner(msg *Message) (*Message, error) {
|
|||
return nil, fmt.Errorf("invalid start miner payload: %w", err)
|
||||
}
|
||||
|
||||
// Validate miner type is provided
|
||||
if payload.MinerType == "" {
|
||||
return nil, fmt.Errorf("miner type is required")
|
||||
}
|
||||
|
||||
// Get the config from the profile or use the override
|
||||
var config interface{}
|
||||
if payload.Config != nil {
|
||||
|
|
@ -213,7 +218,7 @@ func (w *Worker) handleStartMiner(msg *Message) (*Message, error) {
|
|||
}
|
||||
|
||||
// Start the miner
|
||||
miner, err := w.minerManager.StartMiner("", config)
|
||||
miner, err := w.minerManager.StartMiner(payload.MinerType, config)
|
||||
if err != nil {
|
||||
ack := MinerAckPayload{
|
||||
Success: false,
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ func TestWorker_HandleStartMiner_NoManager(t *testing.T) {
|
|||
if identity == nil {
|
||||
t.Fatal("expected identity to be generated")
|
||||
}
|
||||
payload := StartMinerPayload{ProfileID: "test-profile"}
|
||||
payload := StartMinerPayload{MinerType: "xmrig", ProfileID: "test-profile"}
|
||||
msg, err := NewMessage(MsgStartMiner, "sender-id", identity.ID, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create start_miner message: %v", err)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue