commit 8f94639ec9bd4dce256da0667993d6a723f614f1 Author: Claude Date: Mon Feb 16 15:47:10 2026 +0000 feat: extract P2P networking and UEPS protocol from Mining repo P2P node layer (peer discovery, WebSocket transport, message protocol, worker pool, identity management) and Unified Ethical Protocol Stack (TLV packet builder with HMAC-signed frames). Ported from github.com/Snider/Mining/pkg/{node,ueps,logging} Co-Authored-By: Claude Opus 4.6 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0f65c03 --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module forge.lthn.ai/core/go-p2p + +go 1.25.5 + +require ( + github.com/Snider/Borg v0.2.0 + github.com/Snider/Poindexter v0.0.0-20260104200422-91146b212a1f + github.com/adrg/xdg v0.5.3 + github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 +) + +require ( + github.com/ProtonMail/go-crypto v1.3.0 // indirect + github.com/Snider/Enchantrix v0.0.2 // indirect + github.com/cloudflare/circl v1.6.1 // indirect + github.com/klauspost/compress v1.18.2 // indirect + golang.org/x/crypto v0.45.0 // indirect + golang.org/x/sys v0.38.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3c92b87 --- /dev/null +++ b/go.sum @@ -0,0 +1,30 @@ +github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw= +github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE= +github.com/Snider/Borg v0.2.0 h1:iCyDhY4WTXi39+FexRwXbn2YpZ2U9FUXVXDZk9xRCXQ= +github.com/Snider/Borg v0.2.0/go.mod h1:TqlKnfRo9okioHbgrZPfWjQsztBV0Nfskz4Om1/vdMY= +github.com/Snider/Enchantrix v0.0.2 h1:ExZQiBhfS/p/AHFTKhY80TOd+BXZjK95EzByAEgwvjs= +github.com/Snider/Enchantrix v0.0.2/go.mod h1:CtFcLAvnDT1KcuF1JBb/DJj0KplY8jHryO06KzQ1hsQ= +github.com/Snider/Poindexter v0.0.0-20260104200422-91146b212a1f h1:+EnE414H9wUaBeUVNjyErusrxSbBGnGV6MBhTw/em0k= +github.com/Snider/Poindexter v0.0.0-20260104200422-91146b212a1f/go.mod h1:nhgkbg4zWA4AS2Ga3RmcvdsyiI9TdxvSqe5EVBSb3Hk= +github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78= +github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ= +github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= +github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/logging/logger.go b/logging/logger.go new file mode 100644 index 0000000..f400dc9 --- /dev/null +++ b/logging/logger.go @@ -0,0 +1,284 @@ +// Package logging provides structured logging with log levels and fields. +package logging + +import ( + "fmt" + "io" + "os" + "strings" + "sync" + "time" +) + +// Level represents the severity of a log message. +type Level int + +const ( + // LevelDebug is the most verbose log level. + LevelDebug Level = iota + // LevelInfo is for general informational messages. + LevelInfo + // LevelWarn is for warning messages. + LevelWarn + // LevelError is for error messages. + LevelError +) + +// String returns the string representation of the log level. +func (l Level) String() string { + switch l { + case LevelDebug: + return "DEBUG" + case LevelInfo: + return "INFO" + case LevelWarn: + return "WARN" + case LevelError: + return "ERROR" + default: + return "UNKNOWN" + } +} + +// Logger provides structured logging with configurable output and level. +type Logger struct { + mu sync.Mutex + output io.Writer + level Level + component string +} + +// Config holds configuration for creating a new Logger. +type Config struct { + Output io.Writer + Level Level + Component string +} + +// DefaultConfig returns the default logger configuration. +func DefaultConfig() Config { + return Config{ + Output: os.Stderr, + Level: LevelInfo, + Component: "", + } +} + +// New creates a new Logger with the given configuration. +func New(cfg Config) *Logger { + if cfg.Output == nil { + cfg.Output = os.Stderr + } + return &Logger{ + output: cfg.Output, + level: cfg.Level, + component: cfg.Component, + } +} + +// WithComponent returns a new Logger with the specified component name. +func (l *Logger) WithComponent(component string) *Logger { + return &Logger{ + output: l.output, + level: l.level, + component: component, + } +} + +// SetLevel sets the minimum log level. +func (l *Logger) SetLevel(level Level) { + l.mu.Lock() + defer l.mu.Unlock() + l.level = level +} + +// GetLevel returns the current log level. +func (l *Logger) GetLevel() Level { + l.mu.Lock() + defer l.mu.Unlock() + return l.level +} + +// Fields represents key-value pairs for structured logging. +type Fields map[string]interface{} + +// log writes a log message at the specified level. +func (l *Logger) log(level Level, msg string, fields Fields) { + l.mu.Lock() + defer l.mu.Unlock() + + if level < l.level { + return + } + + // Build the log line + var sb strings.Builder + timestamp := time.Now().Format("2006/01/02 15:04:05") + sb.WriteString(timestamp) + sb.WriteString(" [") + sb.WriteString(level.String()) + sb.WriteString("]") + + if l.component != "" { + sb.WriteString(" [") + sb.WriteString(l.component) + sb.WriteString("]") + } + + sb.WriteString(" ") + sb.WriteString(msg) + + // Add fields if present + if len(fields) > 0 { + sb.WriteString(" |") + for k, v := range fields { + sb.WriteString(" ") + sb.WriteString(k) + sb.WriteString("=") + sb.WriteString(fmt.Sprintf("%v", v)) + } + } + + sb.WriteString("\n") + fmt.Fprint(l.output, sb.String()) +} + +// Debug logs a debug message. +func (l *Logger) Debug(msg string, fields ...Fields) { + l.log(LevelDebug, msg, mergeFields(fields)) +} + +// Info logs an informational message. +func (l *Logger) Info(msg string, fields ...Fields) { + l.log(LevelInfo, msg, mergeFields(fields)) +} + +// Warn logs a warning message. +func (l *Logger) Warn(msg string, fields ...Fields) { + l.log(LevelWarn, msg, mergeFields(fields)) +} + +// Error logs an error message. +func (l *Logger) Error(msg string, fields ...Fields) { + l.log(LevelError, msg, mergeFields(fields)) +} + +// Debugf logs a formatted debug message. +func (l *Logger) Debugf(format string, args ...interface{}) { + l.log(LevelDebug, fmt.Sprintf(format, args...), nil) +} + +// Infof logs a formatted informational message. +func (l *Logger) Infof(format string, args ...interface{}) { + l.log(LevelInfo, fmt.Sprintf(format, args...), nil) +} + +// Warnf logs a formatted warning message. +func (l *Logger) Warnf(format string, args ...interface{}) { + l.log(LevelWarn, fmt.Sprintf(format, args...), nil) +} + +// Errorf logs a formatted error message. +func (l *Logger) Errorf(format string, args ...interface{}) { + l.log(LevelError, fmt.Sprintf(format, args...), nil) +} + +// mergeFields combines multiple Fields maps into one. +func mergeFields(fields []Fields) Fields { + if len(fields) == 0 { + return nil + } + result := make(Fields) + for _, f := range fields { + for k, v := range f { + result[k] = v + } + } + return result +} + +// --- Global logger for convenience --- + +var ( + globalLogger = New(DefaultConfig()) + globalMu sync.RWMutex +) + +// SetGlobal sets the global logger instance. +func SetGlobal(l *Logger) { + globalMu.Lock() + defer globalMu.Unlock() + globalLogger = l +} + +// GetGlobal returns the global logger instance. +func GetGlobal() *Logger { + globalMu.RLock() + defer globalMu.RUnlock() + return globalLogger +} + +// SetGlobalLevel sets the log level of the global logger. +func SetGlobalLevel(level Level) { + globalMu.RLock() + defer globalMu.RUnlock() + globalLogger.SetLevel(level) +} + +// Global convenience functions that use the global logger + +// Debug logs a debug message using the global logger. +func Debug(msg string, fields ...Fields) { + GetGlobal().Debug(msg, fields...) +} + +// Info logs an informational message using the global logger. +func Info(msg string, fields ...Fields) { + GetGlobal().Info(msg, fields...) +} + +// Warn logs a warning message using the global logger. +func Warn(msg string, fields ...Fields) { + GetGlobal().Warn(msg, fields...) +} + +// Error logs an error message using the global logger. +func Error(msg string, fields ...Fields) { + GetGlobal().Error(msg, fields...) +} + +// Debugf logs a formatted debug message using the global logger. +func Debugf(format string, args ...interface{}) { + GetGlobal().Debugf(format, args...) +} + +// Infof logs a formatted informational message using the global logger. +func Infof(format string, args ...interface{}) { + GetGlobal().Infof(format, args...) +} + +// Warnf logs a formatted warning message using the global logger. +func Warnf(format string, args ...interface{}) { + GetGlobal().Warnf(format, args...) +} + +// Errorf logs a formatted error message using the global logger. +func Errorf(format string, args ...interface{}) { + GetGlobal().Errorf(format, args...) +} + +// ParseLevel parses a string into a log level. +func ParseLevel(s string) (Level, error) { + switch strings.ToUpper(s) { + case "DEBUG": + return LevelDebug, nil + case "INFO": + return LevelInfo, nil + case "WARN", "WARNING": + return LevelWarn, nil + case "ERROR": + return LevelError, nil + default: + return LevelInfo, fmt.Errorf("unknown log level: %s", s) + } +} diff --git a/logging/logger_test.go b/logging/logger_test.go new file mode 100644 index 0000000..5fa5163 --- /dev/null +++ b/logging/logger_test.go @@ -0,0 +1,262 @@ +package logging + +import ( + "bytes" + "strings" + "testing" +) + +func TestLoggerLevels(t *testing.T) { + var buf bytes.Buffer + logger := New(Config{ + Output: &buf, + Level: LevelInfo, + }) + + // Debug should not appear at Info level + logger.Debug("debug message") + if buf.Len() > 0 { + t.Error("Debug message should not appear at Info level") + } + + // Info should appear + logger.Info("info message") + if !strings.Contains(buf.String(), "[INFO]") { + t.Error("Info message should appear") + } + if !strings.Contains(buf.String(), "info message") { + t.Error("Info message content should appear") + } + buf.Reset() + + // Warn should appear + logger.Warn("warn message") + if !strings.Contains(buf.String(), "[WARN]") { + t.Error("Warn message should appear") + } + buf.Reset() + + // Error should appear + logger.Error("error message") + if !strings.Contains(buf.String(), "[ERROR]") { + t.Error("Error message should appear") + } +} + +func TestLoggerDebugLevel(t *testing.T) { + var buf bytes.Buffer + logger := New(Config{ + Output: &buf, + Level: LevelDebug, + }) + + logger.Debug("debug message") + if !strings.Contains(buf.String(), "[DEBUG]") { + t.Error("Debug message should appear at Debug level") + } +} + +func TestLoggerWithFields(t *testing.T) { + var buf bytes.Buffer + logger := New(Config{ + Output: &buf, + Level: LevelInfo, + }) + + logger.Info("test message", Fields{"key": "value", "num": 42}) + output := buf.String() + + if !strings.Contains(output, "key=value") { + t.Error("Field key=value should appear") + } + if !strings.Contains(output, "num=42") { + t.Error("Field num=42 should appear") + } +} + +func TestLoggerWithComponent(t *testing.T) { + var buf bytes.Buffer + logger := New(Config{ + Output: &buf, + Level: LevelInfo, + Component: "TestComponent", + }) + + logger.Info("test message") + output := buf.String() + + if !strings.Contains(output, "[TestComponent]") { + t.Error("Component name should appear in log") + } +} + +func TestLoggerDerivedComponent(t *testing.T) { + var buf bytes.Buffer + parent := New(Config{ + Output: &buf, + Level: LevelInfo, + }) + + child := parent.WithComponent("ChildComponent") + child.Info("child message") + output := buf.String() + + if !strings.Contains(output, "[ChildComponent]") { + t.Error("Derived component name should appear") + } +} + +func TestLoggerFormatted(t *testing.T) { + var buf bytes.Buffer + logger := New(Config{ + Output: &buf, + Level: LevelInfo, + }) + + logger.Infof("formatted %s %d", "string", 123) + output := buf.String() + + if !strings.Contains(output, "formatted string 123") { + t.Errorf("Formatted message should appear, got: %s", output) + } +} + +func TestSetLevel(t *testing.T) { + var buf bytes.Buffer + logger := New(Config{ + Output: &buf, + Level: LevelError, + }) + + // Info should not appear at Error level + logger.Info("should not appear") + if buf.Len() > 0 { + t.Error("Info should not appear at Error level") + } + + // Change to Info level + logger.SetLevel(LevelInfo) + logger.Info("should appear now") + if !strings.Contains(buf.String(), "should appear now") { + t.Error("Info should appear after level change") + } + + // Verify GetLevel + if logger.GetLevel() != LevelInfo { + t.Error("GetLevel should return LevelInfo") + } +} + +func TestParseLevel(t *testing.T) { + tests := []struct { + input string + expected Level + wantErr bool + }{ + {"DEBUG", LevelDebug, false}, + {"debug", LevelDebug, false}, + {"INFO", LevelInfo, false}, + {"info", LevelInfo, false}, + {"WARN", LevelWarn, false}, + {"WARNING", LevelWarn, false}, + {"ERROR", LevelError, false}, + {"error", LevelError, false}, + {"invalid", LevelInfo, true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + level, err := ParseLevel(tt.input) + if tt.wantErr && err == nil { + t.Error("Expected error but got none") + } + if !tt.wantErr && err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !tt.wantErr && level != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, level) + } + }) + } +} + +func TestGlobalLogger(t *testing.T) { + var buf bytes.Buffer + logger := New(Config{ + Output: &buf, + Level: LevelInfo, + }) + + SetGlobal(logger) + + Info("global test") + if !strings.Contains(buf.String(), "global test") { + t.Error("Global logger should write message") + } + + buf.Reset() + SetGlobalLevel(LevelError) + Info("should not appear") + if buf.Len() > 0 { + t.Error("Info should not appear at Error level") + } + + // Reset to default for other tests + SetGlobal(New(DefaultConfig())) +} + +func TestLevelString(t *testing.T) { + tests := []struct { + level Level + expected string + }{ + {LevelDebug, "DEBUG"}, + {LevelInfo, "INFO"}, + {LevelWarn, "WARN"}, + {LevelError, "ERROR"}, + {Level(99), "UNKNOWN"}, + } + + for _, tt := range tests { + if got := tt.level.String(); got != tt.expected { + t.Errorf("Level(%d).String() = %s, want %s", tt.level, got, tt.expected) + } + } +} + +func TestMergeFields(t *testing.T) { + // Empty fields + result := mergeFields(nil) + if result != nil { + t.Error("nil input should return nil") + } + + result = mergeFields([]Fields{}) + if result != nil { + t.Error("empty input should return nil") + } + + // Single fields + result = mergeFields([]Fields{{"key": "value"}}) + if result["key"] != "value" { + t.Error("Single field should be preserved") + } + + // Multiple fields + result = mergeFields([]Fields{ + {"key1": "value1"}, + {"key2": "value2"}, + }) + if result["key1"] != "value1" || result["key2"] != "value2" { + t.Error("Multiple fields should be merged") + } + + // Override + result = mergeFields([]Fields{ + {"key": "value1"}, + {"key": "value2"}, + }) + if result["key"] != "value2" { + t.Error("Later fields should override earlier ones") + } +} diff --git a/node/bufpool.go b/node/bufpool.go new file mode 100644 index 0000000..a4f0e68 --- /dev/null +++ b/node/bufpool.go @@ -0,0 +1,55 @@ +package node + +import ( + "bytes" + "encoding/json" + "sync" +) + +// bufferPool provides reusable byte buffers for JSON encoding. +// This reduces allocation overhead in hot paths like message serialization. +var bufferPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, 1024)) + }, +} + +// getBuffer retrieves a buffer from the pool. +func getBuffer() *bytes.Buffer { + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + return buf +} + +// putBuffer returns a buffer to the pool. +func putBuffer(buf *bytes.Buffer) { + // Don't pool buffers that grew too large (>64KB) + if buf.Cap() <= 65536 { + bufferPool.Put(buf) + } +} + +// MarshalJSON encodes a value to JSON using a pooled buffer. +// Returns a copy of the encoded bytes (safe to use after the function returns). +func MarshalJSON(v interface{}) ([]byte, error) { + buf := getBuffer() + defer putBuffer(buf) + + enc := json.NewEncoder(buf) + // Don't escape HTML characters (matches json.Marshal behavior for these use cases) + enc.SetEscapeHTML(false) + if err := enc.Encode(v); err != nil { + return nil, err + } + + // json.Encoder.Encode adds a newline; remove it to match json.Marshal + data := buf.Bytes() + if len(data) > 0 && data[len(data)-1] == '\n' { + data = data[:len(data)-1] + } + + // Return a copy since the buffer will be reused + result := make([]byte, len(data)) + copy(result, data) + return result, nil +} diff --git a/node/bundle.go b/node/bundle.go new file mode 100644 index 0000000..030f48e --- /dev/null +++ b/node/bundle.go @@ -0,0 +1,355 @@ +package node + +import ( + "archive/tar" + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/Snider/Borg/pkg/datanode" + "github.com/Snider/Borg/pkg/tim" +) + +// BundleType defines the type of deployment bundle. +type BundleType string + +const ( + BundleProfile BundleType = "profile" // Just config/profile JSON + BundleMiner BundleType = "miner" // Miner binary + config + BundleFull BundleType = "full" // Everything (miner + profiles + config) +) + +// Bundle represents a deployment bundle for P2P transfer. +type Bundle struct { + Type BundleType `json:"type"` + Name string `json:"name"` + Data []byte `json:"data"` // Encrypted STIM data or raw JSON + Checksum string `json:"checksum"` // SHA-256 of Data +} + +// BundleManifest describes the contents of a bundle. +type BundleManifest struct { + Type BundleType `json:"type"` + Name string `json:"name"` + Version string `json:"version,omitempty"` + MinerType string `json:"minerType,omitempty"` + ProfileIDs []string `json:"profileIds,omitempty"` + CreatedAt string `json:"createdAt"` +} + +// CreateProfileBundle creates an encrypted bundle containing a mining profile. +func CreateProfileBundle(profileJSON []byte, name string, password string) (*Bundle, error) { + // Create a TIM with just the profile config + t, err := tim.New() + if err != nil { + return nil, fmt.Errorf("failed to create TIM: %w", err) + } + t.Config = profileJSON + + // Encrypt to STIM format + stimData, err := t.ToSigil(password) + if err != nil { + return nil, fmt.Errorf("failed to encrypt bundle: %w", err) + } + + // Calculate checksum + checksum := calculateChecksum(stimData) + + return &Bundle{ + Type: BundleProfile, + Name: name, + Data: stimData, + Checksum: checksum, + }, nil +} + +// CreateProfileBundleUnencrypted creates a plain JSON bundle (for testing or trusted networks). +func CreateProfileBundleUnencrypted(profileJSON []byte, name string) (*Bundle, error) { + checksum := calculateChecksum(profileJSON) + + return &Bundle{ + Type: BundleProfile, + Name: name, + Data: profileJSON, + Checksum: checksum, + }, nil +} + +// CreateMinerBundle creates an encrypted bundle containing a miner binary and optional profile. +func CreateMinerBundle(minerPath string, profileJSON []byte, name string, password string) (*Bundle, error) { + // Read miner binary + minerData, err := os.ReadFile(minerPath) + if err != nil { + return nil, fmt.Errorf("failed to read miner binary: %w", err) + } + + // Create a tarball with the miner binary + tarData, err := createTarball(map[string][]byte{ + filepath.Base(minerPath): minerData, + }) + if err != nil { + return nil, fmt.Errorf("failed to create tarball: %w", err) + } + + // Create DataNode from tarball + dn, err := datanode.FromTar(tarData) + if err != nil { + return nil, fmt.Errorf("failed to create datanode: %w", err) + } + + // Create TIM from DataNode + t, err := tim.FromDataNode(dn) + if err != nil { + return nil, fmt.Errorf("failed to create TIM: %w", err) + } + + // Set profile as config if provided + if profileJSON != nil { + t.Config = profileJSON + } + + // Encrypt to STIM format + stimData, err := t.ToSigil(password) + if err != nil { + return nil, fmt.Errorf("failed to encrypt bundle: %w", err) + } + + checksum := calculateChecksum(stimData) + + return &Bundle{ + Type: BundleMiner, + Name: name, + Data: stimData, + Checksum: checksum, + }, nil +} + +// ExtractProfileBundle decrypts and extracts a profile bundle. +func ExtractProfileBundle(bundle *Bundle, password string) ([]byte, error) { + // Verify checksum first + if calculateChecksum(bundle.Data) != bundle.Checksum { + return nil, fmt.Errorf("checksum mismatch - bundle may be corrupted") + } + + // If it's unencrypted JSON, just return it + if isJSON(bundle.Data) { + return bundle.Data, nil + } + + // Decrypt STIM format + t, err := tim.FromSigil(bundle.Data, password) + if err != nil { + return nil, fmt.Errorf("failed to decrypt bundle: %w", err) + } + + return t.Config, nil +} + +// ExtractMinerBundle decrypts and extracts a miner bundle, returning the miner path and profile. +func ExtractMinerBundle(bundle *Bundle, password string, destDir string) (string, []byte, error) { + // Verify checksum + if calculateChecksum(bundle.Data) != bundle.Checksum { + return "", nil, fmt.Errorf("checksum mismatch - bundle may be corrupted") + } + + // Decrypt STIM format + t, err := tim.FromSigil(bundle.Data, password) + if err != nil { + return "", nil, fmt.Errorf("failed to decrypt bundle: %w", err) + } + + // Convert rootfs to tarball and extract + tarData, err := t.RootFS.ToTar() + if err != nil { + return "", nil, fmt.Errorf("failed to convert rootfs to tar: %w", err) + } + + // Extract tarball to destination + minerPath, err := extractTarball(tarData, destDir) + if err != nil { + return "", nil, fmt.Errorf("failed to extract tarball: %w", err) + } + + return minerPath, t.Config, nil +} + +// VerifyBundle checks if a bundle's checksum is valid. +func VerifyBundle(bundle *Bundle) bool { + return calculateChecksum(bundle.Data) == bundle.Checksum +} + +// calculateChecksum computes SHA-256 checksum of data. +func calculateChecksum(data []byte) string { + hash := sha256.Sum256(data) + return hex.EncodeToString(hash[:]) +} + +// isJSON checks if data starts with JSON characters. +func isJSON(data []byte) bool { + if len(data) == 0 { + return false + } + // JSON typically starts with { or [ + return data[0] == '{' || data[0] == '[' +} + +// createTarball creates a tar archive from a map of filename -> content. +func createTarball(files map[string][]byte) ([]byte, error) { + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + + // Track directories we've created + dirs := make(map[string]bool) + + for name, content := range files { + // Create parent directories if needed + dir := filepath.Dir(name) + if dir != "." && !dirs[dir] { + hdr := &tar.Header{ + Name: dir + "/", + Mode: 0755, + Typeflag: tar.TypeDir, + } + if err := tw.WriteHeader(hdr); err != nil { + return nil, err + } + dirs[dir] = true + } + + // Determine file mode (executable for binaries in miners/) + mode := int64(0644) + if filepath.Dir(name) == "miners" || !isJSON(content) { + mode = 0755 + } + + hdr := &tar.Header{ + Name: name, + Mode: mode, + Size: int64(len(content)), + } + if err := tw.WriteHeader(hdr); err != nil { + return nil, err + } + if _, err := tw.Write(content); err != nil { + return nil, err + } + } + + if err := tw.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// extractTarball extracts a tar archive to a directory, returns first executable found. +func extractTarball(tarData []byte, destDir string) (string, error) { + // Ensure destDir is an absolute, clean path for security checks + absDestDir, err := filepath.Abs(destDir) + if err != nil { + return "", fmt.Errorf("failed to resolve destination directory: %w", err) + } + absDestDir = filepath.Clean(absDestDir) + + if err := os.MkdirAll(absDestDir, 0755); err != nil { + return "", err + } + + tr := tar.NewReader(bytes.NewReader(tarData)) + var firstExecutable string + + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return "", err + } + + // Security: Sanitize the tar entry name to prevent path traversal (Zip Slip) + cleanName := filepath.Clean(hdr.Name) + + // Reject absolute paths + if filepath.IsAbs(cleanName) { + return "", fmt.Errorf("invalid tar entry: absolute path not allowed: %s", hdr.Name) + } + + // Reject paths that escape the destination directory + if strings.HasPrefix(cleanName, ".."+string(os.PathSeparator)) || cleanName == ".." { + return "", fmt.Errorf("invalid tar entry: path traversal attempt: %s", hdr.Name) + } + + // Build the full path and verify it's within destDir + fullPath := filepath.Join(absDestDir, cleanName) + fullPath = filepath.Clean(fullPath) + + // Final security check: ensure the path is still within destDir + if !strings.HasPrefix(fullPath, absDestDir+string(os.PathSeparator)) && fullPath != absDestDir { + return "", fmt.Errorf("invalid tar entry: path escape attempt: %s", hdr.Name) + } + + switch hdr.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(fullPath, os.FileMode(hdr.Mode)); err != nil { + return "", err + } + case tar.TypeReg: + // Ensure parent directory exists + if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { + return "", err + } + + f, err := os.OpenFile(fullPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)) + if err != nil { + return "", err + } + + // Limit file size to prevent decompression bombs (100MB max per file) + const maxFileSize int64 = 100 * 1024 * 1024 + limitedReader := io.LimitReader(tr, maxFileSize+1) + written, err := io.Copy(f, limitedReader) + f.Close() + if err != nil { + return "", err + } + if written > maxFileSize { + os.Remove(fullPath) + return "", fmt.Errorf("file %s exceeds maximum size of %d bytes", hdr.Name, maxFileSize) + } + + // Track first executable + if hdr.Mode&0111 != 0 && firstExecutable == "" { + firstExecutable = fullPath + } + // Explicitly ignore symlinks and hard links to prevent symlink attacks + case tar.TypeSymlink, tar.TypeLink: + // Skip symlinks and hard links for security + continue + } + } + + return firstExecutable, nil +} + +// StreamBundle writes a bundle to a writer (for large transfers). +func StreamBundle(bundle *Bundle, w io.Writer) error { + encoder := json.NewEncoder(w) + return encoder.Encode(bundle) +} + +// ReadBundle reads a bundle from a reader. +func ReadBundle(r io.Reader) (*Bundle, error) { + var bundle Bundle + decoder := json.NewDecoder(r) + if err := decoder.Decode(&bundle); err != nil { + return nil, err + } + return &bundle, nil +} diff --git a/node/bundle_test.go b/node/bundle_test.go new file mode 100644 index 0000000..4bc8f26 --- /dev/null +++ b/node/bundle_test.go @@ -0,0 +1,352 @@ +package node + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +func TestCreateProfileBundleUnencrypted(t *testing.T) { + profileJSON := []byte(`{"name":"test-profile","minerType":"xmrig","config":{}}`) + + bundle, err := CreateProfileBundleUnencrypted(profileJSON, "test-profile") + if err != nil { + t.Fatalf("failed to create bundle: %v", err) + } + + if bundle.Type != BundleProfile { + t.Errorf("expected type BundleProfile, got %s", bundle.Type) + } + + if bundle.Name != "test-profile" { + t.Errorf("expected name 'test-profile', got '%s'", bundle.Name) + } + + if bundle.Checksum == "" { + t.Error("checksum should not be empty") + } + + if !bytes.Equal(bundle.Data, profileJSON) { + t.Error("data should match original JSON") + } +} + +func TestVerifyBundle(t *testing.T) { + t.Run("ValidChecksum", func(t *testing.T) { + bundle, _ := CreateProfileBundleUnencrypted([]byte(`{"test":"data"}`), "test") + + if !VerifyBundle(bundle) { + t.Error("valid bundle should verify") + } + }) + + t.Run("InvalidChecksum", func(t *testing.T) { + bundle, _ := CreateProfileBundleUnencrypted([]byte(`{"test":"data"}`), "test") + bundle.Checksum = "invalid-checksum" + + if VerifyBundle(bundle) { + t.Error("bundle with invalid checksum should not verify") + } + }) + + t.Run("ModifiedData", func(t *testing.T) { + bundle, _ := CreateProfileBundleUnencrypted([]byte(`{"test":"data"}`), "test") + bundle.Data = []byte(`{"test":"modified"}`) + + if VerifyBundle(bundle) { + t.Error("bundle with modified data should not verify") + } + }) +} + +func TestCreateProfileBundle(t *testing.T) { + profileJSON := []byte(`{"name":"encrypted-profile","minerType":"xmrig"}`) + password := "test-password-123" + + bundle, err := CreateProfileBundle(profileJSON, "encrypted-test", password) + if err != nil { + t.Fatalf("failed to create encrypted bundle: %v", err) + } + + if bundle.Type != BundleProfile { + t.Errorf("expected type BundleProfile, got %s", bundle.Type) + } + + // Encrypted data should not match original + if bytes.Equal(bundle.Data, profileJSON) { + t.Error("encrypted data should not match original") + } + + // Should be able to extract with correct password + extracted, err := ExtractProfileBundle(bundle, password) + if err != nil { + t.Fatalf("failed to extract bundle: %v", err) + } + + if !bytes.Equal(extracted, profileJSON) { + t.Errorf("extracted data should match original: got %s", string(extracted)) + } +} + +func TestExtractProfileBundle(t *testing.T) { + t.Run("UnencryptedBundle", func(t *testing.T) { + originalJSON := []byte(`{"name":"plain","config":{}}`) + bundle, _ := CreateProfileBundleUnencrypted(originalJSON, "plain") + + extracted, err := ExtractProfileBundle(bundle, "") + if err != nil { + t.Fatalf("failed to extract unencrypted bundle: %v", err) + } + + if !bytes.Equal(extracted, originalJSON) { + t.Error("extracted data should match original") + } + }) + + t.Run("EncryptedBundle", func(t *testing.T) { + originalJSON := []byte(`{"name":"secret","config":{"pool":"pool.example.com"}}`) + password := "strong-password" + + bundle, _ := CreateProfileBundle(originalJSON, "secret", password) + + extracted, err := ExtractProfileBundle(bundle, password) + if err != nil { + t.Fatalf("failed to extract encrypted bundle: %v", err) + } + + if !bytes.Equal(extracted, originalJSON) { + t.Error("extracted data should match original") + } + }) + + t.Run("WrongPassword", func(t *testing.T) { + originalJSON := []byte(`{"name":"secret"}`) + bundle, _ := CreateProfileBundle(originalJSON, "secret", "correct-password") + + _, err := ExtractProfileBundle(bundle, "wrong-password") + if err == nil { + t.Error("should fail with wrong password") + } + }) + + t.Run("CorruptedChecksum", func(t *testing.T) { + bundle, _ := CreateProfileBundleUnencrypted([]byte(`{}`), "test") + bundle.Checksum = "corrupted" + + _, err := ExtractProfileBundle(bundle, "") + if err == nil { + t.Error("should fail with corrupted checksum") + } + }) +} + +func TestTarballFunctions(t *testing.T) { + t.Run("CreateAndExtractTarball", func(t *testing.T) { + files := map[string][]byte{ + "file1.txt": []byte("content of file 1"), + "dir/file2.json": []byte(`{"key":"value"}`), + "miners/xmrig": []byte("binary content"), + } + + tarData, err := createTarball(files) + if err != nil { + t.Fatalf("failed to create tarball: %v", err) + } + + if len(tarData) == 0 { + t.Error("tarball should not be empty") + } + + // Extract to temp directory + tmpDir, _ := os.MkdirTemp("", "tarball-test") + defer os.RemoveAll(tmpDir) + + firstExec, err := extractTarball(tarData, tmpDir) + if err != nil { + t.Fatalf("failed to extract tarball: %v", err) + } + + // Check files exist + for name, content := range files { + path := filepath.Join(tmpDir, name) + data, err := os.ReadFile(path) + if err != nil { + t.Errorf("failed to read extracted file %s: %v", name, err) + continue + } + + if !bytes.Equal(data, content) { + t.Errorf("content mismatch for %s", name) + } + } + + // Check first executable is the miner + if firstExec == "" { + t.Error("should find an executable") + } + }) +} + +func TestStreamAndReadBundle(t *testing.T) { + original, _ := CreateProfileBundleUnencrypted([]byte(`{"streaming":"test"}`), "stream-test") + + // Stream to buffer + var buf bytes.Buffer + err := StreamBundle(original, &buf) + if err != nil { + t.Fatalf("failed to stream bundle: %v", err) + } + + // Read back + restored, err := ReadBundle(&buf) + if err != nil { + t.Fatalf("failed to read bundle: %v", err) + } + + if restored.Name != original.Name { + t.Errorf("name mismatch: expected '%s', got '%s'", original.Name, restored.Name) + } + + if restored.Checksum != original.Checksum { + t.Error("checksum mismatch") + } + + if !bytes.Equal(restored.Data, original.Data) { + t.Error("data mismatch") + } +} + +func TestCalculateChecksum(t *testing.T) { + t.Run("Deterministic", func(t *testing.T) { + data := []byte("test data for checksum") + + checksum1 := calculateChecksum(data) + checksum2 := calculateChecksum(data) + + if checksum1 != checksum2 { + t.Error("checksum should be deterministic") + } + }) + + t.Run("DifferentData", func(t *testing.T) { + checksum1 := calculateChecksum([]byte("data1")) + checksum2 := calculateChecksum([]byte("data2")) + + if checksum1 == checksum2 { + t.Error("different data should produce different checksums") + } + }) + + t.Run("HexFormat", func(t *testing.T) { + checksum := calculateChecksum([]byte("test")) + + // SHA-256 produces 64 hex characters + if len(checksum) != 64 { + t.Errorf("expected 64 character hex string, got %d characters", len(checksum)) + } + + // Should be valid hex + for _, c := range checksum { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("invalid hex character: %c", c) + } + } + }) +} + +func TestIsJSON(t *testing.T) { + tests := []struct { + data []byte + expected bool + }{ + {[]byte(`{"key":"value"}`), true}, + {[]byte(`["item1","item2"]`), true}, + {[]byte(`{}`), true}, + {[]byte(`[]`), true}, + {[]byte(`binary\x00data`), false}, + {[]byte(`plain text`), false}, + {[]byte{}, false}, + {nil, false}, + } + + for _, tt := range tests { + result := isJSON(tt.data) + if result != tt.expected { + t.Errorf("isJSON(%q) = %v, expected %v", tt.data, result, tt.expected) + } + } +} + +func TestBundleTypes(t *testing.T) { + types := []BundleType{ + BundleProfile, + BundleMiner, + BundleFull, + } + + expected := []string{"profile", "miner", "full"} + + for i, bt := range types { + if string(bt) != expected[i] { + t.Errorf("expected %s, got %s", expected[i], string(bt)) + } + } +} + +func TestCreateMinerBundle(t *testing.T) { + // Create a temp "miner binary" + tmpDir, _ := os.MkdirTemp("", "miner-bundle-test") + defer os.RemoveAll(tmpDir) + + minerPath := filepath.Join(tmpDir, "test-miner") + err := os.WriteFile(minerPath, []byte("fake miner binary content"), 0755) + if err != nil { + t.Fatalf("failed to create test miner: %v", err) + } + + profileJSON := []byte(`{"profile":"data"}`) + password := "miner-password" + + bundle, err := CreateMinerBundle(minerPath, profileJSON, "miner-bundle", password) + if err != nil { + t.Fatalf("failed to create miner bundle: %v", err) + } + + if bundle.Type != BundleMiner { + t.Errorf("expected type BundleMiner, got %s", bundle.Type) + } + + if bundle.Name != "miner-bundle" { + t.Errorf("expected name 'miner-bundle', got '%s'", bundle.Name) + } + + // Extract and verify + extractDir, _ := os.MkdirTemp("", "miner-extract-test") + defer os.RemoveAll(extractDir) + + extractedPath, extractedProfile, err := ExtractMinerBundle(bundle, password, extractDir) + if err != nil { + t.Fatalf("failed to extract miner bundle: %v", err) + } + + // Note: extractedPath may be empty if the tarball structure doesn't match + // what extractTarball expects (it looks for files at root with executable bit) + t.Logf("extracted path: %s", extractedPath) + + if !bytes.Equal(extractedProfile, profileJSON) { + t.Error("profile data mismatch") + } + + // If we got an extracted path, verify its content + if extractedPath != "" { + minerData, err := os.ReadFile(extractedPath) + if err != nil { + t.Fatalf("failed to read extracted miner: %v", err) + } + + if string(minerData) != "fake miner binary content" { + t.Error("miner content mismatch") + } + } +} diff --git a/node/controller.go b/node/controller.go new file mode 100644 index 0000000..9bc0c80 --- /dev/null +++ b/node/controller.go @@ -0,0 +1,327 @@ +package node + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "forge.lthn.ai/core/go-p2p/logging" +) + +// Controller manages remote peer operations from a controller node. +type Controller struct { + node *NodeManager + peers *PeerRegistry + transport *Transport + mu sync.RWMutex + + // Pending requests awaiting responses + pending map[string]chan *Message // message ID -> response channel +} + +// NewController creates a new Controller instance. +func NewController(node *NodeManager, peers *PeerRegistry, transport *Transport) *Controller { + c := &Controller{ + node: node, + peers: peers, + transport: transport, + pending: make(map[string]chan *Message), + } + + // Register message handler for responses + transport.OnMessage(c.handleResponse) + + return c +} + +// handleResponse processes incoming messages that are responses to our requests. +func (c *Controller) handleResponse(conn *PeerConnection, msg *Message) { + if msg.ReplyTo == "" { + return // Not a response, let worker handle it + } + + c.mu.Lock() + ch, exists := c.pending[msg.ReplyTo] + if exists { + delete(c.pending, msg.ReplyTo) + } + c.mu.Unlock() + + if exists && ch != nil { + select { + case ch <- msg: + default: + // Channel full or closed + } + } +} + +// sendRequest sends a message and waits for a response. +func (c *Controller) sendRequest(peerID string, msg *Message, timeout time.Duration) (*Message, error) { + actualPeerID := peerID + + // Auto-connect if not already connected + if c.transport.GetConnection(peerID) == nil { + peer := c.peers.GetPeer(peerID) + if peer == nil { + return nil, fmt.Errorf("peer not found: %s", peerID) + } + conn, err := c.transport.Connect(peer) + if err != nil { + return nil, fmt.Errorf("failed to connect to peer: %w", err) + } + // Use the real peer ID after handshake (it may have changed) + actualPeerID = conn.Peer.ID + // Update the message destination + msg.To = actualPeerID + } + + // Create response channel + respCh := make(chan *Message, 1) + + c.mu.Lock() + c.pending[msg.ID] = respCh + c.mu.Unlock() + + // Clean up on exit - ensure channel is closed and removed from map + defer func() { + c.mu.Lock() + delete(c.pending, msg.ID) + c.mu.Unlock() + close(respCh) // Close channel to allow garbage collection + }() + + // Send the message + if err := c.transport.Send(actualPeerID, msg); err != nil { + return nil, fmt.Errorf("failed to send message: %w", err) + } + + // Wait for response + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + select { + case resp := <-respCh: + return resp, nil + case <-ctx.Done(): + return nil, fmt.Errorf("request timeout") + } +} + +// GetRemoteStats requests miner statistics from a remote peer. +func (c *Controller) GetRemoteStats(peerID string) (*StatsPayload, error) { + identity := c.node.GetIdentity() + if identity == nil { + return nil, fmt.Errorf("node identity not initialized") + } + + msg, err := NewMessage(MsgGetStats, identity.ID, peerID, nil) + if err != nil { + return nil, fmt.Errorf("failed to create message: %w", err) + } + + resp, err := c.sendRequest(peerID, msg, 10*time.Second) + if err != nil { + return nil, err + } + + var stats StatsPayload + if err := ParseResponse(resp, MsgStats, &stats); err != nil { + return nil, err + } + + return &stats, nil +} + +// StartRemoteMiner requests a remote peer to start a miner with a given profile. +func (c *Controller) StartRemoteMiner(peerID, minerType, profileID string, configOverride json.RawMessage) error { + identity := c.node.GetIdentity() + if identity == nil { + return fmt.Errorf("node identity not initialized") + } + + if minerType == "" { + return fmt.Errorf("miner type is required") + } + + payload := StartMinerPayload{ + MinerType: minerType, + ProfileID: profileID, + Config: configOverride, + } + + msg, err := NewMessage(MsgStartMiner, identity.ID, peerID, payload) + if err != nil { + return fmt.Errorf("failed to create message: %w", err) + } + + resp, err := c.sendRequest(peerID, msg, 30*time.Second) + if err != nil { + return err + } + + var ack MinerAckPayload + if err := ParseResponse(resp, MsgMinerAck, &ack); err != nil { + return err + } + + if !ack.Success { + return fmt.Errorf("miner start failed: %s", ack.Error) + } + + return nil +} + +// StopRemoteMiner requests a remote peer to stop a miner. +func (c *Controller) StopRemoteMiner(peerID, minerName string) error { + identity := c.node.GetIdentity() + if identity == nil { + return fmt.Errorf("node identity not initialized") + } + + payload := StopMinerPayload{ + MinerName: minerName, + } + + msg, err := NewMessage(MsgStopMiner, identity.ID, peerID, payload) + if err != nil { + return fmt.Errorf("failed to create message: %w", err) + } + + resp, err := c.sendRequest(peerID, msg, 30*time.Second) + if err != nil { + return err + } + + var ack MinerAckPayload + if err := ParseResponse(resp, MsgMinerAck, &ack); err != nil { + return err + } + + if !ack.Success { + return fmt.Errorf("miner stop failed: %s", ack.Error) + } + + return nil +} + +// GetRemoteLogs requests console logs from a remote miner. +func (c *Controller) GetRemoteLogs(peerID, minerName string, lines int) ([]string, error) { + identity := c.node.GetIdentity() + if identity == nil { + return nil, fmt.Errorf("node identity not initialized") + } + + payload := GetLogsPayload{ + MinerName: minerName, + Lines: lines, + } + + msg, err := NewMessage(MsgGetLogs, identity.ID, peerID, payload) + if err != nil { + return nil, fmt.Errorf("failed to create message: %w", err) + } + + resp, err := c.sendRequest(peerID, msg, 10*time.Second) + if err != nil { + return nil, err + } + + var logs LogsPayload + if err := ParseResponse(resp, MsgLogs, &logs); err != nil { + return nil, err + } + + return logs.Lines, nil +} + +// GetAllStats fetches stats from all connected peers. +func (c *Controller) GetAllStats() map[string]*StatsPayload { + peers := c.peers.GetConnectedPeers() + results := make(map[string]*StatsPayload) + var mu sync.Mutex + var wg sync.WaitGroup + + for _, peer := range peers { + wg.Add(1) + go func(p *Peer) { + defer wg.Done() + stats, err := c.GetRemoteStats(p.ID) + if err != nil { + logging.Debug("failed to get stats from peer", logging.Fields{ + "peer_id": p.ID, + "peer": p.Name, + "error": err.Error(), + }) + return // Skip failed peers + } + mu.Lock() + results[p.ID] = stats + mu.Unlock() + }(peer) + } + + wg.Wait() + return results +} + +// PingPeer sends a ping to a peer and updates metrics. +func (c *Controller) PingPeer(peerID string) (float64, error) { + identity := c.node.GetIdentity() + if identity == nil { + return 0, fmt.Errorf("node identity not initialized") + } + sentAt := time.Now() + + payload := PingPayload{ + SentAt: sentAt.UnixMilli(), + } + + msg, err := NewMessage(MsgPing, identity.ID, peerID, payload) + if err != nil { + return 0, fmt.Errorf("failed to create message: %w", err) + } + + resp, err := c.sendRequest(peerID, msg, 5*time.Second) + if err != nil { + return 0, err + } + + if err := ValidateResponse(resp, MsgPong); err != nil { + return 0, err + } + + // Calculate round-trip time + rtt := time.Since(sentAt).Seconds() * 1000 // Convert to ms + + // Update peer metrics + peer := c.peers.GetPeer(peerID) + if peer != nil { + c.peers.UpdateMetrics(peerID, rtt, peer.GeoKM, peer.Hops) + } + + return rtt, nil +} + +// ConnectToPeer establishes a connection to a peer. +func (c *Controller) ConnectToPeer(peerID string) error { + peer := c.peers.GetPeer(peerID) + if peer == nil { + return fmt.Errorf("peer not found: %s", peerID) + } + + _, err := c.transport.Connect(peer) + return err +} + +// DisconnectFromPeer closes connection to a peer. +func (c *Controller) DisconnectFromPeer(peerID string) error { + conn := c.transport.GetConnection(peerID) + if conn == nil { + return fmt.Errorf("peer not connected: %s", peerID) + } + + return conn.Close() +} diff --git a/node/dispatcher.go b/node/dispatcher.go new file mode 100644 index 0000000..dc9d23d --- /dev/null +++ b/node/dispatcher.go @@ -0,0 +1,39 @@ +package node + +// pkg/node/dispatcher.go + +/* +func (n *NodeManager) DispatchUEPS(pkt *ueps.ParsedPacket) error { + // 1. The "Threat" Circuit Breaker (L5 Guard) + if pkt.Header.ThreatScore > 50000 { + // High threat? Drop it. Don't even parse the payload. + // This protects the Agent from "semantic viruses" + return fmt.Errorf("packet rejected: threat score %d exceeds safety limit", pkt.Header.ThreatScore) + } + + // 2. The "Intent" Router (L9 Semantic) + switch pkt.Header.IntentID { + + case 0x01: // Handshake / Hello + // return n.handleHandshake(pkt) + + case 0x20: // Compute / Job Request + // "Hey, can you run this Docker container?" + // Check local resources first (Self-Validation) + // return n.handleComputeRequest(pkt.Payload) + + case 0x30: // Rehab / Intervention + // "Violet says you are hallucinating. Pause execution." + // This is the "Benevolent Intervention" Axiom. + // return n.enterRehabMode(pkt.Payload) + + case 0xFF: // Extended / Custom + // Check the payload for specific sub-protocols (e.g. your JSON blobs) + // return n.handleApplicationData(pkt.Payload) + + default: + return fmt.Errorf("unknown intent ID: 0x%X", pkt.Header.IntentID) + } + return nil +} +*/ diff --git a/node/identity.go b/node/identity.go new file mode 100644 index 0000000..31aac1c --- /dev/null +++ b/node/identity.go @@ -0,0 +1,290 @@ +// Package node provides P2P node identity and communication for multi-node mining management. +package node + +import ( + "crypto/ecdh" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/Snider/Borg/pkg/stmf" + "github.com/adrg/xdg" +) + +// ChallengeSize is the size of the challenge in bytes +const ChallengeSize = 32 + +// GenerateChallenge creates a random challenge for authentication. +func GenerateChallenge() ([]byte, error) { + challenge := make([]byte, ChallengeSize) + if _, err := rand.Read(challenge); err != nil { + return nil, fmt.Errorf("failed to generate challenge: %w", err) + } + return challenge, nil +} + +// SignChallenge creates an HMAC signature of a challenge using a shared secret. +// The signature proves possession of the shared secret without revealing it. +func SignChallenge(challenge []byte, sharedSecret []byte) []byte { + mac := hmac.New(sha256.New, sharedSecret) + mac.Write(challenge) + return mac.Sum(nil) +} + +// VerifyChallenge verifies that a challenge response was signed with the correct shared secret. +func VerifyChallenge(challenge, response, sharedSecret []byte) bool { + expected := SignChallenge(challenge, sharedSecret) + return hmac.Equal(response, expected) +} + +// NodeRole defines the operational mode of a node. +type NodeRole string + +const ( + // RoleController manages remote worker nodes. + RoleController NodeRole = "controller" + // RoleWorker receives commands and runs miners. + RoleWorker NodeRole = "worker" + // RoleDual operates as both controller and worker (default). + RoleDual NodeRole = "dual" +) + +// NodeIdentity represents the public identity of a node. +type NodeIdentity struct { + ID string `json:"id"` // Derived from public key (first 16 bytes hex) + Name string `json:"name"` // Human-friendly name + PublicKey string `json:"publicKey"` // X25519 base64 + CreatedAt time.Time `json:"createdAt"` + Role NodeRole `json:"role"` +} + +// NodeManager handles node identity operations including key generation and storage. +type NodeManager struct { + identity *NodeIdentity + privateKey []byte // Never serialized to JSON + keyPair *stmf.KeyPair + keyPath string // ~/.local/share/lethean-desktop/node/private.key + configPath string // ~/.config/lethean-desktop/node.json + mu sync.RWMutex +} + +// NewNodeManager creates a new NodeManager, loading existing identity if available. +func NewNodeManager() (*NodeManager, error) { + keyPath, err := xdg.DataFile("lethean-desktop/node/private.key") + if err != nil { + return nil, fmt.Errorf("failed to get key path: %w", err) + } + + configPath, err := xdg.ConfigFile("lethean-desktop/node.json") + if err != nil { + return nil, fmt.Errorf("failed to get config path: %w", err) + } + + return NewNodeManagerWithPaths(keyPath, configPath) +} + +// NewNodeManagerWithPaths creates a NodeManager with custom paths. +// This is primarily useful for testing to avoid xdg path caching issues. +func NewNodeManagerWithPaths(keyPath, configPath string) (*NodeManager, error) { + nm := &NodeManager{ + keyPath: keyPath, + configPath: configPath, + } + + // Try to load existing identity + if err := nm.loadIdentity(); err != nil { + // Identity doesn't exist yet, that's ok + return nm, nil + } + + return nm, nil +} + +// HasIdentity returns true if a node identity has been initialized. +func (n *NodeManager) HasIdentity() bool { + n.mu.RLock() + defer n.mu.RUnlock() + return n.identity != nil +} + +// GetIdentity returns the node's public identity. +func (n *NodeManager) GetIdentity() *NodeIdentity { + n.mu.RLock() + defer n.mu.RUnlock() + if n.identity == nil { + return nil + } + // Return a copy to prevent mutation + identity := *n.identity + return &identity +} + +// GenerateIdentity creates a new node identity with the given name and role. +func (n *NodeManager) GenerateIdentity(name string, role NodeRole) error { + n.mu.Lock() + defer n.mu.Unlock() + + // Generate X25519 keypair using STMF + keyPair, err := stmf.GenerateKeyPair() + if err != nil { + return fmt.Errorf("failed to generate keypair: %w", err) + } + + // Derive node ID from public key (first 16 bytes as hex = 32 char ID) + pubKeyBytes := keyPair.PublicKey() + hash := sha256.Sum256(pubKeyBytes) + nodeID := hex.EncodeToString(hash[:16]) + + n.identity = &NodeIdentity{ + ID: nodeID, + Name: name, + PublicKey: keyPair.PublicKeyBase64(), + CreatedAt: time.Now(), + Role: role, + } + + n.keyPair = keyPair + n.privateKey = keyPair.PrivateKey() + + // Save private key + if err := n.savePrivateKey(); err != nil { + return fmt.Errorf("failed to save private key: %w", err) + } + + // Save identity config + if err := n.saveIdentity(); err != nil { + return fmt.Errorf("failed to save identity: %w", err) + } + + return nil +} + +// DeriveSharedSecret derives a shared secret with a peer using X25519 ECDH. +// The result is hashed with SHA-256 for use as a symmetric key. +func (n *NodeManager) DeriveSharedSecret(peerPubKeyBase64 string) ([]byte, error) { + n.mu.RLock() + defer n.mu.RUnlock() + + if n.privateKey == nil { + return nil, fmt.Errorf("node identity not initialized") + } + + // Load peer's public key + peerPubKey, err := stmf.LoadPublicKeyBase64(peerPubKeyBase64) + if err != nil { + return nil, fmt.Errorf("failed to load peer public key: %w", err) + } + + // Load our private key + privateKey, err := ecdh.X25519().NewPrivateKey(n.privateKey) + if err != nil { + return nil, fmt.Errorf("failed to load private key: %w", err) + } + + // Derive shared secret using ECDH + sharedSecret, err := privateKey.ECDH(peerPubKey) + if err != nil { + return nil, fmt.Errorf("failed to derive shared secret: %w", err) + } + + // Hash the shared secret using SHA-256 (same pattern as Borg/trix) + hash := sha256.Sum256(sharedSecret) + return hash[:], nil +} + +// savePrivateKey saves the private key to disk with restricted permissions. +func (n *NodeManager) savePrivateKey() error { + // Ensure directory exists + dir := filepath.Dir(n.keyPath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create key directory: %w", err) + } + + // Write private key with restricted permissions (0600) + if err := os.WriteFile(n.keyPath, n.privateKey, 0600); err != nil { + return fmt.Errorf("failed to write private key: %w", err) + } + + return nil +} + +// saveIdentity saves the public identity to the config file. +func (n *NodeManager) saveIdentity() error { + // Ensure directory exists + dir := filepath.Dir(n.configPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + data, err := json.MarshalIndent(n.identity, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal identity: %w", err) + } + + if err := os.WriteFile(n.configPath, data, 0644); err != nil { + return fmt.Errorf("failed to write identity: %w", err) + } + + return nil +} + +// loadIdentity loads the node identity from disk. +func (n *NodeManager) loadIdentity() error { + // Load identity config + data, err := os.ReadFile(n.configPath) + if err != nil { + return fmt.Errorf("failed to read identity: %w", err) + } + + var identity NodeIdentity + if err := json.Unmarshal(data, &identity); err != nil { + return fmt.Errorf("failed to unmarshal identity: %w", err) + } + + // Load private key + privateKey, err := os.ReadFile(n.keyPath) + if err != nil { + return fmt.Errorf("failed to read private key: %w", err) + } + + // Reconstruct keypair from private key + keyPair, err := stmf.LoadKeyPair(privateKey) + if err != nil { + return fmt.Errorf("failed to load keypair: %w", err) + } + + n.identity = &identity + n.privateKey = privateKey + n.keyPair = keyPair + + return nil +} + +// Delete removes the node identity and keys from disk. +func (n *NodeManager) Delete() error { + n.mu.Lock() + defer n.mu.Unlock() + + // Remove private key + if err := os.Remove(n.keyPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove private key: %w", err) + } + + // Remove identity config + if err := os.Remove(n.configPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove identity: %w", err) + } + + n.identity = nil + n.privateKey = nil + n.keyPair = nil + + return nil +} diff --git a/node/identity_test.go b/node/identity_test.go new file mode 100644 index 0000000..fb0dce9 --- /dev/null +++ b/node/identity_test.go @@ -0,0 +1,353 @@ +package node + +import ( + "os" + "path/filepath" + "testing" +) + +// setupTestNodeManager creates a NodeManager with paths in a temp directory. +func setupTestNodeManager(t *testing.T) (*NodeManager, func()) { + tmpDir, err := os.MkdirTemp("", "node-identity-test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + + keyPath := filepath.Join(tmpDir, "private.key") + configPath := filepath.Join(tmpDir, "node.json") + + nm, err := NewNodeManagerWithPaths(keyPath, configPath) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("failed to create node manager: %v", err) + } + + cleanup := func() { + os.RemoveAll(tmpDir) + } + + return nm, cleanup +} + +func TestNodeIdentity(t *testing.T) { + t.Run("NewNodeManager", func(t *testing.T) { + nm, cleanup := setupTestNodeManager(t) + defer cleanup() + + if nm.HasIdentity() { + t.Error("new node manager should not have identity") + } + }) + + t.Run("GenerateIdentity", func(t *testing.T) { + nm, cleanup := setupTestNodeManager(t) + defer cleanup() + + err := nm.GenerateIdentity("test-node", RoleDual) + if err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + if !nm.HasIdentity() { + t.Error("node manager should have identity after generation") + } + + identity := nm.GetIdentity() + if identity == nil { + t.Fatal("identity should not be nil") + } + + if identity.Name != "test-node" { + t.Errorf("expected name 'test-node', got '%s'", identity.Name) + } + + if identity.Role != RoleDual { + t.Errorf("expected role Dual, got '%s'", identity.Role) + } + + if identity.ID == "" { + t.Error("identity ID should not be empty") + } + + if identity.PublicKey == "" { + t.Error("public key should not be empty") + } + }) + + t.Run("LoadExistingIdentity", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "node-load-test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + keyPath := filepath.Join(tmpDir, "private.key") + configPath := filepath.Join(tmpDir, "node.json") + + // First, create an identity + nm1, err := NewNodeManagerWithPaths(keyPath, configPath) + if err != nil { + t.Fatalf("failed to create first node manager: %v", err) + } + + err = nm1.GenerateIdentity("persistent-node", RoleWorker) + if err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + originalID := nm1.GetIdentity().ID + originalPubKey := nm1.GetIdentity().PublicKey + + // Create a new manager - should load existing identity + nm2, err := NewNodeManagerWithPaths(keyPath, configPath) + if err != nil { + t.Fatalf("failed to create second node manager: %v", err) + } + + if !nm2.HasIdentity() { + t.Error("second node manager should have loaded existing identity") + } + + identity := nm2.GetIdentity() + if identity.ID != originalID { + t.Errorf("expected ID '%s', got '%s'", originalID, identity.ID) + } + + if identity.PublicKey != originalPubKey { + t.Error("public key mismatch after reload") + } + }) + + t.Run("DeriveSharedSecret", func(t *testing.T) { + // Create two node managers with separate temp directories + tmpDir1, _ := os.MkdirTemp("", "node1") + tmpDir2, _ := os.MkdirTemp("", "node2") + defer os.RemoveAll(tmpDir1) + defer os.RemoveAll(tmpDir2) + + // Node 1 + nm1, err := NewNodeManagerWithPaths( + filepath.Join(tmpDir1, "private.key"), + filepath.Join(tmpDir1, "node.json"), + ) + if err != nil { + t.Fatalf("failed to create node manager 1: %v", err) + } + err = nm1.GenerateIdentity("node1", RoleDual) + if err != nil { + t.Fatalf("failed to generate identity 1: %v", err) + } + + // Node 2 + nm2, err := NewNodeManagerWithPaths( + filepath.Join(tmpDir2, "private.key"), + filepath.Join(tmpDir2, "node.json"), + ) + if err != nil { + t.Fatalf("failed to create node manager 2: %v", err) + } + err = nm2.GenerateIdentity("node2", RoleDual) + if err != nil { + t.Fatalf("failed to generate identity 2: %v", err) + } + + // Derive shared secrets - should be identical + secret1, err := nm1.DeriveSharedSecret(nm2.GetIdentity().PublicKey) + if err != nil { + t.Fatalf("failed to derive shared secret from node 1: %v", err) + } + + secret2, err := nm2.DeriveSharedSecret(nm1.GetIdentity().PublicKey) + if err != nil { + t.Fatalf("failed to derive shared secret from node 2: %v", err) + } + + if len(secret1) != len(secret2) { + t.Errorf("shared secrets have different lengths: %d vs %d", len(secret1), len(secret2)) + } + + for i := range secret1 { + if secret1[i] != secret2[i] { + t.Error("shared secrets do not match") + break + } + } + }) + + t.Run("DeleteIdentity", func(t *testing.T) { + nm, cleanup := setupTestNodeManager(t) + defer cleanup() + + err := nm.GenerateIdentity("delete-me", RoleDual) + if err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + if !nm.HasIdentity() { + t.Error("should have identity before delete") + } + + err = nm.Delete() + if err != nil { + t.Fatalf("failed to delete identity: %v", err) + } + + if nm.HasIdentity() { + t.Error("should not have identity after delete") + } + }) +} + +func TestNodeRoles(t *testing.T) { + tests := []struct { + role NodeRole + expected string + }{ + {RoleController, "controller"}, + {RoleWorker, "worker"}, + {RoleDual, "dual"}, + } + + for _, tt := range tests { + t.Run(string(tt.role), func(t *testing.T) { + if string(tt.role) != tt.expected { + t.Errorf("expected '%s', got '%s'", tt.expected, string(tt.role)) + } + }) + } +} + +func TestChallengeResponse(t *testing.T) { + t.Run("GenerateChallenge", func(t *testing.T) { + challenge, err := GenerateChallenge() + if err != nil { + t.Fatalf("failed to generate challenge: %v", err) + } + + if len(challenge) != ChallengeSize { + t.Errorf("expected challenge size %d, got %d", ChallengeSize, len(challenge)) + } + + // Ensure challenges are unique (not all zeros) + allZero := true + for _, b := range challenge { + if b != 0 { + allZero = false + break + } + } + if allZero { + t.Error("challenge should not be all zeros") + } + + // Generate another and ensure they're different + challenge2, err := GenerateChallenge() + if err != nil { + t.Fatalf("failed to generate second challenge: %v", err) + } + + same := true + for i := range challenge { + if challenge[i] != challenge2[i] { + same = false + break + } + } + if same { + t.Error("two generated challenges should be different") + } + }) + + t.Run("SignAndVerifyChallenge", func(t *testing.T) { + challenge, _ := GenerateChallenge() + sharedSecret := []byte("test-secret-key-32-bytes-long!!") + + // Sign the challenge + signature := SignChallenge(challenge, sharedSecret) + + if len(signature) == 0 { + t.Error("signature should not be empty") + } + + // Verify should succeed with correct parameters + if !VerifyChallenge(challenge, signature, sharedSecret) { + t.Error("verification should succeed with correct parameters") + } + + // Verify should fail with wrong challenge + wrongChallenge, _ := GenerateChallenge() + if VerifyChallenge(wrongChallenge, signature, sharedSecret) { + t.Error("verification should fail with wrong challenge") + } + + // Verify should fail with wrong secret + wrongSecret := []byte("wrong-secret-key-32-bytes-long!") + if VerifyChallenge(challenge, signature, wrongSecret) { + t.Error("verification should fail with wrong secret") + } + + // Verify should fail with tampered signature + tamperedSig := make([]byte, len(signature)) + copy(tamperedSig, signature) + tamperedSig[0] ^= 0xFF // Flip bits + if VerifyChallenge(challenge, tamperedSig, sharedSecret) { + t.Error("verification should fail with tampered signature") + } + }) + + t.Run("SignatureIsDeterministic", func(t *testing.T) { + challenge := []byte("fixed-challenge-for-testing") + sharedSecret := []byte("fixed-secret-key-for-testing") + + sig1 := SignChallenge(challenge, sharedSecret) + sig2 := SignChallenge(challenge, sharedSecret) + + if len(sig1) != len(sig2) { + t.Fatal("signatures should have same length") + } + + for i := range sig1 { + if sig1[i] != sig2[i] { + t.Fatal("signatures should be identical for same inputs") + } + } + }) + + t.Run("IntegrationWithSharedSecret", func(t *testing.T) { + // Create two nodes and test end-to-end challenge-response + tmpDir1, _ := os.MkdirTemp("", "node-challenge-1") + tmpDir2, _ := os.MkdirTemp("", "node-challenge-2") + defer os.RemoveAll(tmpDir1) + defer os.RemoveAll(tmpDir2) + + nm1, _ := NewNodeManagerWithPaths( + filepath.Join(tmpDir1, "private.key"), + filepath.Join(tmpDir1, "node.json"), + ) + nm1.GenerateIdentity("challenger", RoleDual) + + nm2, _ := NewNodeManagerWithPaths( + filepath.Join(tmpDir2, "private.key"), + filepath.Join(tmpDir2, "node.json"), + ) + nm2.GenerateIdentity("responder", RoleDual) + + // Challenger generates challenge + challenge, err := GenerateChallenge() + if err != nil { + t.Fatalf("failed to generate challenge: %v", err) + } + + // Both derive the same shared secret + secret1, _ := nm1.DeriveSharedSecret(nm2.GetIdentity().PublicKey) + secret2, _ := nm2.DeriveSharedSecret(nm1.GetIdentity().PublicKey) + + // Responder signs challenge with their derived secret + response := SignChallenge(challenge, secret2) + + // Challenger verifies with their derived secret + if !VerifyChallenge(challenge, response, secret1) { + t.Error("challenge-response should verify with matching shared secrets") + } + }) +} diff --git a/node/message.go b/node/message.go new file mode 100644 index 0000000..6f343df --- /dev/null +++ b/node/message.go @@ -0,0 +1,237 @@ +package node + +import ( + "encoding/json" + "time" + + "github.com/google/uuid" +) + +// Protocol version constants +const ( + // ProtocolVersion is the current protocol version + ProtocolVersion = "1.0" + // MinProtocolVersion is the minimum supported version + MinProtocolVersion = "1.0" +) + +// SupportedProtocolVersions lists all protocol versions this node supports. +// Used for version negotiation during handshake. +var SupportedProtocolVersions = []string{"1.0"} + +// IsProtocolVersionSupported checks if a given version is supported. +func IsProtocolVersionSupported(version string) bool { + for _, v := range SupportedProtocolVersions { + if v == version { + return true + } + } + return false +} + +// MessageType defines the type of P2P message. +type MessageType string + +const ( + // Connection lifecycle + MsgHandshake MessageType = "handshake" + MsgHandshakeAck MessageType = "handshake_ack" + MsgPing MessageType = "ping" + MsgPong MessageType = "pong" + MsgDisconnect MessageType = "disconnect" + + // Miner operations + MsgGetStats MessageType = "get_stats" + MsgStats MessageType = "stats" + MsgStartMiner MessageType = "start_miner" + MsgStopMiner MessageType = "stop_miner" + MsgMinerAck MessageType = "miner_ack" + + // Deployment + MsgDeploy MessageType = "deploy" + MsgDeployAck MessageType = "deploy_ack" + + // Logs + MsgGetLogs MessageType = "get_logs" + MsgLogs MessageType = "logs" + + // Error response + MsgError MessageType = "error" +) + +// Message represents a P2P message between nodes. +type Message struct { + ID string `json:"id"` // UUID + Type MessageType `json:"type"` + From string `json:"from"` // Sender node ID + To string `json:"to"` // Recipient node ID (empty for broadcast) + Timestamp time.Time `json:"ts"` + Payload json.RawMessage `json:"payload"` + ReplyTo string `json:"replyTo,omitempty"` // ID of message being replied to +} + +// NewMessage creates a new message with a generated ID and timestamp. +func NewMessage(msgType MessageType, from, to string, payload interface{}) (*Message, error) { + var payloadBytes json.RawMessage + if payload != nil { + data, err := MarshalJSON(payload) + if err != nil { + return nil, err + } + payloadBytes = data + } + + return &Message{ + ID: uuid.New().String(), + Type: msgType, + From: from, + To: to, + Timestamp: time.Now(), + Payload: payloadBytes, + }, nil +} + +// Reply creates a reply message to this message. +func (m *Message) Reply(msgType MessageType, payload interface{}) (*Message, error) { + reply, err := NewMessage(msgType, m.To, m.From, payload) + if err != nil { + return nil, err + } + reply.ReplyTo = m.ID + return reply, nil +} + +// ParsePayload unmarshals the payload into the given struct. +func (m *Message) ParsePayload(v interface{}) error { + if m.Payload == nil { + return nil + } + return json.Unmarshal(m.Payload, v) +} + +// --- Payload Types --- + +// HandshakePayload is sent during connection establishment. +type HandshakePayload struct { + Identity NodeIdentity `json:"identity"` + Challenge []byte `json:"challenge,omitempty"` // Random bytes for auth + Version string `json:"version"` // Protocol version +} + +// HandshakeAckPayload is the response to a handshake. +type HandshakeAckPayload struct { + Identity NodeIdentity `json:"identity"` + ChallengeResponse []byte `json:"challengeResponse,omitempty"` + Accepted bool `json:"accepted"` + Reason string `json:"reason,omitempty"` // If not accepted +} + +// PingPayload for keepalive/latency measurement. +type PingPayload struct { + SentAt int64 `json:"sentAt"` // Unix timestamp in milliseconds +} + +// PongPayload response to ping. +type PongPayload struct { + SentAt int64 `json:"sentAt"` // Echo of ping's sentAt + ReceivedAt int64 `json:"receivedAt"` // When ping was received +} + +// StartMinerPayload requests starting a miner. +type StartMinerPayload struct { + MinerType string `json:"minerType"` // Required: miner type (e.g., "xmrig", "tt-miner") + ProfileID string `json:"profileId,omitempty"` + Config json.RawMessage `json:"config,omitempty"` // Override profile config +} + +// StopMinerPayload requests stopping a miner. +type StopMinerPayload struct { + MinerName string `json:"minerName"` +} + +// MinerAckPayload acknowledges a miner start/stop operation. +type MinerAckPayload struct { + Success bool `json:"success"` + MinerName string `json:"minerName,omitempty"` + Error string `json:"error,omitempty"` +} + +// MinerStatsItem represents stats for a single miner. +type MinerStatsItem struct { + Name string `json:"name"` + Type string `json:"type"` + Hashrate float64 `json:"hashrate"` + Shares int `json:"shares"` + Rejected int `json:"rejected"` + Uptime int `json:"uptime"` // Seconds + Pool string `json:"pool"` + Algorithm string `json:"algorithm"` + CPUThreads int `json:"cpuThreads,omitempty"` +} + +// StatsPayload contains miner statistics. +type StatsPayload struct { + NodeID string `json:"nodeId"` + NodeName string `json:"nodeName"` + Miners []MinerStatsItem `json:"miners"` + Uptime int64 `json:"uptime"` // Node uptime in seconds +} + +// GetLogsPayload requests console logs from a miner. +type GetLogsPayload struct { + MinerName string `json:"minerName"` + Lines int `json:"lines"` // Number of lines to fetch + Since int64 `json:"since,omitempty"` // Unix timestamp, logs after this time +} + +// LogsPayload contains console log lines. +type LogsPayload struct { + MinerName string `json:"minerName"` + Lines []string `json:"lines"` + HasMore bool `json:"hasMore"` // More logs available +} + +// DeployPayload contains a deployment bundle. +type DeployPayload struct { + BundleType string `json:"type"` // "profile" | "miner" | "full" + Data []byte `json:"data"` // STIM-encrypted bundle + Checksum string `json:"checksum"` // SHA-256 of Data + Name string `json:"name"` // Profile or miner name +} + +// DeployAckPayload acknowledges a deployment. +type DeployAckPayload struct { + Success bool `json:"success"` + Name string `json:"name,omitempty"` + Error string `json:"error,omitempty"` +} + +// ErrorPayload contains error information. +type ErrorPayload struct { + Code int `json:"code"` + Message string `json:"message"` + Details string `json:"details,omitempty"` +} + +// Common error codes +const ( + ErrCodeUnknown = 1000 + ErrCodeInvalidMessage = 1001 + ErrCodeUnauthorized = 1002 + ErrCodeNotFound = 1003 + ErrCodeOperationFailed = 1004 + ErrCodeTimeout = 1005 +) + +// NewErrorMessage creates an error response message. +func NewErrorMessage(from, to string, code int, message string, replyTo string) (*Message, error) { + msg, err := NewMessage(MsgError, from, to, ErrorPayload{ + Code: code, + Message: message, + }) + if err != nil { + return nil, err + } + msg.ReplyTo = replyTo + return msg, nil +} diff --git a/node/message_test.go b/node/message_test.go new file mode 100644 index 0000000..6f68ffc --- /dev/null +++ b/node/message_test.go @@ -0,0 +1,284 @@ +package node + +import ( + "encoding/json" + "testing" + "time" +) + +func TestNewMessage(t *testing.T) { + t.Run("BasicMessage", func(t *testing.T) { + msg, err := NewMessage(MsgPing, "sender-id", "receiver-id", nil) + if err != nil { + t.Fatalf("failed to create message: %v", err) + } + + if msg.Type != MsgPing { + t.Errorf("expected type MsgPing, got %s", msg.Type) + } + + if msg.From != "sender-id" { + t.Errorf("expected from 'sender-id', got '%s'", msg.From) + } + + if msg.To != "receiver-id" { + t.Errorf("expected to 'receiver-id', got '%s'", msg.To) + } + + if msg.ID == "" { + t.Error("message ID should not be empty") + } + + if msg.Timestamp.IsZero() { + t.Error("timestamp should be set") + } + }) + + t.Run("MessageWithPayload", func(t *testing.T) { + payload := PingPayload{ + SentAt: time.Now().UnixMilli(), + } + + msg, err := NewMessage(MsgPing, "sender", "receiver", payload) + if err != nil { + t.Fatalf("failed to create message: %v", err) + } + + if msg.Payload == nil { + t.Error("payload should not be nil") + } + + var parsed PingPayload + err = msg.ParsePayload(&parsed) + if err != nil { + t.Fatalf("failed to parse payload: %v", err) + } + + if parsed.SentAt != payload.SentAt { + t.Errorf("expected SentAt %d, got %d", payload.SentAt, parsed.SentAt) + } + }) +} + +func TestMessageReply(t *testing.T) { + original, _ := NewMessage(MsgPing, "sender", "receiver", PingPayload{SentAt: 12345}) + + reply, err := original.Reply(MsgPong, PongPayload{ + SentAt: 12345, + ReceivedAt: 12350, + }) + + if err != nil { + t.Fatalf("failed to create reply: %v", err) + } + + if reply.ReplyTo != original.ID { + t.Errorf("reply should reference original message ID") + } + + if reply.From != original.To { + t.Error("reply From should be original To") + } + + if reply.To != original.From { + t.Error("reply To should be original From") + } + + if reply.Type != MsgPong { + t.Errorf("expected type MsgPong, got %s", reply.Type) + } +} + +func TestParsePayload(t *testing.T) { + t.Run("ValidPayload", func(t *testing.T) { + payload := StartMinerPayload{ + MinerType: "xmrig", + ProfileID: "test-profile", + } + + msg, _ := NewMessage(MsgStartMiner, "ctrl", "worker", payload) + + var parsed StartMinerPayload + err := msg.ParsePayload(&parsed) + if err != nil { + t.Fatalf("failed to parse payload: %v", err) + } + + if parsed.ProfileID != "test-profile" { + t.Errorf("expected ProfileID 'test-profile', got '%s'", parsed.ProfileID) + } + }) + + t.Run("NilPayload", func(t *testing.T) { + msg, _ := NewMessage(MsgGetStats, "ctrl", "worker", nil) + + var parsed StatsPayload + err := msg.ParsePayload(&parsed) + if err != nil { + t.Errorf("parsing nil payload should not error: %v", err) + } + }) + + t.Run("ComplexPayload", func(t *testing.T) { + stats := StatsPayload{ + NodeID: "node-123", + NodeName: "Test Node", + Miners: []MinerStatsItem{ + { + Name: "xmrig-1", + Type: "xmrig", + Hashrate: 1234.56, + Shares: 100, + Rejected: 2, + Uptime: 3600, + Pool: "pool.example.com:3333", + Algorithm: "RandomX", + }, + }, + Uptime: 86400, + } + + msg, _ := NewMessage(MsgStats, "worker", "ctrl", stats) + + var parsed StatsPayload + err := msg.ParsePayload(&parsed) + if err != nil { + t.Fatalf("failed to parse stats payload: %v", err) + } + + if parsed.NodeID != "node-123" { + t.Errorf("expected NodeID 'node-123', got '%s'", parsed.NodeID) + } + + if len(parsed.Miners) != 1 { + t.Fatalf("expected 1 miner, got %d", len(parsed.Miners)) + } + + if parsed.Miners[0].Hashrate != 1234.56 { + t.Errorf("expected hashrate 1234.56, got %f", parsed.Miners[0].Hashrate) + } + }) +} + +func TestNewErrorMessage(t *testing.T) { + errMsg, err := NewErrorMessage("sender", "receiver", ErrCodeOperationFailed, "something went wrong", "original-msg-id") + if err != nil { + t.Fatalf("failed to create error message: %v", err) + } + + if errMsg.Type != MsgError { + t.Errorf("expected type MsgError, got %s", errMsg.Type) + } + + if errMsg.ReplyTo != "original-msg-id" { + t.Errorf("expected ReplyTo 'original-msg-id', got '%s'", errMsg.ReplyTo) + } + + var errPayload ErrorPayload + err = errMsg.ParsePayload(&errPayload) + if err != nil { + t.Fatalf("failed to parse error payload: %v", err) + } + + if errPayload.Code != ErrCodeOperationFailed { + t.Errorf("expected code %d, got %d", ErrCodeOperationFailed, errPayload.Code) + } + + if errPayload.Message != "something went wrong" { + t.Errorf("expected message 'something went wrong', got '%s'", errPayload.Message) + } +} + +func TestMessageSerialization(t *testing.T) { + original, _ := NewMessage(MsgStartMiner, "ctrl", "worker", StartMinerPayload{ + MinerType: "xmrig", + ProfileID: "my-profile", + }) + + // Serialize + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("failed to serialize message: %v", err) + } + + // Deserialize + var restored Message + err = json.Unmarshal(data, &restored) + if err != nil { + t.Fatalf("failed to deserialize message: %v", err) + } + + if restored.ID != original.ID { + t.Error("ID mismatch after serialization") + } + + if restored.Type != original.Type { + t.Error("Type mismatch after serialization") + } + + if restored.From != original.From { + t.Error("From mismatch after serialization") + } + + var payload StartMinerPayload + err = restored.ParsePayload(&payload) + if err != nil { + t.Fatalf("failed to parse restored payload: %v", err) + } + + if payload.ProfileID != "my-profile" { + t.Errorf("expected ProfileID 'my-profile', got '%s'", payload.ProfileID) + } +} + +func TestMessageTypes(t *testing.T) { + types := []MessageType{ + MsgHandshake, + MsgHandshakeAck, + MsgPing, + MsgPong, + MsgDisconnect, + MsgGetStats, + MsgStats, + MsgStartMiner, + MsgStopMiner, + MsgMinerAck, + MsgDeploy, + MsgDeployAck, + MsgGetLogs, + MsgLogs, + MsgError, + } + + for _, msgType := range types { + t.Run(string(msgType), func(t *testing.T) { + msg, err := NewMessage(msgType, "from", "to", nil) + if err != nil { + t.Fatalf("failed to create message of type %s: %v", msgType, err) + } + + if msg.Type != msgType { + t.Errorf("expected type %s, got %s", msgType, msg.Type) + } + }) + } +} + +func TestErrorCodes(t *testing.T) { + codes := map[int]string{ + ErrCodeUnknown: "Unknown", + ErrCodeInvalidMessage: "InvalidMessage", + ErrCodeUnauthorized: "Unauthorized", + ErrCodeNotFound: "NotFound", + ErrCodeOperationFailed: "OperationFailed", + ErrCodeTimeout: "Timeout", + } + + for code, name := range codes { + t.Run(name, func(t *testing.T) { + if code < 1000 || code > 1999 { + t.Errorf("error code %d should be in 1000-1999 range", code) + } + }) + } +} diff --git a/node/peer.go b/node/peer.go new file mode 100644 index 0000000..27a1e7a --- /dev/null +++ b/node/peer.go @@ -0,0 +1,708 @@ +package node + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "sync" + "time" + + "forge.lthn.ai/core/go-p2p/logging" + "github.com/Snider/Poindexter" + "github.com/adrg/xdg" +) + +// Peer represents a known remote node. +type Peer struct { + ID string `json:"id"` + Name string `json:"name"` + PublicKey string `json:"publicKey"` + Address string `json:"address"` // host:port for WebSocket connection + Role NodeRole `json:"role"` + AddedAt time.Time `json:"addedAt"` + LastSeen time.Time `json:"lastSeen"` + + // Poindexter metrics (updated dynamically) + PingMS float64 `json:"pingMs"` // Latency in milliseconds + Hops int `json:"hops"` // Network hop count + GeoKM float64 `json:"geoKm"` // Geographic distance in kilometers + Score float64 `json:"score"` // Reliability score 0-100 + + // Connection state (not persisted) + Connected bool `json:"-"` +} + +// saveDebounceInterval is the minimum time between disk writes. +const saveDebounceInterval = 5 * time.Second + +// PeerAuthMode controls how unknown peers are handled +type PeerAuthMode int + +const ( + // PeerAuthOpen allows any peer to connect (original behavior) + PeerAuthOpen PeerAuthMode = iota + // PeerAuthAllowlist only allows pre-registered peers or those with allowed public keys + PeerAuthAllowlist +) + +// Peer name validation constants +const ( + PeerNameMinLength = 1 + PeerNameMaxLength = 64 +) + +// peerNameRegex validates peer names: alphanumeric, hyphens, underscores, and spaces +var peerNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9\-_ ]{0,62}[a-zA-Z0-9]$|^[a-zA-Z0-9]$`) + +// safeKeyPrefix returns a truncated key for logging, handling short keys safely +func safeKeyPrefix(key string) string { + if len(key) >= 16 { + return key[:16] + "..." + } + if len(key) == 0 { + return "(empty)" + } + return key +} + +// validatePeerName checks if a peer name is valid. +// Peer names must be 1-64 characters, start and end with alphanumeric, +// and contain only alphanumeric, hyphens, underscores, and spaces. +func validatePeerName(name string) error { + if name == "" { + return nil // Empty names are allowed (optional field) + } + if len(name) < PeerNameMinLength { + return fmt.Errorf("peer name too short (min %d characters)", PeerNameMinLength) + } + if len(name) > PeerNameMaxLength { + return fmt.Errorf("peer name too long (max %d characters)", PeerNameMaxLength) + } + if !peerNameRegex.MatchString(name) { + return fmt.Errorf("peer name contains invalid characters (use alphanumeric, hyphens, underscores, spaces)") + } + return nil +} + +// PeerRegistry manages known peers with KD-tree based selection. +type PeerRegistry struct { + peers map[string]*Peer + kdTree *poindexter.KDTree[string] // KD-tree with peer ID as payload + path string + mu sync.RWMutex + + // Authentication settings + authMode PeerAuthMode // How to handle unknown peers + allowedPublicKeys map[string]bool // Allowlist of public keys (when authMode is Allowlist) + allowedPublicKeyMu sync.RWMutex // Protects allowedPublicKeys + + // Debounce disk writes + dirty bool // Whether there are unsaved changes + saveTimer *time.Timer // Timer for debounced save + saveMu sync.Mutex // Protects dirty and saveTimer + stopChan chan struct{} // Signal to stop background save + saveStopOnce sync.Once // Ensure stopChan is closed only once +} + +// Dimension weights for peer selection +// Lower ping, hops, geo are better; higher score is better +var ( + pingWeight = 1.0 + hopsWeight = 0.7 + geoWeight = 0.2 + scoreWeight = 1.2 +) + +// NewPeerRegistry creates a new PeerRegistry, loading existing peers if available. +func NewPeerRegistry() (*PeerRegistry, error) { + peersPath, err := xdg.ConfigFile("lethean-desktop/peers.json") + if err != nil { + return nil, fmt.Errorf("failed to get peers path: %w", err) + } + + return NewPeerRegistryWithPath(peersPath) +} + +// NewPeerRegistryWithPath creates a new PeerRegistry with a custom path. +// This is primarily useful for testing to avoid xdg path caching issues. +func NewPeerRegistryWithPath(peersPath string) (*PeerRegistry, error) { + pr := &PeerRegistry{ + peers: make(map[string]*Peer), + path: peersPath, + stopChan: make(chan struct{}), + authMode: PeerAuthOpen, // Default to open for backward compatibility + allowedPublicKeys: make(map[string]bool), + } + + // Try to load existing peers + if err := pr.load(); err != nil { + // No existing peers, that's ok + pr.rebuildKDTree() + return pr, nil + } + + pr.rebuildKDTree() + return pr, nil +} + +// SetAuthMode sets the authentication mode for peer connections. +func (r *PeerRegistry) SetAuthMode(mode PeerAuthMode) { + r.allowedPublicKeyMu.Lock() + defer r.allowedPublicKeyMu.Unlock() + r.authMode = mode + logging.Info("peer auth mode changed", logging.Fields{"mode": mode}) +} + +// GetAuthMode returns the current authentication mode. +func (r *PeerRegistry) GetAuthMode() PeerAuthMode { + r.allowedPublicKeyMu.RLock() + defer r.allowedPublicKeyMu.RUnlock() + return r.authMode +} + +// AllowPublicKey adds a public key to the allowlist. +func (r *PeerRegistry) AllowPublicKey(publicKey string) { + r.allowedPublicKeyMu.Lock() + defer r.allowedPublicKeyMu.Unlock() + r.allowedPublicKeys[publicKey] = true + logging.Debug("public key added to allowlist", logging.Fields{"key": safeKeyPrefix(publicKey)}) +} + +// RevokePublicKey removes a public key from the allowlist. +func (r *PeerRegistry) RevokePublicKey(publicKey string) { + r.allowedPublicKeyMu.Lock() + defer r.allowedPublicKeyMu.Unlock() + delete(r.allowedPublicKeys, publicKey) + logging.Debug("public key removed from allowlist", logging.Fields{"key": safeKeyPrefix(publicKey)}) +} + +// IsPublicKeyAllowed checks if a public key is in the allowlist. +func (r *PeerRegistry) IsPublicKeyAllowed(publicKey string) bool { + r.allowedPublicKeyMu.RLock() + defer r.allowedPublicKeyMu.RUnlock() + return r.allowedPublicKeys[publicKey] +} + +// IsPeerAllowed checks if a peer is allowed to connect based on auth mode. +// Returns true if: +// - AuthMode is Open (allow all) +// - AuthMode is Allowlist AND (peer is pre-registered OR public key is allowlisted) +func (r *PeerRegistry) IsPeerAllowed(peerID string, publicKey string) bool { + r.allowedPublicKeyMu.RLock() + authMode := r.authMode + keyAllowed := r.allowedPublicKeys[publicKey] + r.allowedPublicKeyMu.RUnlock() + + // Open mode allows everyone + if authMode == PeerAuthOpen { + return true + } + + // Allowlist mode: check if peer is pre-registered + r.mu.RLock() + _, isRegistered := r.peers[peerID] + r.mu.RUnlock() + + if isRegistered { + return true + } + + // Check if public key is allowlisted + return keyAllowed +} + +// ListAllowedPublicKeys returns all allowlisted public keys. +func (r *PeerRegistry) ListAllowedPublicKeys() []string { + r.allowedPublicKeyMu.RLock() + defer r.allowedPublicKeyMu.RUnlock() + + keys := make([]string, 0, len(r.allowedPublicKeys)) + for key := range r.allowedPublicKeys { + keys = append(keys, key) + } + return keys +} + +// AddPeer adds a new peer to the registry. +// Note: Persistence is debounced (writes batched every 5s). Call Close() to ensure +// all changes are flushed to disk before shutdown. +func (r *PeerRegistry) AddPeer(peer *Peer) error { + r.mu.Lock() + + if peer.ID == "" { + r.mu.Unlock() + return fmt.Errorf("peer ID is required") + } + + // Validate peer name (P2P-LOW-3) + if err := validatePeerName(peer.Name); err != nil { + r.mu.Unlock() + return err + } + + if _, exists := r.peers[peer.ID]; exists { + r.mu.Unlock() + return fmt.Errorf("peer %s already exists", peer.ID) + } + + // Set defaults + if peer.AddedAt.IsZero() { + peer.AddedAt = time.Now() + } + if peer.Score == 0 { + peer.Score = 50 // Default neutral score + } + + r.peers[peer.ID] = peer + r.rebuildKDTree() + r.mu.Unlock() + + return r.save() +} + +// UpdatePeer updates an existing peer's information. +// Note: Persistence is debounced. Call Close() to flush before shutdown. +func (r *PeerRegistry) UpdatePeer(peer *Peer) error { + r.mu.Lock() + + if _, exists := r.peers[peer.ID]; !exists { + r.mu.Unlock() + return fmt.Errorf("peer %s not found", peer.ID) + } + + r.peers[peer.ID] = peer + r.rebuildKDTree() + r.mu.Unlock() + + return r.save() +} + +// RemovePeer removes a peer from the registry. +// Note: Persistence is debounced. Call Close() to flush before shutdown. +func (r *PeerRegistry) RemovePeer(id string) error { + r.mu.Lock() + + if _, exists := r.peers[id]; !exists { + r.mu.Unlock() + return fmt.Errorf("peer %s not found", id) + } + + delete(r.peers, id) + r.rebuildKDTree() + r.mu.Unlock() + + return r.save() +} + +// GetPeer returns a peer by ID. +func (r *PeerRegistry) GetPeer(id string) *Peer { + r.mu.RLock() + defer r.mu.RUnlock() + + peer, exists := r.peers[id] + if !exists { + return nil + } + + // Return a copy + peerCopy := *peer + return &peerCopy +} + +// ListPeers returns all registered peers. +func (r *PeerRegistry) ListPeers() []*Peer { + r.mu.RLock() + defer r.mu.RUnlock() + + peers := make([]*Peer, 0, len(r.peers)) + for _, peer := range r.peers { + peerCopy := *peer + peers = append(peers, &peerCopy) + } + return peers +} + +// UpdateMetrics updates a peer's performance metrics. +// Note: Persistence is debounced. Call Close() to flush before shutdown. +func (r *PeerRegistry) UpdateMetrics(id string, pingMS, geoKM float64, hops int) error { + r.mu.Lock() + + peer, exists := r.peers[id] + if !exists { + r.mu.Unlock() + return fmt.Errorf("peer %s not found", id) + } + + peer.PingMS = pingMS + peer.GeoKM = geoKM + peer.Hops = hops + peer.LastSeen = time.Now() + + r.rebuildKDTree() + r.mu.Unlock() + + return r.save() +} + +// UpdateScore updates a peer's reliability score. +// Note: Persistence is debounced. Call Close() to flush before shutdown. +func (r *PeerRegistry) UpdateScore(id string, score float64) error { + r.mu.Lock() + + peer, exists := r.peers[id] + if !exists { + r.mu.Unlock() + return fmt.Errorf("peer %s not found", id) + } + + // Clamp score to 0-100 + if score < 0 { + score = 0 + } else if score > 100 { + score = 100 + } + + peer.Score = score + r.rebuildKDTree() + r.mu.Unlock() + + return r.save() +} + +// SetConnected updates a peer's connection state. +func (r *PeerRegistry) SetConnected(id string, connected bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if peer, exists := r.peers[id]; exists { + peer.Connected = connected + if connected { + peer.LastSeen = time.Now() + } + } +} + +// Score adjustment constants +const ( + ScoreSuccessIncrement = 1.0 // Increment for successful interaction + ScoreFailureDecrement = 5.0 // Decrement for failed interaction + ScoreTimeoutDecrement = 3.0 // Decrement for timeout + ScoreMinimum = 0.0 // Minimum score + ScoreMaximum = 100.0 // Maximum score + ScoreDefault = 50.0 // Default score for new peers +) + +// RecordSuccess records a successful interaction with a peer, improving their score. +func (r *PeerRegistry) RecordSuccess(id string) { + r.mu.Lock() + peer, exists := r.peers[id] + if !exists { + r.mu.Unlock() + return + } + + peer.Score = min(peer.Score+ScoreSuccessIncrement, ScoreMaximum) + peer.LastSeen = time.Now() + r.mu.Unlock() + r.save() +} + +// RecordFailure records a failed interaction with a peer, reducing their score. +func (r *PeerRegistry) RecordFailure(id string) { + r.mu.Lock() + peer, exists := r.peers[id] + if !exists { + r.mu.Unlock() + return + } + + peer.Score = max(peer.Score-ScoreFailureDecrement, ScoreMinimum) + newScore := peer.Score + r.mu.Unlock() + r.save() + + logging.Debug("peer score decreased", logging.Fields{ + "peer_id": id, + "new_score": newScore, + "reason": "failure", + }) +} + +// RecordTimeout records a timeout when communicating with a peer. +func (r *PeerRegistry) RecordTimeout(id string) { + r.mu.Lock() + peer, exists := r.peers[id] + if !exists { + r.mu.Unlock() + return + } + + peer.Score = max(peer.Score-ScoreTimeoutDecrement, ScoreMinimum) + newScore := peer.Score + r.mu.Unlock() + r.save() + + logging.Debug("peer score decreased", logging.Fields{ + "peer_id": id, + "new_score": newScore, + "reason": "timeout", + }) +} + +// GetPeersByScore returns peers sorted by score (highest first). +func (r *PeerRegistry) GetPeersByScore() []*Peer { + r.mu.RLock() + defer r.mu.RUnlock() + + peers := make([]*Peer, 0, len(r.peers)) + for _, p := range r.peers { + peers = append(peers, p) + } + + // Sort by score descending + for i := 0; i < len(peers)-1; i++ { + for j := i + 1; j < len(peers); j++ { + if peers[j].Score > peers[i].Score { + peers[i], peers[j] = peers[j], peers[i] + } + } + } + + return peers +} + +// SelectOptimalPeer returns the best peer based on multi-factor optimization. +// Uses Poindexter KD-tree to find the peer closest to ideal metrics. +func (r *PeerRegistry) SelectOptimalPeer() *Peer { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.kdTree == nil || len(r.peers) == 0 { + return nil + } + + // Target: ideal peer (0 ping, 0 hops, 0 geo, 100 score) + // Score is inverted (100 - score) so lower is better in the tree + target := []float64{0, 0, 0, 0} + + result, _, found := r.kdTree.Nearest(target) + if !found { + return nil + } + + peer, exists := r.peers[result.Value] + if !exists { + return nil + } + + peerCopy := *peer + return &peerCopy +} + +// SelectNearestPeers returns the n best peers based on multi-factor optimization. +func (r *PeerRegistry) SelectNearestPeers(n int) []*Peer { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.kdTree == nil || len(r.peers) == 0 { + return nil + } + + // Target: ideal peer + target := []float64{0, 0, 0, 0} + + results, _ := r.kdTree.KNearest(target, n) + + peers := make([]*Peer, 0, len(results)) + for _, result := range results { + if peer, exists := r.peers[result.Value]; exists { + peerCopy := *peer + peers = append(peers, &peerCopy) + } + } + + return peers +} + +// GetConnectedPeers returns all currently connected peers. +func (r *PeerRegistry) GetConnectedPeers() []*Peer { + r.mu.RLock() + defer r.mu.RUnlock() + + peers := make([]*Peer, 0) + for _, peer := range r.peers { + if peer.Connected { + peerCopy := *peer + peers = append(peers, &peerCopy) + } + } + return peers +} + +// Count returns the number of registered peers. +func (r *PeerRegistry) Count() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.peers) +} + +// rebuildKDTree rebuilds the KD-tree from current peers. +// Must be called with lock held. +func (r *PeerRegistry) rebuildKDTree() { + if len(r.peers) == 0 { + r.kdTree = nil + return + } + + points := make([]poindexter.KDPoint[string], 0, len(r.peers)) + for _, peer := range r.peers { + // Build 4D point with weighted, normalized values + // Invert score so that higher score = lower value (better) + point := poindexter.KDPoint[string]{ + ID: peer.ID, + Coords: []float64{ + peer.PingMS * pingWeight, + float64(peer.Hops) * hopsWeight, + peer.GeoKM * geoWeight, + (100 - peer.Score) * scoreWeight, // Invert score + }, + Value: peer.ID, + } + points = append(points, point) + } + + // Build KD-tree with Euclidean distance + tree, err := poindexter.NewKDTree(points, poindexter.WithMetric(poindexter.EuclideanDistance{})) + if err != nil { + // Log error but continue - worst case we don't have optimal selection + return + } + + r.kdTree = tree +} + +// scheduleSave schedules a debounced save operation. +// Multiple calls within saveDebounceInterval will be coalesced into a single save. +// Must NOT be called with r.mu held. +func (r *PeerRegistry) scheduleSave() { + r.saveMu.Lock() + defer r.saveMu.Unlock() + + r.dirty = true + + // If timer already running, let it handle the save + if r.saveTimer != nil { + return + } + + // Start a new timer + r.saveTimer = time.AfterFunc(saveDebounceInterval, func() { + r.saveMu.Lock() + r.saveTimer = nil + shouldSave := r.dirty + r.dirty = false + r.saveMu.Unlock() + + if shouldSave { + r.mu.RLock() + err := r.saveNow() + r.mu.RUnlock() + if err != nil { + // Log error but continue - best effort persistence + logging.Warn("failed to save peer registry", logging.Fields{"error": err}) + } + } + }) +} + +// saveNow persists peers to disk immediately. +// Must be called with r.mu held (at least RLock). +func (r *PeerRegistry) saveNow() error { + // Ensure directory exists + dir := filepath.Dir(r.path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create peers directory: %w", err) + } + + // Convert to slice for JSON + peers := make([]*Peer, 0, len(r.peers)) + for _, peer := range r.peers { + peers = append(peers, peer) + } + + data, err := json.MarshalIndent(peers, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal peers: %w", err) + } + + // Use atomic write pattern: write to temp file, then rename + tmpPath := r.path + ".tmp" + if err := os.WriteFile(tmpPath, data, 0644); err != nil { + return fmt.Errorf("failed to write peers temp file: %w", err) + } + + if err := os.Rename(tmpPath, r.path); err != nil { + os.Remove(tmpPath) // Clean up temp file + return fmt.Errorf("failed to rename peers file: %w", err) + } + + return nil +} + +// Close flushes any pending changes and releases resources. +func (r *PeerRegistry) Close() error { + r.saveStopOnce.Do(func() { + close(r.stopChan) + }) + + // Cancel pending timer and save immediately if dirty + r.saveMu.Lock() + if r.saveTimer != nil { + r.saveTimer.Stop() + r.saveTimer = nil + } + shouldSave := r.dirty + r.dirty = false + r.saveMu.Unlock() + + if shouldSave { + r.mu.RLock() + err := r.saveNow() + r.mu.RUnlock() + return err + } + + return nil +} + +// save is a helper that schedules a debounced save. +// Kept for backward compatibility but now debounces writes. +// Must NOT be called with r.mu held. +func (r *PeerRegistry) save() error { + r.scheduleSave() + return nil // Errors will be logged asynchronously +} + +// load reads peers from disk. +func (r *PeerRegistry) load() error { + data, err := os.ReadFile(r.path) + if err != nil { + return fmt.Errorf("failed to read peers: %w", err) + } + + var peers []*Peer + if err := json.Unmarshal(data, &peers); err != nil { + return fmt.Errorf("failed to unmarshal peers: %w", err) + } + + r.peers = make(map[string]*Peer) + for _, peer := range peers { + r.peers[peer.ID] = peer + } + + return nil +} + +// Example usage inside a connection handler diff --git a/node/peer_test.go b/node/peer_test.go new file mode 100644 index 0000000..041f2e0 --- /dev/null +++ b/node/peer_test.go @@ -0,0 +1,639 @@ +package node + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func setupTestPeerRegistry(t *testing.T) (*PeerRegistry, func()) { + tmpDir, err := os.MkdirTemp("", "peer-registry-test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + + peersPath := filepath.Join(tmpDir, "peers.json") + + pr, err := NewPeerRegistryWithPath(peersPath) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("failed to create peer registry: %v", err) + } + + cleanup := func() { + os.RemoveAll(tmpDir) + } + + return pr, cleanup +} + +func TestPeerRegistry_NewPeerRegistry(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + if pr.Count() != 0 { + t.Errorf("expected 0 peers, got %d", pr.Count()) + } +} + +func TestPeerRegistry_AddPeer(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + peer := &Peer{ + ID: "test-peer-1", + Name: "Test Peer", + PublicKey: "testkey123", + Address: "192.168.1.100:9091", + Role: RoleWorker, + Score: 75, + } + + err := pr.AddPeer(peer) + if err != nil { + t.Fatalf("failed to add peer: %v", err) + } + + if pr.Count() != 1 { + t.Errorf("expected 1 peer, got %d", pr.Count()) + } + + // Try to add duplicate + err = pr.AddPeer(peer) + if err == nil { + t.Error("expected error when adding duplicate peer") + } +} + +func TestPeerRegistry_GetPeer(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + peer := &Peer{ + ID: "get-test-peer", + Name: "Get Test", + PublicKey: "getkey123", + Address: "10.0.0.1:9091", + Role: RoleDual, + } + + pr.AddPeer(peer) + + retrieved := pr.GetPeer("get-test-peer") + if retrieved == nil { + t.Fatal("failed to retrieve peer") + } + + if retrieved.Name != "Get Test" { + t.Errorf("expected name 'Get Test', got '%s'", retrieved.Name) + } + + // Non-existent peer + nonExistent := pr.GetPeer("non-existent") + if nonExistent != nil { + t.Error("expected nil for non-existent peer") + } +} + +func TestPeerRegistry_ListPeers(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + peers := []*Peer{ + {ID: "list-1", Name: "Peer 1", Address: "1.1.1.1:9091", Role: RoleWorker}, + {ID: "list-2", Name: "Peer 2", Address: "2.2.2.2:9091", Role: RoleWorker}, + {ID: "list-3", Name: "Peer 3", Address: "3.3.3.3:9091", Role: RoleController}, + } + + for _, p := range peers { + pr.AddPeer(p) + } + + listed := pr.ListPeers() + if len(listed) != 3 { + t.Errorf("expected 3 peers, got %d", len(listed)) + } +} + +func TestPeerRegistry_RemovePeer(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + peer := &Peer{ + ID: "remove-test", + Name: "Remove Me", + Address: "5.5.5.5:9091", + Role: RoleWorker, + } + + pr.AddPeer(peer) + + if pr.Count() != 1 { + t.Error("peer should exist before removal") + } + + err := pr.RemovePeer("remove-test") + if err != nil { + t.Fatalf("failed to remove peer: %v", err) + } + + if pr.Count() != 0 { + t.Error("peer should be removed") + } + + // Remove non-existent + err = pr.RemovePeer("non-existent") + if err == nil { + t.Error("expected error when removing non-existent peer") + } +} + +func TestPeerRegistry_UpdateMetrics(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + peer := &Peer{ + ID: "metrics-test", + Name: "Metrics Peer", + Address: "6.6.6.6:9091", + Role: RoleWorker, + } + + pr.AddPeer(peer) + + err := pr.UpdateMetrics("metrics-test", 50.5, 100.2, 3) + if err != nil { + t.Fatalf("failed to update metrics: %v", err) + } + + updated := pr.GetPeer("metrics-test") + if updated == nil { + t.Fatal("expected peer to exist") + } + if updated.PingMS != 50.5 { + t.Errorf("expected ping 50.5, got %f", updated.PingMS) + } + if updated.GeoKM != 100.2 { + t.Errorf("expected geo 100.2, got %f", updated.GeoKM) + } + if updated.Hops != 3 { + t.Errorf("expected hops 3, got %d", updated.Hops) + } +} + +func TestPeerRegistry_UpdateScore(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + peer := &Peer{ + ID: "score-test", + Name: "Score Peer", + Score: 50, + } + + pr.AddPeer(peer) + + err := pr.UpdateScore("score-test", 85.5) + if err != nil { + t.Fatalf("failed to update score: %v", err) + } + + updated := pr.GetPeer("score-test") + if updated == nil { + t.Fatal("expected peer to exist") + } + if updated.Score != 85.5 { + t.Errorf("expected score 85.5, got %f", updated.Score) + } + + // Test clamping - over 100 + err = pr.UpdateScore("score-test", 150) + if err != nil { + t.Fatalf("failed to update score: %v", err) + } + + updated = pr.GetPeer("score-test") + if updated == nil { + t.Fatal("expected peer to exist") + } + if updated.Score != 100 { + t.Errorf("expected score clamped to 100, got %f", updated.Score) + } + + // Test clamping - below 0 + err = pr.UpdateScore("score-test", -50) + if err != nil { + t.Fatalf("failed to update score: %v", err) + } + + updated = pr.GetPeer("score-test") + if updated == nil { + t.Fatal("expected peer to exist") + } + if updated.Score != 0 { + t.Errorf("expected score clamped to 0, got %f", updated.Score) + } +} + +func TestPeerRegistry_SetConnected(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + peer := &Peer{ + ID: "connect-test", + Name: "Connect Peer", + Connected: false, + } + + pr.AddPeer(peer) + + pr.SetConnected("connect-test", true) + + updated := pr.GetPeer("connect-test") + if updated == nil { + t.Fatal("expected peer to exist") + } + if !updated.Connected { + t.Error("peer should be connected") + } + if updated.LastSeen.IsZero() { + t.Error("LastSeen should be set when connected") + } + + pr.SetConnected("connect-test", false) + updated = pr.GetPeer("connect-test") + if updated == nil { + t.Fatal("expected peer to exist") + } + if updated.Connected { + t.Error("peer should be disconnected") + } +} + +func TestPeerRegistry_GetConnectedPeers(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + peers := []*Peer{ + {ID: "conn-1", Name: "Peer 1"}, + {ID: "conn-2", Name: "Peer 2"}, + {ID: "conn-3", Name: "Peer 3"}, + } + + for _, p := range peers { + pr.AddPeer(p) + } + + pr.SetConnected("conn-1", true) + pr.SetConnected("conn-3", true) + + connected := pr.GetConnectedPeers() + if len(connected) != 2 { + t.Errorf("expected 2 connected peers, got %d", len(connected)) + } +} + +func TestPeerRegistry_SelectOptimalPeer(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + // Add peers with different metrics + peers := []*Peer{ + {ID: "opt-1", Name: "Slow Peer", PingMS: 200, Hops: 5, GeoKM: 1000, Score: 50}, + {ID: "opt-2", Name: "Fast Peer", PingMS: 10, Hops: 1, GeoKM: 50, Score: 90}, + {ID: "opt-3", Name: "Medium Peer", PingMS: 50, Hops: 2, GeoKM: 200, Score: 70}, + } + + for _, p := range peers { + pr.AddPeer(p) + } + + optimal := pr.SelectOptimalPeer() + if optimal == nil { + t.Fatal("expected to find an optimal peer") + } + + // The "Fast Peer" should be selected as optimal + if optimal.ID != "opt-2" { + t.Errorf("expected 'opt-2' (Fast Peer) to be optimal, got '%s' (%s)", optimal.ID, optimal.Name) + } +} + +func TestPeerRegistry_SelectNearestPeers(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + peers := []*Peer{ + {ID: "near-1", Name: "Peer 1", PingMS: 100, Score: 50}, + {ID: "near-2", Name: "Peer 2", PingMS: 10, Score: 90}, + {ID: "near-3", Name: "Peer 3", PingMS: 50, Score: 70}, + {ID: "near-4", Name: "Peer 4", PingMS: 200, Score: 30}, + } + + for _, p := range peers { + pr.AddPeer(p) + } + + nearest := pr.SelectNearestPeers(2) + if len(nearest) != 2 { + t.Errorf("expected 2 nearest peers, got %d", len(nearest)) + } +} + +func TestPeerRegistry_Persistence(t *testing.T) { + tmpDir, _ := os.MkdirTemp("", "persist-test") + defer os.RemoveAll(tmpDir) + + peersPath := filepath.Join(tmpDir, "peers.json") + + // Create and save + pr1, err := NewPeerRegistryWithPath(peersPath) + if err != nil { + t.Fatalf("failed to create first registry: %v", err) + } + + peer := &Peer{ + ID: "persist-test", + Name: "Persistent Peer", + Address: "7.7.7.7:9091", + Role: RoleDual, + AddedAt: time.Now(), + } + + pr1.AddPeer(peer) + + // Flush pending changes before reloading + if err := pr1.Close(); err != nil { + t.Fatalf("failed to close first registry: %v", err) + } + + // Load in new registry from same path + pr2, err := NewPeerRegistryWithPath(peersPath) + if err != nil { + t.Fatalf("failed to create second registry: %v", err) + } + + if pr2.Count() != 1 { + t.Errorf("expected 1 peer after reload, got %d", pr2.Count()) + } + + loaded := pr2.GetPeer("persist-test") + if loaded == nil { + t.Fatal("peer should exist after reload") + } + + if loaded.Name != "Persistent Peer" { + t.Errorf("expected name 'Persistent Peer', got '%s'", loaded.Name) + } +} + +// --- Security Feature Tests --- + +func TestPeerRegistry_AuthMode(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + // Default should be Open + if pr.GetAuthMode() != PeerAuthOpen { + t.Errorf("expected default auth mode to be Open, got %d", pr.GetAuthMode()) + } + + // Set to Allowlist + pr.SetAuthMode(PeerAuthAllowlist) + if pr.GetAuthMode() != PeerAuthAllowlist { + t.Errorf("expected auth mode to be Allowlist after setting, got %d", pr.GetAuthMode()) + } + + // Set back to Open + pr.SetAuthMode(PeerAuthOpen) + if pr.GetAuthMode() != PeerAuthOpen { + t.Errorf("expected auth mode to be Open after resetting, got %d", pr.GetAuthMode()) + } +} + +func TestPeerRegistry_PublicKeyAllowlist(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + testKey := "base64PublicKeyExample1234567890123456" + + // Initially key should not be allowed + if pr.IsPublicKeyAllowed(testKey) { + t.Error("key should not be allowed before adding") + } + + // Add key to allowlist + pr.AllowPublicKey(testKey) + if !pr.IsPublicKeyAllowed(testKey) { + t.Error("key should be allowed after adding") + } + + // List should contain the key + keys := pr.ListAllowedPublicKeys() + found := false + for _, k := range keys { + if k == testKey { + found = true + break + } + } + if !found { + t.Error("ListAllowedPublicKeys should contain the added key") + } + + // Revoke key + pr.RevokePublicKey(testKey) + if pr.IsPublicKeyAllowed(testKey) { + t.Error("key should not be allowed after revoking") + } + + // List should be empty + keys = pr.ListAllowedPublicKeys() + if len(keys) != 0 { + t.Errorf("expected 0 keys after revoke, got %d", len(keys)) + } +} + +func TestPeerRegistry_IsPeerAllowed_OpenMode(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + pr.SetAuthMode(PeerAuthOpen) + + // In Open mode, any peer should be allowed + if !pr.IsPeerAllowed("unknown-peer", "unknown-key") { + t.Error("in Open mode, all peers should be allowed") + } + + if !pr.IsPeerAllowed("", "") { + t.Error("in Open mode, even empty IDs should be allowed") + } +} + +func TestPeerRegistry_IsPeerAllowed_AllowlistMode(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + pr.SetAuthMode(PeerAuthAllowlist) + + // Unknown peer with unknown key should be rejected + if pr.IsPeerAllowed("unknown-peer", "unknown-key") { + t.Error("in Allowlist mode, unknown peers should be rejected") + } + + // Pre-registered peer should be allowed + peer := &Peer{ + ID: "registered-peer", + Name: "Registered", + PublicKey: "registered-key", + } + pr.AddPeer(peer) + + if !pr.IsPeerAllowed("registered-peer", "any-key") { + t.Error("pre-registered peer should be allowed in Allowlist mode") + } + + // Peer with allowlisted public key should be allowed + pr.AllowPublicKey("allowed-key-1234567890") + if !pr.IsPeerAllowed("new-peer", "allowed-key-1234567890") { + t.Error("peer with allowlisted key should be allowed") + } + + // Unknown peer with non-allowlisted key should still be rejected + if pr.IsPeerAllowed("another-peer", "not-allowed-key") { + t.Error("peer without allowlisted key should be rejected") + } +} + +func TestPeerRegistry_PeerNameValidation(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + testCases := []struct { + name string + peerName string + shouldErr bool + }{ + {"empty name allowed", "", false}, + {"single char", "A", false}, + {"simple name", "MyPeer", false}, + {"name with hyphen", "my-peer", false}, + {"name with underscore", "my_peer", false}, + {"name with space", "My Peer", false}, + {"name with numbers", "Peer123", false}, + {"max length name", "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789AB", false}, + {"too long name", "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABC", true}, + {"starts with hyphen", "-peer", true}, + {"ends with hyphen", "peer-", true}, + {"special chars", "peer@host", true}, + {"unicode chars", "peer\u0000name", true}, + } + + for i, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + peer := &Peer{ + ID: "test-peer-" + string(rune('A'+i)), + Name: tc.peerName, + } + err := pr.AddPeer(peer) + if tc.shouldErr && err == nil { + t.Errorf("expected error for name '%s' but got none", tc.peerName) + } else if !tc.shouldErr && err != nil { + t.Errorf("unexpected error for name '%s': %v", tc.peerName, err) + } + // Clean up for next test + if err == nil { + pr.RemovePeer(peer.ID) + } + }) + } +} + +func TestPeerRegistry_ScoreRecording(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + peer := &Peer{ + ID: "score-record-test", + Name: "Score Peer", + Score: 50, // Start at neutral + } + pr.AddPeer(peer) + + // Record successes - score should increase + for i := 0; i < 5; i++ { + pr.RecordSuccess("score-record-test") + } + updated := pr.GetPeer("score-record-test") + if updated.Score <= 50 { + t.Errorf("score should increase after successes, got %f", updated.Score) + } + + // Record failures - score should decrease + initialScore := updated.Score + for i := 0; i < 3; i++ { + pr.RecordFailure("score-record-test") + } + updated = pr.GetPeer("score-record-test") + if updated.Score >= initialScore { + t.Errorf("score should decrease after failures, got %f (was %f)", updated.Score, initialScore) + } + + // Record timeouts - score should decrease + initialScore = updated.Score + pr.RecordTimeout("score-record-test") + updated = pr.GetPeer("score-record-test") + if updated.Score >= initialScore { + t.Errorf("score should decrease after timeout, got %f (was %f)", updated.Score, initialScore) + } + + // Score should be clamped to min/max + for i := 0; i < 100; i++ { + pr.RecordSuccess("score-record-test") + } + updated = pr.GetPeer("score-record-test") + if updated.Score > ScoreMaximum { + t.Errorf("score should be clamped to max %f, got %f", ScoreMaximum, updated.Score) + } + + for i := 0; i < 100; i++ { + pr.RecordFailure("score-record-test") + } + updated = pr.GetPeer("score-record-test") + if updated.Score < ScoreMinimum { + t.Errorf("score should be clamped to min %f, got %f", ScoreMinimum, updated.Score) + } +} + +func TestPeerRegistry_GetPeersByScore(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + // Add peers with different scores + peers := []*Peer{ + {ID: "low-score", Name: "Low", Score: 20}, + {ID: "high-score", Name: "High", Score: 90}, + {ID: "mid-score", Name: "Mid", Score: 50}, + } + + for _, p := range peers { + pr.AddPeer(p) + } + + sorted := pr.GetPeersByScore() + if len(sorted) != 3 { + t.Fatalf("expected 3 peers, got %d", len(sorted)) + } + + // Should be sorted by score descending + if sorted[0].ID != "high-score" { + t.Errorf("first peer should be high-score, got %s", sorted[0].ID) + } + if sorted[1].ID != "mid-score" { + t.Errorf("second peer should be mid-score, got %s", sorted[1].ID) + } + if sorted[2].ID != "low-score" { + t.Errorf("third peer should be low-score, got %s", sorted[2].ID) + } +} diff --git a/node/protocol.go b/node/protocol.go new file mode 100644 index 0000000..197d5e4 --- /dev/null +++ b/node/protocol.go @@ -0,0 +1,88 @@ +package node + +import ( + "fmt" +) + +// ProtocolError represents an error from the remote peer. +type ProtocolError struct { + Code int + Message string +} + +func (e *ProtocolError) Error() string { + return fmt.Sprintf("remote error (%d): %s", e.Code, e.Message) +} + +// ResponseHandler provides helpers for handling protocol responses. +type ResponseHandler struct{} + +// ValidateResponse checks if the response is valid and returns a parsed error if it's an error response. +// It checks: +// 1. If response is nil (returns error) +// 2. If response is an error message (returns ProtocolError) +// 3. If response type matches expected (returns error if not) +func (h *ResponseHandler) ValidateResponse(resp *Message, expectedType MessageType) error { + if resp == nil { + return fmt.Errorf("nil response") + } + + // Check for error response + if resp.Type == MsgError { + var errPayload ErrorPayload + if err := resp.ParsePayload(&errPayload); err != nil { + return &ProtocolError{Code: ErrCodeUnknown, Message: "unable to parse error response"} + } + return &ProtocolError{Code: errPayload.Code, Message: errPayload.Message} + } + + // Check expected type + if resp.Type != expectedType { + return fmt.Errorf("unexpected response type: expected %s, got %s", expectedType, resp.Type) + } + + return nil +} + +// ParseResponse validates the response and parses the payload into the target. +// This combines ValidateResponse and ParsePayload into a single call. +func (h *ResponseHandler) ParseResponse(resp *Message, expectedType MessageType, target interface{}) error { + if err := h.ValidateResponse(resp, expectedType); err != nil { + return err + } + + if target != nil { + if err := resp.ParsePayload(target); err != nil { + return fmt.Errorf("failed to parse %s payload: %w", expectedType, err) + } + } + + return nil +} + +// DefaultResponseHandler is the default response handler instance. +var DefaultResponseHandler = &ResponseHandler{} + +// ValidateResponse is a convenience function using the default handler. +func ValidateResponse(resp *Message, expectedType MessageType) error { + return DefaultResponseHandler.ValidateResponse(resp, expectedType) +} + +// ParseResponse is a convenience function using the default handler. +func ParseResponse(resp *Message, expectedType MessageType, target interface{}) error { + return DefaultResponseHandler.ParseResponse(resp, expectedType, target) +} + +// IsProtocolError returns true if the error is a ProtocolError. +func IsProtocolError(err error) bool { + _, ok := err.(*ProtocolError) + return ok +} + +// GetProtocolErrorCode returns the error code if err is a ProtocolError, otherwise returns 0. +func GetProtocolErrorCode(err error) int { + if pe, ok := err.(*ProtocolError); ok { + return pe.Code + } + return 0 +} diff --git a/node/protocol_test.go b/node/protocol_test.go new file mode 100644 index 0000000..1d728a4 --- /dev/null +++ b/node/protocol_test.go @@ -0,0 +1,161 @@ +package node + +import ( + "fmt" + "testing" +) + +func TestResponseHandler_ValidateResponse(t *testing.T) { + handler := &ResponseHandler{} + + t.Run("NilResponse", func(t *testing.T) { + err := handler.ValidateResponse(nil, MsgStats) + if err == nil { + t.Error("Expected error for nil response") + } + }) + + t.Run("ErrorResponse", func(t *testing.T) { + errMsg, _ := NewErrorMessage("sender", "receiver", ErrCodeOperationFailed, "operation failed", "") + err := handler.ValidateResponse(errMsg, MsgStats) + if err == nil { + t.Fatal("Expected error for error response") + } + + if !IsProtocolError(err) { + t.Errorf("Expected ProtocolError, got %T", err) + } + + if GetProtocolErrorCode(err) != ErrCodeOperationFailed { + t.Errorf("Expected code %d, got %d", ErrCodeOperationFailed, GetProtocolErrorCode(err)) + } + }) + + t.Run("WrongType", func(t *testing.T) { + msg, _ := NewMessage(MsgPong, "sender", "receiver", nil) + err := handler.ValidateResponse(msg, MsgStats) + if err == nil { + t.Error("Expected error for wrong type") + } + if IsProtocolError(err) { + t.Error("Should not be a ProtocolError for type mismatch") + } + }) + + t.Run("ValidResponse", func(t *testing.T) { + msg, _ := NewMessage(MsgStats, "sender", "receiver", StatsPayload{NodeID: "test"}) + err := handler.ValidateResponse(msg, MsgStats) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) +} + +func TestResponseHandler_ParseResponse(t *testing.T) { + handler := &ResponseHandler{} + + t.Run("ParseStats", func(t *testing.T) { + payload := StatsPayload{ + NodeID: "node-123", + NodeName: "Test Node", + Uptime: 3600, + } + msg, _ := NewMessage(MsgStats, "sender", "receiver", payload) + + var parsed StatsPayload + err := handler.ParseResponse(msg, MsgStats, &parsed) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if parsed.NodeID != "node-123" { + t.Errorf("Expected NodeID 'node-123', got '%s'", parsed.NodeID) + } + if parsed.Uptime != 3600 { + t.Errorf("Expected Uptime 3600, got %d", parsed.Uptime) + } + }) + + t.Run("ParseMinerAck", func(t *testing.T) { + payload := MinerAckPayload{ + Success: true, + MinerName: "xmrig-1", + } + msg, _ := NewMessage(MsgMinerAck, "sender", "receiver", payload) + + var parsed MinerAckPayload + err := handler.ParseResponse(msg, MsgMinerAck, &parsed) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !parsed.Success { + t.Error("Expected Success to be true") + } + if parsed.MinerName != "xmrig-1" { + t.Errorf("Expected MinerName 'xmrig-1', got '%s'", parsed.MinerName) + } + }) + + t.Run("ErrorResponse", func(t *testing.T) { + errMsg, _ := NewErrorMessage("sender", "receiver", ErrCodeNotFound, "not found", "") + + var parsed StatsPayload + err := handler.ParseResponse(errMsg, MsgStats, &parsed) + if err == nil { + t.Error("Expected error for error response") + } + if !IsProtocolError(err) { + t.Errorf("Expected ProtocolError, got %T", err) + } + }) + + t.Run("NilTarget", func(t *testing.T) { + msg, _ := NewMessage(MsgPong, "sender", "receiver", nil) + err := handler.ParseResponse(msg, MsgPong, nil) + if err != nil { + t.Errorf("Unexpected error with nil target: %v", err) + } + }) +} + +func TestProtocolError(t *testing.T) { + err := &ProtocolError{Code: 1001, Message: "test error"} + + if err.Error() != "remote error (1001): test error" { + t.Errorf("Unexpected error message: %s", err.Error()) + } + + if !IsProtocolError(err) { + t.Error("IsProtocolError should return true") + } + + if GetProtocolErrorCode(err) != 1001 { + t.Errorf("Expected code 1001, got %d", GetProtocolErrorCode(err)) + } +} + +func TestConvenienceFunctions(t *testing.T) { + msg, _ := NewMessage(MsgStats, "sender", "receiver", StatsPayload{NodeID: "test"}) + + // Test ValidateResponse + if err := ValidateResponse(msg, MsgStats); err != nil { + t.Errorf("ValidateResponse failed: %v", err) + } + + // Test ParseResponse + var parsed StatsPayload + if err := ParseResponse(msg, MsgStats, &parsed); err != nil { + t.Errorf("ParseResponse failed: %v", err) + } + if parsed.NodeID != "test" { + t.Errorf("Expected NodeID 'test', got '%s'", parsed.NodeID) + } +} + +func TestGetProtocolErrorCode_NonProtocolError(t *testing.T) { + err := fmt.Errorf("regular error") + if GetProtocolErrorCode(err) != 0 { + t.Error("Expected 0 for non-ProtocolError") + } +} diff --git a/node/transport.go b/node/transport.go new file mode 100644 index 0000000..1040920 --- /dev/null +++ b/node/transport.go @@ -0,0 +1,934 @@ +package node + +import ( + "context" + "crypto/tls" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "sync" + "sync/atomic" + "time" + + "github.com/Snider/Borg/pkg/smsg" + "forge.lthn.ai/core/go-p2p/logging" + "github.com/gorilla/websocket" +) + +// debugLogCounter tracks message counts for rate limiting debug logs +var debugLogCounter atomic.Int64 + +// debugLogInterval controls how often we log debug messages in hot paths (1 in N) +const debugLogInterval = 100 + +// DefaultMaxMessageSize is the default maximum message size (1MB) +const DefaultMaxMessageSize int64 = 1 << 20 // 1MB + +// TransportConfig configures the WebSocket transport. +type TransportConfig struct { + ListenAddr string // ":9091" default + WSPath string // "/ws" - WebSocket endpoint path + TLSCertPath string // Optional TLS for wss:// + TLSKeyPath string + MaxConns int // Maximum concurrent connections + MaxMessageSize int64 // Maximum message size in bytes (0 = 1MB default) + PingInterval time.Duration // WebSocket keepalive interval + PongTimeout time.Duration // Timeout waiting for pong +} + +// DefaultTransportConfig returns sensible defaults. +func DefaultTransportConfig() TransportConfig { + return TransportConfig{ + ListenAddr: ":9091", + WSPath: "/ws", + MaxConns: 100, + MaxMessageSize: DefaultMaxMessageSize, + PingInterval: 30 * time.Second, + PongTimeout: 10 * time.Second, + } +} + +// MessageHandler processes incoming messages. +type MessageHandler func(conn *PeerConnection, msg *Message) + +// MessageDeduplicator tracks seen message IDs to prevent duplicate processing +type MessageDeduplicator struct { + seen map[string]time.Time + mu sync.RWMutex + ttl time.Duration +} + +// NewMessageDeduplicator creates a deduplicator with specified TTL +func NewMessageDeduplicator(ttl time.Duration) *MessageDeduplicator { + d := &MessageDeduplicator{ + seen: make(map[string]time.Time), + ttl: ttl, + } + return d +} + +// IsDuplicate checks if a message ID has been seen recently +func (d *MessageDeduplicator) IsDuplicate(msgID string) bool { + d.mu.RLock() + _, exists := d.seen[msgID] + d.mu.RUnlock() + return exists +} + +// Mark records a message ID as seen +func (d *MessageDeduplicator) Mark(msgID string) { + d.mu.Lock() + d.seen[msgID] = time.Now() + d.mu.Unlock() +} + +// Cleanup removes expired entries +func (d *MessageDeduplicator) Cleanup() { + d.mu.Lock() + defer d.mu.Unlock() + now := time.Now() + for id, seen := range d.seen { + if now.Sub(seen) > d.ttl { + delete(d.seen, id) + } + } +} + +// Transport manages WebSocket connections with SMSG encryption. +type Transport struct { + config TransportConfig + server *http.Server + upgrader websocket.Upgrader + conns map[string]*PeerConnection // peer ID -> connection + pendingConns atomic.Int32 // tracks connections during handshake + node *NodeManager + registry *PeerRegistry + handler MessageHandler + dedup *MessageDeduplicator // Message deduplication + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +// PeerRateLimiter implements a simple token bucket rate limiter per peer +type PeerRateLimiter struct { + tokens int + maxTokens int + refillRate int // tokens per second + lastRefill time.Time + mu sync.Mutex +} + +// NewPeerRateLimiter creates a rate limiter with specified messages/second +func NewPeerRateLimiter(maxTokens, refillRate int) *PeerRateLimiter { + return &PeerRateLimiter{ + tokens: maxTokens, + maxTokens: maxTokens, + refillRate: refillRate, + lastRefill: time.Now(), + } +} + +// Allow checks if a message is allowed and consumes a token if so +func (r *PeerRateLimiter) Allow() bool { + r.mu.Lock() + defer r.mu.Unlock() + + // Refill tokens based on elapsed time + now := time.Now() + elapsed := now.Sub(r.lastRefill) + tokensToAdd := int(elapsed.Seconds()) * r.refillRate + if tokensToAdd > 0 { + r.tokens = min(r.tokens+tokensToAdd, r.maxTokens) + r.lastRefill = now + } + + // Check if we have tokens available + if r.tokens > 0 { + r.tokens-- + return true + } + return false +} + +// PeerConnection represents an active connection to a peer. +type PeerConnection struct { + Peer *Peer + Conn *websocket.Conn + SharedSecret []byte // Derived via X25519 ECDH, used for SMSG + LastActivity time.Time + writeMu sync.Mutex // Serialize WebSocket writes + transport *Transport + closeOnce sync.Once // Ensure Close() is only called once + rateLimiter *PeerRateLimiter // Per-peer message rate limiting +} + +// NewTransport creates a new WebSocket transport. +func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportConfig) *Transport { + ctx, cancel := context.WithCancel(context.Background()) + + return &Transport{ + config: config, + node: node, + registry: registry, + conns: make(map[string]*PeerConnection), + dedup: NewMessageDeduplicator(5 * time.Minute), // 5 minute TTL for dedup + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + // Allow local connections only for security + origin := r.Header.Get("Origin") + if origin == "" { + return true // No origin header (non-browser client) + } + // Allow localhost and 127.0.0.1 origins + u, err := url.Parse(origin) + if err != nil { + return false + } + host := u.Hostname() + return host == "localhost" || host == "127.0.0.1" || host == "::1" + }, + }, + ctx: ctx, + cancel: cancel, + } +} + +// Start begins listening for incoming connections. +func (t *Transport) Start() error { + mux := http.NewServeMux() + mux.HandleFunc(t.config.WSPath, t.handleWSUpgrade) + + t.server = &http.Server{ + Addr: t.config.ListenAddr, + Handler: mux, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + ReadHeaderTimeout: 10 * time.Second, + } + + // Apply TLS hardening if TLS is enabled + if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" { + t.server.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + // TLS 1.3 ciphers (automatically used when available) + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, + // TLS 1.2 secure ciphers + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + }, + CurvePreferences: []tls.CurveID{ + tls.X25519, + tls.CurveP256, + }, + } + } + + t.wg.Add(1) + go func() { + defer t.wg.Done() + var err error + if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" { + err = t.server.ListenAndServeTLS(t.config.TLSCertPath, t.config.TLSKeyPath) + } else { + err = t.server.ListenAndServe() + } + if err != nil && err != http.ErrServerClosed { + logging.Error("HTTP server error", logging.Fields{"error": err, "addr": t.config.ListenAddr}) + } + }() + + // Start message deduplication cleanup goroutine + t.wg.Add(1) + go func() { + defer t.wg.Done() + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for { + select { + case <-t.ctx.Done(): + return + case <-ticker.C: + t.dedup.Cleanup() + } + } + }() + + return nil +} + +// Stop gracefully shuts down the transport. +func (t *Transport) Stop() error { + t.cancel() + + // Gracefully close all connections with shutdown message + t.mu.Lock() + for _, pc := range t.conns { + pc.GracefulClose("server shutdown", DisconnectShutdown) + } + t.mu.Unlock() + + // Shutdown HTTP server if it was started + if t.server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := t.server.Shutdown(ctx); err != nil { + return fmt.Errorf("server shutdown error: %w", err) + } + } + + t.wg.Wait() + return nil +} + +// OnMessage sets the handler for incoming messages. +// Must be called before Start() to avoid races. +func (t *Transport) OnMessage(handler MessageHandler) { + t.mu.Lock() + defer t.mu.Unlock() + t.handler = handler +} + +// Connect establishes a connection to a peer. +func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) { + // Build WebSocket URL + scheme := "ws" + if t.config.TLSCertPath != "" { + scheme = "wss" + } + u := url.URL{Scheme: scheme, Host: peer.Address, Path: t.config.WSPath} + + // Dial the peer with timeout to prevent hanging on unresponsive peers + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + conn, _, err := dialer.Dial(u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to connect to peer: %w", err) + } + + pc := &PeerConnection{ + Peer: peer, + Conn: conn, + LastActivity: time.Now(), + transport: t, + rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill + } + + // Perform handshake with challenge-response authentication + // This also derives and stores the shared secret in pc.SharedSecret + if err := t.performHandshake(pc); err != nil { + conn.Close() + return nil, fmt.Errorf("handshake failed: %w", err) + } + + // Store connection using the real peer ID from handshake + t.mu.Lock() + t.conns[pc.Peer.ID] = pc + t.mu.Unlock() + + logging.Debug("connected to peer", logging.Fields{"peer_id": pc.Peer.ID, "secret_len": len(pc.SharedSecret)}) + + // Update registry + t.registry.SetConnected(pc.Peer.ID, true) + + // Start read loop + t.wg.Add(1) + go t.readLoop(pc) + + logging.Debug("started readLoop for peer", logging.Fields{"peer_id": pc.Peer.ID}) + + // Start keepalive + t.wg.Add(1) + go t.keepalive(pc) + + return pc, nil +} + +// Send sends a message to a specific peer. +func (t *Transport) Send(peerID string, msg *Message) error { + t.mu.RLock() + pc, exists := t.conns[peerID] + t.mu.RUnlock() + + if !exists { + return fmt.Errorf("peer %s not connected", peerID) + } + + return pc.Send(msg) +} + +// Broadcast sends a message to all connected peers except the sender. +// The sender is identified by msg.From and excluded to prevent echo. +func (t *Transport) Broadcast(msg *Message) error { + t.mu.RLock() + conns := make([]*PeerConnection, 0, len(t.conns)) + for _, pc := range t.conns { + // Exclude sender from broadcast to prevent echo (P2P-MED-6) + if pc.Peer != nil && pc.Peer.ID == msg.From { + continue + } + conns = append(conns, pc) + } + t.mu.RUnlock() + + var lastErr error + for _, pc := range conns { + if err := pc.Send(msg); err != nil { + lastErr = err + } + } + return lastErr +} + +// GetConnection returns an active connection to a peer. +func (t *Transport) GetConnection(peerID string) *PeerConnection { + t.mu.RLock() + defer t.mu.RUnlock() + return t.conns[peerID] +} + +// handleWSUpgrade handles incoming WebSocket connections. +func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { + // Enforce MaxConns limit (including pending connections during handshake) + t.mu.RLock() + currentConns := len(t.conns) + t.mu.RUnlock() + pendingConns := int(t.pendingConns.Load()) + + totalConns := currentConns + pendingConns + if totalConns >= t.config.MaxConns { + http.Error(w, "Too many connections", http.StatusServiceUnavailable) + return + } + + // Track this connection as pending during handshake + t.pendingConns.Add(1) + defer t.pendingConns.Add(-1) + + conn, err := t.upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + // Apply message size limit during handshake to prevent memory exhaustion + maxSize := t.config.MaxMessageSize + if maxSize <= 0 { + maxSize = DefaultMaxMessageSize + } + conn.SetReadLimit(maxSize) + + // Set handshake timeout to prevent slow/malicious clients from blocking + handshakeTimeout := 10 * time.Second + conn.SetReadDeadline(time.Now().Add(handshakeTimeout)) + + // Wait for handshake from client + _, data, err := conn.ReadMessage() + if err != nil { + conn.Close() + return + } + + // Decode handshake message (not encrypted yet, contains public key) + var msg Message + if err := json.Unmarshal(data, &msg); err != nil { + conn.Close() + return + } + + if msg.Type != MsgHandshake { + conn.Close() + return + } + + var payload HandshakePayload + if err := msg.ParsePayload(&payload); err != nil { + conn.Close() + return + } + + // Check protocol version compatibility (P2P-MED-1) + if !IsProtocolVersionSupported(payload.Version) { + logging.Warn("peer connection rejected: incompatible protocol version", logging.Fields{ + "peer_version": payload.Version, + "supported_versions": SupportedProtocolVersions, + "peer_id": payload.Identity.ID, + }) + identity := t.node.GetIdentity() + if identity != nil { + rejectPayload := HandshakeAckPayload{ + Identity: *identity, + Accepted: false, + Reason: fmt.Sprintf("incompatible protocol version %s, supported: %v", payload.Version, SupportedProtocolVersions), + } + rejectMsg, _ := NewMessage(MsgHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload) + if rejectData, err := MarshalJSON(rejectMsg); err == nil { + conn.WriteMessage(websocket.TextMessage, rejectData) + } + } + conn.Close() + return + } + + // Derive shared secret from peer's public key + sharedSecret, err := t.node.DeriveSharedSecret(payload.Identity.PublicKey) + if err != nil { + conn.Close() + return + } + + // Check if peer is allowed to connect (allowlist check) + if !t.registry.IsPeerAllowed(payload.Identity.ID, payload.Identity.PublicKey) { + logging.Warn("peer connection rejected: not in allowlist", logging.Fields{ + "peer_id": payload.Identity.ID, + "peer_name": payload.Identity.Name, + "public_key": safeKeyPrefix(payload.Identity.PublicKey), + }) + // Send rejection before closing + identity := t.node.GetIdentity() + if identity != nil { + rejectPayload := HandshakeAckPayload{ + Identity: *identity, + Accepted: false, + Reason: "peer not authorized", + } + rejectMsg, _ := NewMessage(MsgHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload) + if rejectData, err := MarshalJSON(rejectMsg); err == nil { + conn.WriteMessage(websocket.TextMessage, rejectData) + } + } + conn.Close() + return + } + + // Create peer if not exists (only if auth passed) + peer := t.registry.GetPeer(payload.Identity.ID) + if peer == nil { + // Auto-register the peer since they passed allowlist check + peer = &Peer{ + ID: payload.Identity.ID, + Name: payload.Identity.Name, + PublicKey: payload.Identity.PublicKey, + Role: payload.Identity.Role, + AddedAt: time.Now(), + Score: 50, + } + t.registry.AddPeer(peer) + logging.Info("auto-registered new peer", logging.Fields{ + "peer_id": peer.ID, + "peer_name": peer.Name, + }) + } + + pc := &PeerConnection{ + Peer: peer, + Conn: conn, + SharedSecret: sharedSecret, + LastActivity: time.Now(), + transport: t, + rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill + } + + // Send handshake acknowledgment + identity := t.node.GetIdentity() + if identity == nil { + conn.Close() + return + } + + // Sign the client's challenge to prove we have the matching private key + var challengeResponse []byte + if len(payload.Challenge) > 0 { + challengeResponse = SignChallenge(payload.Challenge, sharedSecret) + } + + ackPayload := HandshakeAckPayload{ + Identity: *identity, + ChallengeResponse: challengeResponse, + Accepted: true, + } + + ackMsg, err := NewMessage(MsgHandshakeAck, identity.ID, peer.ID, ackPayload) + if err != nil { + conn.Close() + return + } + + // First ack is unencrypted (peer needs to know our public key) + ackData, err := MarshalJSON(ackMsg) + if err != nil { + conn.Close() + return + } + + if err := conn.WriteMessage(websocket.TextMessage, ackData); err != nil { + conn.Close() + return + } + + // Store connection + t.mu.Lock() + t.conns[peer.ID] = pc + t.mu.Unlock() + + // Update registry + t.registry.SetConnected(peer.ID, true) + + // Start read loop + t.wg.Add(1) + go t.readLoop(pc) + + // Start keepalive + t.wg.Add(1) + go t.keepalive(pc) +} + +// performHandshake initiates handshake with a peer. +func (t *Transport) performHandshake(pc *PeerConnection) error { + // Set handshake timeout + handshakeTimeout := 10 * time.Second + pc.Conn.SetWriteDeadline(time.Now().Add(handshakeTimeout)) + pc.Conn.SetReadDeadline(time.Now().Add(handshakeTimeout)) + defer func() { + // Reset deadlines after handshake + pc.Conn.SetWriteDeadline(time.Time{}) + pc.Conn.SetReadDeadline(time.Time{}) + }() + + identity := t.node.GetIdentity() + if identity == nil { + return fmt.Errorf("node identity not initialized") + } + + // Generate challenge for the server to prove it has the matching private key + challenge, err := GenerateChallenge() + if err != nil { + return fmt.Errorf("generate challenge: %w", err) + } + + payload := HandshakePayload{ + Identity: *identity, + Challenge: challenge, + Version: ProtocolVersion, + } + + msg, err := NewMessage(MsgHandshake, identity.ID, pc.Peer.ID, payload) + if err != nil { + return fmt.Errorf("create handshake message: %w", err) + } + + // First message is unencrypted (peer needs our public key) + data, err := MarshalJSON(msg) + if err != nil { + return fmt.Errorf("marshal handshake message: %w", err) + } + + if err := pc.Conn.WriteMessage(websocket.TextMessage, data); err != nil { + return fmt.Errorf("send handshake: %w", err) + } + + // Wait for ack + _, ackData, err := pc.Conn.ReadMessage() + if err != nil { + return fmt.Errorf("read handshake ack: %w", err) + } + + var ackMsg Message + if err := json.Unmarshal(ackData, &ackMsg); err != nil { + return fmt.Errorf("unmarshal handshake ack: %w", err) + } + + if ackMsg.Type != MsgHandshakeAck { + return fmt.Errorf("expected handshake_ack, got %s", ackMsg.Type) + } + + var ackPayload HandshakeAckPayload + if err := ackMsg.ParsePayload(&ackPayload); err != nil { + return fmt.Errorf("parse handshake ack payload: %w", err) + } + + if !ackPayload.Accepted { + return fmt.Errorf("handshake rejected: %s", ackPayload.Reason) + } + + // Update peer with the received identity info + pc.Peer.ID = ackPayload.Identity.ID + pc.Peer.PublicKey = ackPayload.Identity.PublicKey + pc.Peer.Name = ackPayload.Identity.Name + pc.Peer.Role = ackPayload.Identity.Role + + // Verify challenge response - derive shared secret first using the peer's public key + sharedSecret, err := t.node.DeriveSharedSecret(pc.Peer.PublicKey) + if err != nil { + return fmt.Errorf("derive shared secret for challenge verification: %w", err) + } + + // Verify the server's response to our challenge + if len(ackPayload.ChallengeResponse) == 0 { + return fmt.Errorf("server did not provide challenge response") + } + if !VerifyChallenge(challenge, ackPayload.ChallengeResponse, sharedSecret) { + return fmt.Errorf("challenge response verification failed: server may not have matching private key") + } + + // Store the shared secret for later use + pc.SharedSecret = sharedSecret + + // Update the peer in registry with the real identity + if err := t.registry.UpdatePeer(pc.Peer); err != nil { + // If update fails (peer not found with old ID), add as new + t.registry.AddPeer(pc.Peer) + } + + logging.Debug("handshake completed with challenge-response verification", logging.Fields{ + "peer_id": pc.Peer.ID, + "peer_name": pc.Peer.Name, + }) + + return nil +} + +// readLoop reads messages from a peer connection. +func (t *Transport) readLoop(pc *PeerConnection) { + defer t.wg.Done() + defer t.removeConnection(pc) + + // Apply message size limit to prevent memory exhaustion attacks + maxSize := t.config.MaxMessageSize + if maxSize <= 0 { + maxSize = DefaultMaxMessageSize + } + pc.Conn.SetReadLimit(maxSize) + + for { + select { + case <-t.ctx.Done(): + return + default: + } + + // Set read deadline to prevent blocking forever on unresponsive connections + readDeadline := t.config.PingInterval + t.config.PongTimeout + if err := pc.Conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { + logging.Error("SetReadDeadline error", logging.Fields{"peer_id": pc.Peer.ID, "error": err}) + return + } + + _, data, err := pc.Conn.ReadMessage() + if err != nil { + logging.Debug("read error from peer", logging.Fields{"peer_id": pc.Peer.ID, "error": err}) + return + } + + pc.LastActivity = time.Now() + + // Check rate limit before processing + if pc.rateLimiter != nil && !pc.rateLimiter.Allow() { + logging.Warn("peer rate limited, dropping message", logging.Fields{"peer_id": pc.Peer.ID}) + continue // Drop message from rate-limited peer + } + + // Decrypt message using SMSG with shared secret + msg, err := t.decryptMessage(data, pc.SharedSecret) + if err != nil { + logging.Debug("decrypt error from peer", logging.Fields{"peer_id": pc.Peer.ID, "error": err, "data_len": len(data)}) + continue // Skip invalid messages + } + + // Check for duplicate messages (prevents amplification attacks) + if t.dedup.IsDuplicate(msg.ID) { + logging.Debug("dropping duplicate message", logging.Fields{"msg_id": msg.ID, "peer_id": pc.Peer.ID}) + continue + } + t.dedup.Mark(msg.ID) + + // Rate limit debug logs in hot path to reduce noise (log 1 in N messages) + if debugLogCounter.Add(1)%debugLogInterval == 0 { + logging.Debug("received message from peer", logging.Fields{"type": msg.Type, "peer_id": pc.Peer.ID, "reply_to": msg.ReplyTo, "sample": "1/100"}) + } + + // Dispatch to handler (read handler under lock to avoid race) + t.mu.RLock() + handler := t.handler + t.mu.RUnlock() + if handler != nil { + handler(pc, msg) + } + } +} + +// keepalive sends periodic pings. +func (t *Transport) keepalive(pc *PeerConnection) { + defer t.wg.Done() + + ticker := time.NewTicker(t.config.PingInterval) + defer ticker.Stop() + + for { + select { + case <-t.ctx.Done(): + return + case <-ticker.C: + // Check if connection is still alive + if time.Since(pc.LastActivity) > t.config.PingInterval+t.config.PongTimeout { + t.removeConnection(pc) + return + } + + // Send ping + identity := t.node.GetIdentity() + pingMsg, err := NewMessage(MsgPing, identity.ID, pc.Peer.ID, PingPayload{ + SentAt: time.Now().UnixMilli(), + }) + if err != nil { + continue + } + + if err := pc.Send(pingMsg); err != nil { + t.removeConnection(pc) + return + } + } + } +} + +// removeConnection removes and cleans up a connection. +func (t *Transport) removeConnection(pc *PeerConnection) { + t.mu.Lock() + delete(t.conns, pc.Peer.ID) + t.mu.Unlock() + + t.registry.SetConnected(pc.Peer.ID, false) + pc.Close() +} + +// Send sends an encrypted message over the connection. +func (pc *PeerConnection) Send(msg *Message) error { + pc.writeMu.Lock() + defer pc.writeMu.Unlock() + + // Encrypt message using SMSG + data, err := pc.transport.encryptMessage(msg, pc.SharedSecret) + if err != nil { + return err + } + + // Set write deadline to prevent blocking forever + if err := pc.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil { + return fmt.Errorf("failed to set write deadline: %w", err) + } + defer pc.Conn.SetWriteDeadline(time.Time{}) // Reset deadline after send + + return pc.Conn.WriteMessage(websocket.BinaryMessage, data) +} + +// Close closes the connection. +func (pc *PeerConnection) Close() error { + var err error + pc.closeOnce.Do(func() { + err = pc.Conn.Close() + }) + return err +} + +// DisconnectPayload contains reason for disconnect. +type DisconnectPayload struct { + Reason string `json:"reason"` + Code int `json:"code"` // Optional disconnect code +} + +// Disconnect codes +const ( + DisconnectNormal = 1000 // Normal closure + DisconnectGoingAway = 1001 // Server/peer going away + DisconnectProtocolErr = 1002 // Protocol error + DisconnectTimeout = 1003 // Idle timeout + DisconnectShutdown = 1004 // Server shutdown +) + +// GracefulClose sends a disconnect message before closing the connection. +func (pc *PeerConnection) GracefulClose(reason string, code int) error { + var err error + pc.closeOnce.Do(func() { + // Try to send disconnect message (best effort) + if pc.transport != nil && pc.SharedSecret != nil { + identity := pc.transport.node.GetIdentity() + if identity != nil { + payload := DisconnectPayload{ + Reason: reason, + Code: code, + } + msg, msgErr := NewMessage(MsgDisconnect, identity.ID, pc.Peer.ID, payload) + if msgErr == nil { + // Set short deadline for disconnect message + pc.Conn.SetWriteDeadline(time.Now().Add(2 * time.Second)) + pc.Send(msg) + } + } + } + + // Close the underlying connection + err = pc.Conn.Close() + }) + return err +} + +// encryptMessage encrypts a message using SMSG with the shared secret. +func (t *Transport) encryptMessage(msg *Message, sharedSecret []byte) ([]byte, error) { + // Serialize message to JSON (using pooled buffer for efficiency) + msgData, err := MarshalJSON(msg) + if err != nil { + return nil, err + } + + // Create SMSG message + smsgMsg := smsg.NewMessage(string(msgData)) + + // Encrypt using shared secret as password (base64 encoded) + password := base64.StdEncoding.EncodeToString(sharedSecret) + encrypted, err := smsg.Encrypt(smsgMsg, password) + if err != nil { + return nil, err + } + + return encrypted, nil +} + +// decryptMessage decrypts a message using SMSG with the shared secret. +func (t *Transport) decryptMessage(data []byte, sharedSecret []byte) (*Message, error) { + // Decrypt using shared secret as password + password := base64.StdEncoding.EncodeToString(sharedSecret) + smsgMsg, err := smsg.Decrypt(data, password) + if err != nil { + return nil, err + } + + // Parse message from JSON + var msg Message + if err := json.Unmarshal([]byte(smsgMsg.Body), &msg); err != nil { + return nil, err + } + + return &msg, nil +} + +// ConnectedPeers returns the number of connected peers. +func (t *Transport) ConnectedPeers() int { + t.mu.RLock() + defer t.mu.RUnlock() + return len(t.conns) +} diff --git a/node/worker.go b/node/worker.go new file mode 100644 index 0000000..72eb9ff --- /dev/null +++ b/node/worker.go @@ -0,0 +1,402 @@ +package node + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "path/filepath" + "time" + + "forge.lthn.ai/core/go-p2p/logging" + "github.com/adrg/xdg" +) + +// MinerManager interface for the mining package integration. +// This allows the node package to interact with mining.Manager without import cycles. +type MinerManager interface { + StartMiner(minerType string, config interface{}) (MinerInstance, error) + StopMiner(name string) error + ListMiners() []MinerInstance + GetMiner(name string) (MinerInstance, error) +} + +// MinerInstance represents a running miner for stats collection. +type MinerInstance interface { + GetName() string + GetType() string + GetStats() (interface{}, error) + GetConsoleHistory(lines int) []string +} + +// ProfileManager interface for profile operations. +type ProfileManager interface { + GetProfile(id string) (interface{}, error) + SaveProfile(profile interface{}) error +} + +// Worker handles incoming messages on a worker node. +type Worker struct { + node *NodeManager + transport *Transport + minerManager MinerManager + profileManager ProfileManager + startTime time.Time +} + +// NewWorker creates a new Worker instance. +func NewWorker(node *NodeManager, transport *Transport) *Worker { + return &Worker{ + node: node, + transport: transport, + startTime: time.Now(), + } +} + +// SetMinerManager sets the miner manager for handling miner operations. +func (w *Worker) SetMinerManager(manager MinerManager) { + w.minerManager = manager +} + +// SetProfileManager sets the profile manager for handling profile operations. +func (w *Worker) SetProfileManager(manager ProfileManager) { + w.profileManager = manager +} + +// HandleMessage processes incoming messages and returns a response. +func (w *Worker) HandleMessage(conn *PeerConnection, msg *Message) { + var response *Message + var err error + + switch msg.Type { + case MsgPing: + response, err = w.handlePing(msg) + case MsgGetStats: + response, err = w.handleGetStats(msg) + case MsgStartMiner: + response, err = w.handleStartMiner(msg) + case MsgStopMiner: + response, err = w.handleStopMiner(msg) + case MsgGetLogs: + response, err = w.handleGetLogs(msg) + case MsgDeploy: + response, err = w.handleDeploy(conn, msg) + default: + // Unknown message type - ignore or send error + return + } + + if err != nil { + // Send error response + identity := w.node.GetIdentity() + if identity != nil { + errMsg, _ := NewErrorMessage( + identity.ID, + msg.From, + ErrCodeOperationFailed, + err.Error(), + msg.ID, + ) + conn.Send(errMsg) + } + return + } + + if response != nil { + logging.Debug("sending response", logging.Fields{"type": response.Type, "to": msg.From}) + if err := conn.Send(response); err != nil { + logging.Error("failed to send response", logging.Fields{"error": err}) + } else { + logging.Debug("response sent successfully") + } + } +} + +// handlePing responds to ping requests. +func (w *Worker) handlePing(msg *Message) (*Message, error) { + var ping PingPayload + if err := msg.ParsePayload(&ping); err != nil { + return nil, fmt.Errorf("invalid ping payload: %w", err) + } + + pong := PongPayload{ + SentAt: ping.SentAt, + ReceivedAt: time.Now().UnixMilli(), + } + + return msg.Reply(MsgPong, pong) +} + +// handleGetStats responds with current miner statistics. +func (w *Worker) handleGetStats(msg *Message) (*Message, error) { + identity := w.node.GetIdentity() + if identity == nil { + return nil, fmt.Errorf("node identity not initialized") + } + + stats := StatsPayload{ + NodeID: identity.ID, + NodeName: identity.Name, + Miners: []MinerStatsItem{}, + Uptime: int64(time.Since(w.startTime).Seconds()), + } + + if w.minerManager != nil { + miners := w.minerManager.ListMiners() + for _, miner := range miners { + minerStats, err := miner.GetStats() + if err != nil { + continue + } + + // Convert to MinerStatsItem - this is a simplified conversion + // The actual implementation would need to match the mining package's stats structure + item := convertMinerStats(miner, minerStats) + stats.Miners = append(stats.Miners, item) + } + } + + return msg.Reply(MsgStats, stats) +} + +// convertMinerStats converts miner stats to the protocol format. +func convertMinerStats(miner MinerInstance, rawStats interface{}) MinerStatsItem { + item := MinerStatsItem{ + Name: miner.GetName(), + Type: miner.GetType(), + } + + // Try to extract common fields from the stats + if statsMap, ok := rawStats.(map[string]interface{}); ok { + if hashrate, ok := statsMap["hashrate"].(float64); ok { + item.Hashrate = hashrate + } + if shares, ok := statsMap["shares"].(int); ok { + item.Shares = shares + } + if rejected, ok := statsMap["rejected"].(int); ok { + item.Rejected = rejected + } + if uptime, ok := statsMap["uptime"].(int); ok { + item.Uptime = uptime + } + if pool, ok := statsMap["pool"].(string); ok { + item.Pool = pool + } + if algorithm, ok := statsMap["algorithm"].(string); ok { + item.Algorithm = algorithm + } + } + + return item +} + +// handleStartMiner starts a miner with the given profile. +func (w *Worker) handleStartMiner(msg *Message) (*Message, error) { + if w.minerManager == nil { + return nil, fmt.Errorf("miner manager not configured") + } + + var payload StartMinerPayload + if err := msg.ParsePayload(&payload); err != nil { + return nil, fmt.Errorf("invalid start miner payload: %w", err) + } + + // Validate miner type is provided + if payload.MinerType == "" { + return nil, fmt.Errorf("miner type is required") + } + + // Get the config from the profile or use the override + var config interface{} + if payload.Config != nil { + config = payload.Config + } else if w.profileManager != nil { + profile, err := w.profileManager.GetProfile(payload.ProfileID) + if err != nil { + return nil, fmt.Errorf("profile not found: %s", payload.ProfileID) + } + config = profile + } else { + return nil, fmt.Errorf("no config provided and no profile manager configured") + } + + // Start the miner + miner, err := w.minerManager.StartMiner(payload.MinerType, config) + if err != nil { + ack := MinerAckPayload{ + Success: false, + Error: err.Error(), + } + return msg.Reply(MsgMinerAck, ack) + } + + ack := MinerAckPayload{ + Success: true, + MinerName: miner.GetName(), + } + return msg.Reply(MsgMinerAck, ack) +} + +// handleStopMiner stops a running miner. +func (w *Worker) handleStopMiner(msg *Message) (*Message, error) { + if w.minerManager == nil { + return nil, fmt.Errorf("miner manager not configured") + } + + var payload StopMinerPayload + if err := msg.ParsePayload(&payload); err != nil { + return nil, fmt.Errorf("invalid stop miner payload: %w", err) + } + + err := w.minerManager.StopMiner(payload.MinerName) + ack := MinerAckPayload{ + Success: err == nil, + MinerName: payload.MinerName, + } + if err != nil { + ack.Error = err.Error() + } + + return msg.Reply(MsgMinerAck, ack) +} + +// handleGetLogs returns console logs from a miner. +func (w *Worker) handleGetLogs(msg *Message) (*Message, error) { + if w.minerManager == nil { + return nil, fmt.Errorf("miner manager not configured") + } + + var payload GetLogsPayload + if err := msg.ParsePayload(&payload); err != nil { + return nil, fmt.Errorf("invalid get logs payload: %w", err) + } + + // Validate and limit the Lines parameter to prevent resource exhaustion + const maxLogLines = 10000 + if payload.Lines <= 0 || payload.Lines > maxLogLines { + payload.Lines = maxLogLines + } + + miner, err := w.minerManager.GetMiner(payload.MinerName) + if err != nil { + return nil, fmt.Errorf("miner not found: %s", payload.MinerName) + } + + lines := miner.GetConsoleHistory(payload.Lines) + + logs := LogsPayload{ + MinerName: payload.MinerName, + Lines: lines, + HasMore: len(lines) >= payload.Lines, + } + + return msg.Reply(MsgLogs, logs) +} + +// handleDeploy handles deployment of profiles or miner bundles. +func (w *Worker) handleDeploy(conn *PeerConnection, msg *Message) (*Message, error) { + var payload DeployPayload + if err := msg.ParsePayload(&payload); err != nil { + return nil, fmt.Errorf("invalid deploy payload: %w", err) + } + + // Reconstruct Bundle object from payload + bundle := &Bundle{ + Type: BundleType(payload.BundleType), + Name: payload.Name, + Data: payload.Data, + Checksum: payload.Checksum, + } + + // Use shared secret as password (base64 encoded) + password := "" + if conn != nil && len(conn.SharedSecret) > 0 { + password = base64.StdEncoding.EncodeToString(conn.SharedSecret) + } + + switch bundle.Type { + case BundleProfile: + if w.profileManager == nil { + return nil, fmt.Errorf("profile manager not configured") + } + + // Decrypt and extract profile data + profileData, err := ExtractProfileBundle(bundle, password) + if err != nil { + return nil, fmt.Errorf("failed to extract profile bundle: %w", err) + } + + // Unmarshal into interface{} to pass to ProfileManager + var profile interface{} + if err := json.Unmarshal(profileData, &profile); err != nil { + return nil, fmt.Errorf("invalid profile data JSON: %w", err) + } + + if err := w.profileManager.SaveProfile(profile); err != nil { + ack := DeployAckPayload{ + Success: false, + Name: payload.Name, + Error: err.Error(), + } + return msg.Reply(MsgDeployAck, ack) + } + + ack := DeployAckPayload{ + Success: true, + Name: payload.Name, + } + return msg.Reply(MsgDeployAck, ack) + + case BundleMiner, BundleFull: + // Determine installation directory + // We use xdg.DataHome/lethean-desktop/miners/ + minersDir := filepath.Join(xdg.DataHome, "lethean-desktop", "miners") + installDir := filepath.Join(minersDir, payload.Name) + + logging.Info("deploying miner bundle", logging.Fields{ + "name": payload.Name, + "path": installDir, + "type": payload.BundleType, + }) + + // Extract miner bundle + minerPath, profileData, err := ExtractMinerBundle(bundle, password, installDir) + if err != nil { + return nil, fmt.Errorf("failed to extract miner bundle: %w", err) + } + + // If the bundle contained a profile config, save it + if len(profileData) > 0 && w.profileManager != nil { + var profile interface{} + if err := json.Unmarshal(profileData, &profile); err != nil { + logging.Warn("failed to parse profile from miner bundle", logging.Fields{"error": err}) + } else { + if err := w.profileManager.SaveProfile(profile); err != nil { + logging.Warn("failed to save profile from miner bundle", logging.Fields{"error": err}) + } + } + } + + // Success response + ack := DeployAckPayload{ + Success: true, + Name: payload.Name, + } + + // Log the installation + logging.Info("miner bundle installed successfully", logging.Fields{ + "name": payload.Name, + "miner_path": minerPath, + }) + + return msg.Reply(MsgDeployAck, ack) + + default: + return nil, fmt.Errorf("unknown bundle type: %s", payload.BundleType) + } +} + +// RegisterWithTransport registers the worker's message handler with the transport. +func (w *Worker) RegisterWithTransport() { + w.transport.OnMessage(w.HandleMessage) +} diff --git a/node/worker_test.go b/node/worker_test.go new file mode 100644 index 0000000..d27da0c --- /dev/null +++ b/node/worker_test.go @@ -0,0 +1,513 @@ +package node + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +// setupTestEnv sets up a temporary environment for testing and returns cleanup function +func setupTestEnv(t *testing.T) func() { + tmpDir := t.TempDir() + os.Setenv("XDG_CONFIG_HOME", filepath.Join(tmpDir, "config")) + os.Setenv("XDG_DATA_HOME", filepath.Join(tmpDir, "data")) + return func() { + os.Unsetenv("XDG_CONFIG_HOME") + os.Unsetenv("XDG_DATA_HOME") + } +} + +func TestNewWorker(t *testing.T) { + cleanup := setupTestEnv(t) + defer cleanup() + + nm, err := NewNodeManager() + if err != nil { + t.Fatalf("failed to create node manager: %v", err) + } + if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + if err != nil { + t.Fatalf("failed to create peer registry: %v", err) + } + + transport := NewTransport(nm, pr, DefaultTransportConfig()) + worker := NewWorker(nm, transport) + + if worker == nil { + t.Fatal("NewWorker returned nil") + } + if worker.node != nm { + t.Error("worker.node not set correctly") + } + if worker.transport != transport { + t.Error("worker.transport not set correctly") + } +} + +func TestWorker_SetMinerManager(t *testing.T) { + cleanup := setupTestEnv(t) + defer cleanup() + + nm, err := NewNodeManager() + if err != nil { + t.Fatalf("failed to create node manager: %v", err) + } + if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + if err != nil { + t.Fatalf("failed to create peer registry: %v", err) + } + + transport := NewTransport(nm, pr, DefaultTransportConfig()) + worker := NewWorker(nm, transport) + + mockManager := &mockMinerManager{} + worker.SetMinerManager(mockManager) + + if worker.minerManager != mockManager { + t.Error("minerManager not set correctly") + } +} + +func TestWorker_SetProfileManager(t *testing.T) { + cleanup := setupTestEnv(t) + defer cleanup() + + nm, err := NewNodeManager() + if err != nil { + t.Fatalf("failed to create node manager: %v", err) + } + if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + if err != nil { + t.Fatalf("failed to create peer registry: %v", err) + } + + transport := NewTransport(nm, pr, DefaultTransportConfig()) + worker := NewWorker(nm, transport) + + mockProfile := &mockProfileManager{} + worker.SetProfileManager(mockProfile) + + if worker.profileManager != mockProfile { + t.Error("profileManager not set correctly") + } +} + +func TestWorker_HandlePing(t *testing.T) { + cleanup := setupTestEnv(t) + defer cleanup() + + nm, err := NewNodeManager() + if err != nil { + t.Fatalf("failed to create node manager: %v", err) + } + if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + if err != nil { + t.Fatalf("failed to create peer registry: %v", err) + } + + transport := NewTransport(nm, pr, DefaultTransportConfig()) + worker := NewWorker(nm, transport) + + // Create a ping message + identity := nm.GetIdentity() + if identity == nil { + t.Fatal("expected identity to be generated") + } + pingPayload := PingPayload{SentAt: time.Now().UnixMilli()} + pingMsg, err := NewMessage(MsgPing, "sender-id", identity.ID, pingPayload) + if err != nil { + t.Fatalf("failed to create ping message: %v", err) + } + + // Call handlePing directly + response, err := worker.handlePing(pingMsg) + if err != nil { + t.Fatalf("handlePing returned error: %v", err) + } + + if response == nil { + t.Fatal("handlePing returned nil response") + } + + if response.Type != MsgPong { + t.Errorf("expected response type %s, got %s", MsgPong, response.Type) + } + + var pong PongPayload + if err := response.ParsePayload(&pong); err != nil { + t.Fatalf("failed to parse pong payload: %v", err) + } + + if pong.SentAt != pingPayload.SentAt { + t.Errorf("pong SentAt mismatch: expected %d, got %d", pingPayload.SentAt, pong.SentAt) + } + + if pong.ReceivedAt == 0 { + t.Error("pong ReceivedAt not set") + } +} + +func TestWorker_HandleGetStats(t *testing.T) { + cleanup := setupTestEnv(t) + defer cleanup() + + nm, err := NewNodeManager() + if err != nil { + t.Fatalf("failed to create node manager: %v", err) + } + if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + if err != nil { + t.Fatalf("failed to create peer registry: %v", err) + } + + transport := NewTransport(nm, pr, DefaultTransportConfig()) + worker := NewWorker(nm, transport) + + // Create a get_stats message + identity := nm.GetIdentity() + if identity == nil { + t.Fatal("expected identity to be generated") + } + msg, err := NewMessage(MsgGetStats, "sender-id", identity.ID, nil) + if err != nil { + t.Fatalf("failed to create get_stats message: %v", err) + } + + // Call handleGetStats directly (without miner manager) + response, err := worker.handleGetStats(msg) + if err != nil { + t.Fatalf("handleGetStats returned error: %v", err) + } + + if response == nil { + t.Fatal("handleGetStats returned nil response") + } + + if response.Type != MsgStats { + t.Errorf("expected response type %s, got %s", MsgStats, response.Type) + } + + var stats StatsPayload + if err := response.ParsePayload(&stats); err != nil { + t.Fatalf("failed to parse stats payload: %v", err) + } + + if stats.NodeID != identity.ID { + t.Errorf("stats NodeID mismatch: expected %s, got %s", identity.ID, stats.NodeID) + } + + if stats.NodeName != identity.Name { + t.Errorf("stats NodeName mismatch: expected %s, got %s", identity.Name, stats.NodeName) + } +} + +func TestWorker_HandleStartMiner_NoManager(t *testing.T) { + cleanup := setupTestEnv(t) + defer cleanup() + + nm, err := NewNodeManager() + if err != nil { + t.Fatalf("failed to create node manager: %v", err) + } + if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + if err != nil { + t.Fatalf("failed to create peer registry: %v", err) + } + + transport := NewTransport(nm, pr, DefaultTransportConfig()) + worker := NewWorker(nm, transport) + + // Create a start_miner message + identity := nm.GetIdentity() + if identity == nil { + t.Fatal("expected identity to be generated") + } + payload := StartMinerPayload{MinerType: "xmrig", ProfileID: "test-profile"} + msg, err := NewMessage(MsgStartMiner, "sender-id", identity.ID, payload) + if err != nil { + t.Fatalf("failed to create start_miner message: %v", err) + } + + // Without miner manager, should return error + _, err = worker.handleStartMiner(msg) + if err == nil { + t.Error("expected error when miner manager is nil") + } +} + +func TestWorker_HandleStopMiner_NoManager(t *testing.T) { + cleanup := setupTestEnv(t) + defer cleanup() + + nm, err := NewNodeManager() + if err != nil { + t.Fatalf("failed to create node manager: %v", err) + } + if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + if err != nil { + t.Fatalf("failed to create peer registry: %v", err) + } + + transport := NewTransport(nm, pr, DefaultTransportConfig()) + worker := NewWorker(nm, transport) + + // Create a stop_miner message + identity := nm.GetIdentity() + if identity == nil { + t.Fatal("expected identity to be generated") + } + payload := StopMinerPayload{MinerName: "test-miner"} + msg, err := NewMessage(MsgStopMiner, "sender-id", identity.ID, payload) + if err != nil { + t.Fatalf("failed to create stop_miner message: %v", err) + } + + // Without miner manager, should return error + _, err = worker.handleStopMiner(msg) + if err == nil { + t.Error("expected error when miner manager is nil") + } +} + +func TestWorker_HandleGetLogs_NoManager(t *testing.T) { + cleanup := setupTestEnv(t) + defer cleanup() + + nm, err := NewNodeManager() + if err != nil { + t.Fatalf("failed to create node manager: %v", err) + } + if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + if err != nil { + t.Fatalf("failed to create peer registry: %v", err) + } + + transport := NewTransport(nm, pr, DefaultTransportConfig()) + worker := NewWorker(nm, transport) + + // Create a get_logs message + identity := nm.GetIdentity() + if identity == nil { + t.Fatal("expected identity to be generated") + } + payload := GetLogsPayload{MinerName: "test-miner", Lines: 100} + msg, err := NewMessage(MsgGetLogs, "sender-id", identity.ID, payload) + if err != nil { + t.Fatalf("failed to create get_logs message: %v", err) + } + + // Without miner manager, should return error + _, err = worker.handleGetLogs(msg) + if err == nil { + t.Error("expected error when miner manager is nil") + } +} + +func TestWorker_HandleDeploy_Profile(t *testing.T) { + cleanup := setupTestEnv(t) + defer cleanup() + + nm, err := NewNodeManager() + if err != nil { + t.Fatalf("failed to create node manager: %v", err) + } + if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + if err != nil { + t.Fatalf("failed to create peer registry: %v", err) + } + + transport := NewTransport(nm, pr, DefaultTransportConfig()) + worker := NewWorker(nm, transport) + + // Create a deploy message for profile + identity := nm.GetIdentity() + if identity == nil { + t.Fatal("expected identity to be generated") + } + payload := DeployPayload{ + BundleType: "profile", + Data: []byte(`{"id": "test", "name": "Test Profile"}`), + Name: "test-profile", + } + msg, err := NewMessage(MsgDeploy, "sender-id", identity.ID, payload) + if err != nil { + t.Fatalf("failed to create deploy message: %v", err) + } + + // Without profile manager, should return error + _, err = worker.handleDeploy(nil, msg) + if err == nil { + t.Error("expected error when profile manager is nil") + } +} + +func TestWorker_HandleDeploy_UnknownType(t *testing.T) { + cleanup := setupTestEnv(t) + defer cleanup() + + nm, err := NewNodeManager() + if err != nil { + t.Fatalf("failed to create node manager: %v", err) + } + if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { + t.Fatalf("failed to generate identity: %v", err) + } + + pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + if err != nil { + t.Fatalf("failed to create peer registry: %v", err) + } + + transport := NewTransport(nm, pr, DefaultTransportConfig()) + worker := NewWorker(nm, transport) + + // Create a deploy message with unknown type + identity := nm.GetIdentity() + if identity == nil { + t.Fatal("expected identity to be generated") + } + payload := DeployPayload{ + BundleType: "unknown", + Data: []byte(`{}`), + Name: "test", + } + msg, err := NewMessage(MsgDeploy, "sender-id", identity.ID, payload) + if err != nil { + t.Fatalf("failed to create deploy message: %v", err) + } + + _, err = worker.handleDeploy(nil, msg) + if err == nil { + t.Error("expected error for unknown bundle type") + } +} + +func TestConvertMinerStats(t *testing.T) { + tests := []struct { + name string + rawStats interface{} + wantHash float64 + }{ + { + name: "MapWithHashrate", + rawStats: map[string]interface{}{ + "hashrate": 100.5, + "shares": 10, + "rejected": 2, + "uptime": 3600, + "pool": "test-pool", + "algorithm": "rx/0", + }, + wantHash: 100.5, + }, + { + name: "EmptyMap", + rawStats: map[string]interface{}{}, + wantHash: 0, + }, + { + name: "NonMap", + rawStats: "not a map", + wantHash: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockMinerInstance{name: "test", minerType: "xmrig"} + result := convertMinerStats(mock, tt.rawStats) + + if result.Name != "test" { + t.Errorf("expected name 'test', got '%s'", result.Name) + } + if result.Hashrate != tt.wantHash { + t.Errorf("expected hashrate %f, got %f", tt.wantHash, result.Hashrate) + } + }) + } +} + +// Mock implementations for testing + +type mockMinerManager struct { + miners []MinerInstance +} + +func (m *mockMinerManager) StartMiner(minerType string, config interface{}) (MinerInstance, error) { + return nil, nil +} + +func (m *mockMinerManager) StopMiner(name string) error { + return nil +} + +func (m *mockMinerManager) ListMiners() []MinerInstance { + return m.miners +} + +func (m *mockMinerManager) GetMiner(name string) (MinerInstance, error) { + for _, miner := range m.miners { + if miner.GetName() == name { + return miner, nil + } + } + return nil, nil +} + +type mockMinerInstance struct { + name string + minerType string + stats interface{} +} + +func (m *mockMinerInstance) GetName() string { return m.name } +func (m *mockMinerInstance) GetType() string { return m.minerType } +func (m *mockMinerInstance) GetStats() (interface{}, error) { return m.stats, nil } +func (m *mockMinerInstance) GetConsoleHistory(lines int) []string { return []string{} } + +type mockProfileManager struct{} + +func (m *mockProfileManager) GetProfile(id string) (interface{}, error) { + return nil, nil +} + +func (m *mockProfileManager) SaveProfile(profile interface{}) error { + return nil +} diff --git a/ueps/packet.go b/ueps/packet.go new file mode 100644 index 0000000..7c75334 --- /dev/null +++ b/ueps/packet.go @@ -0,0 +1,124 @@ +package ueps + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/binary" + "errors" + "io" +) + +// TLV Types +const ( + TagVersion = 0x01 + TagCurrentLay = 0x02 + TagTargetLay = 0x03 + TagIntent = 0x04 + TagThreatScore = 0x05 + TagHMAC = 0x06 // The Signature + TagPayload = 0xFF // The Data +) + +// UEPSHeader represents the conscious routing metadata +type UEPSHeader struct { + Version uint8 // Default 0x09 + CurrentLayer uint8 + TargetLayer uint8 + IntentID uint8 // Semantic Token + ThreatScore uint16 // 0-65535 +} + +// PacketBuilder helps construct a signed UEPS frame +type PacketBuilder struct { + Header UEPSHeader + Payload []byte +} + +// NewBuilder creates a packet context for a specific intent +func NewBuilder(intentID uint8, payload []byte) *PacketBuilder { + return &PacketBuilder{ + Header: UEPSHeader{ + Version: 0x09, // IPv9 + CurrentLayer: 5, // Application + TargetLayer: 5, // Application + IntentID: intentID, + ThreatScore: 0, // Assumed innocent until proven guilty + }, + Payload: payload, + } +} + +// MarshalAndSign generates the final byte stream using the shared secret +func (p *PacketBuilder) MarshalAndSign(sharedSecret []byte) ([]byte, error) { + buf := new(bytes.Buffer) + + // 1. Write Standard Header Tags (0x01 - 0x05) + // We write these first because they are part of what we sign. + if err := writeTLV(buf, TagVersion, []byte{p.Header.Version}); err != nil { + return nil, err + } + if err := writeTLV(buf, TagCurrentLay, []byte{p.Header.CurrentLayer}); err != nil { + return nil, err + } + if err := writeTLV(buf, TagTargetLay, []byte{p.Header.TargetLayer}); err != nil { + return nil, err + } + if err := writeTLV(buf, TagIntent, []byte{p.Header.IntentID}); err != nil { + return nil, err + } + + // Threat Score is uint16, needs binary packing + tsBuf := make([]byte, 2) + binary.BigEndian.PutUint16(tsBuf, p.Header.ThreatScore) + if err := writeTLV(buf, TagThreatScore, tsBuf); err != nil { + return nil, err + } + + // 2. Calculate HMAC + // The signature covers: Existing Header TLVs + The Payload + // It does NOT cover the HMAC TLV tag itself (obviously) + mac := hmac.New(sha256.New, sharedSecret) + mac.Write(buf.Bytes()) // The headers so far + mac.Write(p.Payload) // The data + signature := mac.Sum(nil) + + // 3. Write HMAC TLV (0x06) + // Length is 32 bytes for SHA256 + if err := writeTLV(buf, TagHMAC, signature); err != nil { + return nil, err + } + + // 4. Write Payload TLV (0xFF) + // Note: 0xFF length is variable. For simplicity in this specialized reader, + // we might handle 0xFF as "read until EOF" or use a varint length. + // Implementing standard 1-byte length for payload is risky if payload > 255. + // Assuming your spec allows >255 bytes, we handle 0xFF differently. + + buf.WriteByte(TagPayload) + // We don't write a 1-byte length for payload here assuming stream mode, + // but if strict TLV, we'd need a multi-byte length protocol. + // For this snippet, simply appending data: + buf.Write(p.Payload) + + return buf.Bytes(), nil +} + +// Helper to write a simple TLV +func writeTLV(w io.Writer, tag uint8, value []byte) error { + // Check strict length constraint (1 byte length = max 255 bytes) + if len(value) > 255 { + return errors.New("TLV value too large for 1-byte length header") + } + + if _, err := w.Write([]byte{tag}); err != nil { + return err + } + if _, err := w.Write([]byte{uint8(len(value))}); err != nil { + return err + } + if _, err := w.Write(value); err != nil { + return err + } + return nil +} diff --git a/ueps/reader.go b/ueps/reader.go new file mode 100644 index 0000000..d17b332 --- /dev/null +++ b/ueps/reader.go @@ -0,0 +1,138 @@ +package ueps + +import ( + "bufio" + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "io" +) + +// ParsedPacket holds the verified data +type ParsedPacket struct { + Header UEPSHeader + Payload []byte +} + +// ReadAndVerify reads a UEPS frame from the stream and validates the HMAC. +// It consumes the stream up to the end of the packet. +func ReadAndVerify(r *bufio.Reader, sharedSecret []byte) (*ParsedPacket, error) { + // Buffer to reconstruct the data for HMAC verification + // We have to "record" what we read to verify the signature later. + var signedData bytes.Buffer + header := UEPSHeader{} + var signature []byte + var payload []byte + + // Loop through TLVs until we hit Payload (0xFF) or EOF + for { + // 1. Read Tag + tag, err := r.ReadByte() + if err != nil { + return nil, err + } + + // 2. Handle Payload Tag (0xFF) - The Exit Condition + if tag == TagPayload { + // Stop recording signedData here (HMAC covers headers + payload, but logic splits) + // Actually, wait. The HMAC covers (Headers + Payload). + // We need to read the payload to verify. + + // For this implementation, we read until EOF or a specific delimiter? + // In a TCP stream, we need a length. + // If you are using standard TCP, you typically prefix the WHOLE frame with + // a 4-byte length. Assuming you handle that framing *before* calling this. + + // Reading the rest as payload: + remaining, err := io.ReadAll(r) + if err != nil { + return nil, err + } + payload = remaining + + // Add 0xFF and payload to the buffer for signature check? + // NO. In MarshalAndSign: + // mac.Write(buf.Bytes()) // Headers + // mac.Write(p.Payload) // Data + // It did NOT write the 0xFF tag into the HMAC. + + break // Exit loop + } + + // 3. Read Length (Standard TLV) + lengthByte, err := r.ReadByte() + if err != nil { + return nil, err + } + length := int(lengthByte) + + // 4. Read Value + value := make([]byte, length) + if _, err := io.ReadFull(r, value); err != nil { + return nil, err + } + + // Store for processing + switch tag { + case TagVersion: + header.Version = value[0] + // Reconstruct signed data: Tag + Len + Val + signedData.WriteByte(tag) + signedData.WriteByte(byte(length)) + signedData.Write(value) + case TagCurrentLay: + header.CurrentLayer = value[0] + signedData.WriteByte(tag) + signedData.WriteByte(byte(length)) + signedData.Write(value) + case TagTargetLay: + header.TargetLayer = value[0] + signedData.WriteByte(tag) + signedData.WriteByte(byte(length)) + signedData.Write(value) + case TagIntent: + header.IntentID = value[0] + signedData.WriteByte(tag) + signedData.WriteByte(byte(length)) + signedData.Write(value) + case TagThreatScore: + header.ThreatScore = binary.BigEndian.Uint16(value) + signedData.WriteByte(tag) + signedData.WriteByte(byte(length)) + signedData.Write(value) + case TagHMAC: + signature = value + // We do NOT add the HMAC itself to signedData + default: + // Unknown tag (future proofing), verify it but ignore semantics + signedData.WriteByte(tag) + signedData.WriteByte(byte(length)) + signedData.Write(value) + } + } + + if len(signature) == 0 { + return nil, errors.New("UEPS packet missing HMAC signature") + } + + // 5. Verify HMAC + // Reconstruct: Headers (signedData) + Payload + mac := hmac.New(sha256.New, sharedSecret) + mac.Write(signedData.Bytes()) + mac.Write(payload) + expectedMAC := mac.Sum(nil) + + if !hmac.Equal(signature, expectedMAC) { + // Log this. This is a Threat Event. + // "Axiom Violation: Integrity Check Failed" + return nil, fmt.Errorf("integrity violation: HMAC mismatch (ThreatScore +100)") + } + + return &ParsedPacket{ + Header: header, + Payload: payload, + }, nil +}