Compare commits

..

1 commit
dev ... main

Author SHA1 Message Date
Virgil
61fb52e8f2 fix(node): harden bundle stream error handling
All checks were successful
Security Scan / security (push) Successful in 14s
Test / test (push) Successful in 2m13s
Co-Authored-By: Virgil <virgil@lethean.io>
2026-03-29 15:33:52 +00:00
12 changed files with 50 additions and 343 deletions

View file

@ -98,7 +98,7 @@ The `Transport` manages a WebSocket server (gorilla/websocket) and outbound conn
| Timeout | 3.0 (floored at 0) |
| Default (new peer) | 50.0 |
**Peer name validation**: Empty names are permitted. Non-empty names must be 164 characters, start and end with an alphanumeric character, and contain only alphanumeric, hyphen, underscore, or space characters.
**Peer name validation**: Names must be 164 characters, start and end with an alphanumeric character, and contain only alphanumeric, hyphen, underscore, or space characters.
### message.go — Protocol Messages

View file

@ -237,15 +237,15 @@ func createTarball(files map[string][]byte) ([]byte, error) {
Size: int64(len(content)),
}
if err := tw.WriteHeader(hdr); err != nil {
return nil, err
return nil, coreerr.E("createTarball", "failed to write tar header", err)
}
if _, err := tw.Write(content); err != nil {
return nil, err
return nil, coreerr.E("createTarball", "failed to write tar content", err)
}
}
if err := tw.Close(); err != nil {
return nil, err
return nil, coreerr.E("createTarball", "failed to close tar writer", err)
}
return buf.Bytes(), nil
@ -261,11 +261,11 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
absDestDir = filepath.Clean(absDestDir)
if err := coreio.Local.EnsureDir(absDestDir); err != nil {
return "", err
return "", coreerr.E("extractTarball", "failed to ensure destination directory", err)
}
tr := tar.NewReader(bytes.NewReader(tarData))
var firstExecutable string
var firstExecutablePath string
for {
hdr, err := tr.Next()
@ -273,7 +273,7 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
break
}
if err != nil {
return "", err
return "", coreerr.E("extractTarball", "failed to read tar entry", err)
}
// Security: Sanitize the tar entry name to prevent path traversal (Zip Slip)
@ -301,12 +301,12 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
switch hdr.Typeflag {
case tar.TypeDir:
if err := coreio.Local.EnsureDir(fullPath); err != nil {
return "", err
return "", coreerr.E("extractTarball", "failed to create directory "+cleanName, err)
}
case tar.TypeReg:
// Ensure parent directory exists
if err := coreio.Local.EnsureDir(filepath.Dir(fullPath)); err != nil {
return "", err
return "", coreerr.E("extractTarball", "failed to create parent directory for "+cleanName, err)
}
// os.OpenFile is used deliberately here instead of coreio.Local.Create/Write
@ -321,18 +321,24 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
const maxFileSize int64 = 100 * 1024 * 1024
limitedReader := io.LimitReader(tr, maxFileSize+1)
written, err := io.Copy(f, limitedReader)
f.Close()
if err != nil {
_ = f.Close()
return "", coreerr.E("extractTarball", "failed to write file "+hdr.Name, err)
}
if err := f.Close(); err != nil {
return "", coreerr.E("extractTarball", "failed to close extracted file "+hdr.Name, err)
}
if written > maxFileSize {
coreio.Local.Delete(fullPath)
if err := coreio.Local.Delete(fullPath); err != nil {
return "", coreerr.E("extractTarball", "failed to clean up oversized file "+hdr.Name, err)
}
return "", coreerr.E("extractTarball", "file "+hdr.Name+" exceeds maximum size", nil)
}
// Track first executable
if hdr.Mode&0111 != 0 && firstExecutable == "" {
firstExecutable = fullPath
if hdr.Mode&0111 != 0 && firstExecutablePath == "" {
firstExecutablePath = fullPath
}
// Explicitly ignore symlinks and hard links to prevent symlink attacks
case tar.TypeSymlink, tar.TypeLink:
@ -341,13 +347,17 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
}
}
return firstExecutable, nil
return firstExecutablePath, nil
}
// StreamBundle writes a bundle to a writer (for large transfers).
func StreamBundle(bundle *Bundle, w io.Writer) error {
encoder := json.NewEncoder(w)
return encoder.Encode(bundle)
if err := encoder.Encode(bundle); err != nil {
return coreerr.E("StreamBundle", "failed to encode bundle", err)
}
return nil
}
// ReadBundle reads a bundle from a reader.
@ -355,7 +365,7 @@ func ReadBundle(r io.Reader) (*Bundle, error) {
var bundle Bundle
decoder := json.NewDecoder(r)
if err := decoder.Decode(&bundle); err != nil {
return nil, err
return nil, coreerr.E("ReadBundle", "failed to decode bundle", err)
}
return &bundle, nil
}

View file

@ -210,11 +210,6 @@ func (c *Controller) StopRemoteMiner(peerID, minerName string) error {
// GetRemoteLogs requests console logs from a remote miner.
func (c *Controller) GetRemoteLogs(peerID, minerName string, lines int) ([]string, error) {
return c.GetRemoteLogsSince(peerID, minerName, lines, time.Time{})
}
// GetRemoteLogsSince requests console logs from a remote miner after a point in time.
func (c *Controller) GetRemoteLogsSince(peerID, minerName string, lines int, since time.Time) ([]string, error) {
identity := c.node.GetIdentity()
if identity == nil {
return nil, ErrIdentityNotInitialized
@ -224,13 +219,10 @@ func (c *Controller) GetRemoteLogsSince(peerID, minerName string, lines int, sin
MinerName: minerName,
Lines: lines,
}
if !since.IsZero() {
payload.Since = since.UnixMilli()
}
msg, err := NewMessage(MsgGetLogs, identity.ID, peerID, payload)
if err != nil {
return nil, coreerr.E("Controller.GetRemoteLogsSince", "failed to create message", err)
return nil, coreerr.E("Controller.GetRemoteLogs", "failed to create message", err)
}
resp, err := c.sendRequest(peerID, msg, 10*time.Second)

View file

@ -7,7 +7,6 @@ import (
"net/http/httptest"
"net/url"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
@ -515,40 +514,6 @@ type mockMinerFull struct {
func (m *mockMinerFull) GetName() string { return m.name }
func (m *mockMinerFull) GetType() string { return m.minerType }
func (m *mockMinerFull) GetStats() (any, error) { return m.stats, nil }
func (m *mockMinerFull) GetConsoleHistorySince(lines int, since time.Time) []string {
if since.IsZero() {
if lines >= len(m.consoleHistory) {
return m.consoleHistory
}
return m.consoleHistory[:lines]
}
filtered := make([]string, 0, len(m.consoleHistory))
for _, line := range m.consoleHistory {
if lineAfter(line, since) {
filtered = append(filtered, line)
}
}
if lines >= len(filtered) {
return filtered
}
return filtered[:lines]
}
func lineAfter(line string, since time.Time) bool {
start := strings.IndexByte(line, '[')
end := strings.IndexByte(line, ']')
if start != 0 || end <= start+1 {
return true
}
ts, err := time.Parse("2006-01-02 15:04:05", line[start+1:end])
if err != nil {
return true
}
return ts.After(since) || ts.Equal(since)
}
func (m *mockMinerFull) GetConsoleHistory(lines int) []string {
if lines >= len(m.consoleHistory) {
return m.consoleHistory
@ -651,20 +616,6 @@ func TestController_GetRemoteLogs_LimitedLines(t *testing.T) {
assert.Len(t, lines, 1, "should return only 1 line")
}
func TestController_GetRemoteLogsSince(t *testing.T) {
controller, _, tp := setupControllerPairWithMiner(t)
serverID := tp.ServerNode.GetIdentity().ID
since, err := time.Parse("2006-01-02 15:04:05", "2026-02-20 10:00:01")
require.NoError(t, err)
lines, err := controller.GetRemoteLogsSince(serverID, "running-miner", 10, since)
require.NoError(t, err, "GetRemoteLogsSince should succeed")
require.Len(t, lines, 2, "should return only log lines on or after the requested timestamp")
assert.Contains(t, lines[0], "connected to pool")
assert.Contains(t, lines[1], "new job received")
}
func TestController_GetRemoteLogs_NoIdentity(t *testing.T) {
tp := setupTestTransportPair(t)
nmNoID, err := NewNodeManagerWithPaths(

View file

@ -8,7 +8,6 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"os"
"path/filepath"
"sync"
"time"
@ -109,48 +108,6 @@ func NewNodeManagerWithPaths(keyPath, configPath string) (*NodeManager, error) {
return nm, nil
}
// LoadOrCreateIdentity loads the node identity from the default XDG paths or
// generates a new dual-role identity when none exists yet.
func LoadOrCreateIdentity() (*NodeManager, error) {
keyPath, err := xdg.DataFile("lethean-desktop/node/private.key")
if err != nil {
return nil, coreerr.E("LoadOrCreateIdentity", "failed to get key path", err)
}
configPath, err := xdg.ConfigFile("lethean-desktop/node.json")
if err != nil {
return nil, coreerr.E("LoadOrCreateIdentity", "failed to get config path", err)
}
return LoadOrCreateIdentityWithPaths(keyPath, configPath)
}
// LoadOrCreateIdentityWithPaths loads an existing identity from the supplied
// paths or creates a new dual-role identity if no persisted identity exists.
// The generated identity name falls back to the host name, then a stable
// project-specific default if the host name cannot be determined.
func LoadOrCreateIdentityWithPaths(keyPath, configPath string) (*NodeManager, error) {
nm, err := NewNodeManagerWithPaths(keyPath, configPath)
if err != nil {
return nil, err
}
if nm.HasIdentity() {
return nm, nil
}
name, err := os.Hostname()
if err != nil || name == "" {
name = "lethean-node"
}
if err := nm.GenerateIdentity(name, RoleDual); err != nil {
return nil, coreerr.E("LoadOrCreateIdentityWithPaths", "failed to generate identity", err)
}
return nm, nil
}
// HasIdentity returns true if a node identity has been initialized.
func (n *NodeManager) HasIdentity() bool {
n.mu.RLock()
@ -251,13 +208,10 @@ func (n *NodeManager) savePrivateKey() error {
return coreerr.E("NodeManager.savePrivateKey", "failed to create key directory", err)
}
// Write private key and then tighten permissions explicitly.
// Write private key
if err := coreio.Local.Write(n.keyPath, string(n.privateKey)); err != nil {
return coreerr.E("NodeManager.savePrivateKey", "failed to write private key", err)
}
if err := os.Chmod(n.keyPath, 0600); err != nil {
return coreerr.E("NodeManager.savePrivateKey", "failed to set private key permissions", err)
}
return nil
}

View file

@ -74,25 +74,6 @@ func TestNodeIdentity(t *testing.T) {
}
})
t.Run("PrivateKeyPermissions", func(t *testing.T) {
nm, cleanup := setupTestNodeManager(t)
defer cleanup()
err := nm.GenerateIdentity("permission-test", RoleDual)
if err != nil {
t.Fatalf("failed to generate identity: %v", err)
}
info, err := os.Stat(nm.keyPath)
if err != nil {
t.Fatalf("failed to stat private key: %v", err)
}
if got := info.Mode().Perm(); got != 0600 {
t.Fatalf("expected private key permissions 0600, got %04o", got)
}
})
t.Run("LoadExistingIdentity", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "node-load-test")
if err != nil {
@ -215,47 +196,6 @@ func TestNodeIdentity(t *testing.T) {
t.Error("should not have identity after delete")
}
})
t.Run("LoadOrCreateIdentityWithPaths", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "node-load-or-create-test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
keyPath := filepath.Join(tmpDir, "private.key")
configPath := filepath.Join(tmpDir, "node.json")
nm, err := LoadOrCreateIdentityWithPaths(keyPath, configPath)
if err != nil {
t.Fatalf("failed to load or create identity: %v", err)
}
if !nm.HasIdentity() {
t.Fatal("expected identity to be initialised")
}
identity := nm.GetIdentity()
if identity == nil {
t.Fatal("identity should not be nil")
}
if identity.Name == "" {
t.Error("identity name should be populated")
}
if identity.Role != RoleDual {
t.Errorf("expected default role dual, got %s", identity.Role)
}
if _, err := os.Stat(keyPath); err != nil {
t.Fatalf("expected private key to be persisted: %v", err)
}
if _, err := os.Stat(configPath); err != nil {
t.Fatalf("expected identity config to be persisted: %v", err)
}
})
}
func TestNodeRoles(t *testing.T) {

View file

@ -51,8 +51,9 @@ const (
PeerAuthAllowlist
)
// Peer name validation constants.
// Peer name validation constants
const (
PeerNameMinLength = 1
PeerNameMaxLength = 64
)
@ -71,12 +72,14 @@ func safeKeyPrefix(key string) string {
}
// validatePeerName checks if a peer name is valid.
// Empty names are permitted. Non-empty names must be 1-64 characters,
// start and end with alphanumeric, and contain only alphanumeric,
// hyphens, underscores, and spaces.
// Peer names must be 1-64 characters, start and end with alphanumeric,
// and contain only alphanumeric, hyphens, underscores, and spaces.
func validatePeerName(name string) error {
if name == "" {
return nil
return nil // Empty names are allowed (optional field)
}
if len(name) < PeerNameMinLength {
return coreerr.E("validatePeerName", "peer name too short", nil)
}
if len(name) > PeerNameMaxLength {
return coreerr.E("validatePeerName", "peer name too long", nil)
@ -98,7 +101,6 @@ type PeerRegistry struct {
authMode PeerAuthMode // How to handle unknown peers
allowedPublicKeys map[string]bool // Allowlist of public keys (when authMode is Allowlist)
allowedPublicKeyMu sync.RWMutex // Protects allowedPublicKeys
allowlistPath string // Sidecar file for persisted allowlist keys
// Debounce disk writes
dirty bool // Whether there are unsaved changes
@ -133,7 +135,6 @@ func NewPeerRegistryWithPath(peersPath string) (*PeerRegistry, error) {
pr := &PeerRegistry{
peers: make(map[string]*Peer),
path: peersPath,
allowlistPath: peersPath + ".allowlist.json",
stopChan: make(chan struct{}),
authMode: PeerAuthOpen, // Default to open for backward compatibility
allowedPublicKeys: make(map[string]bool),
@ -143,12 +144,7 @@ func NewPeerRegistryWithPath(peersPath string) (*PeerRegistry, error) {
if err := pr.load(); err != nil {
// No existing peers, that's ok
pr.rebuildKDTree()
}
// Load any persisted allowlist entries. This is best effort so that a
// missing or corrupt sidecar does not block peer registry startup.
if err := pr.loadAllowedPublicKeys(); err != nil {
logging.Warn("failed to load peer allowlist", logging.Fields{"error": err})
return pr, nil
}
pr.rebuildKDTree()
@ -173,25 +169,17 @@ func (r *PeerRegistry) GetAuthMode() PeerAuthMode {
// AllowPublicKey adds a public key to the allowlist.
func (r *PeerRegistry) AllowPublicKey(publicKey string) {
r.allowedPublicKeyMu.Lock()
defer r.allowedPublicKeyMu.Unlock()
r.allowedPublicKeys[publicKey] = true
r.allowedPublicKeyMu.Unlock()
logging.Debug("public key added to allowlist", logging.Fields{"key": safeKeyPrefix(publicKey)})
if err := r.saveAllowedPublicKeys(); err != nil {
logging.Warn("failed to persist peer allowlist", logging.Fields{"error": err})
}
}
// RevokePublicKey removes a public key from the allowlist.
func (r *PeerRegistry) RevokePublicKey(publicKey string) {
r.allowedPublicKeyMu.Lock()
defer r.allowedPublicKeyMu.Unlock()
delete(r.allowedPublicKeys, publicKey)
r.allowedPublicKeyMu.Unlock()
logging.Debug("public key removed from allowlist", logging.Fields{"key": safeKeyPrefix(publicKey)})
if err := r.saveAllowedPublicKeys(); err != nil {
logging.Warn("failed to persist peer allowlist", logging.Fields{"error": err})
}
}
// IsPublicKeyAllowed checks if a public key is in the allowlist.
@ -720,72 +708,6 @@ func (r *PeerRegistry) Close() error {
return nil
}
// saveAllowedPublicKeys persists the allowlist to disk immediately.
// It keeps the allowlist in a separate sidecar file so peer persistence remains
// backwards compatible with the existing peers.json array format.
func (r *PeerRegistry) saveAllowedPublicKeys() error {
r.allowedPublicKeyMu.RLock()
keys := make([]string, 0, len(r.allowedPublicKeys))
for key := range r.allowedPublicKeys {
keys = append(keys, key)
}
r.allowedPublicKeyMu.RUnlock()
slices.Sort(keys)
dir := filepath.Dir(r.allowlistPath)
if err := coreio.Local.EnsureDir(dir); err != nil {
return coreerr.E("PeerRegistry.saveAllowedPublicKeys", "failed to create allowlist directory", err)
}
data, err := json.MarshalIndent(keys, "", " ")
if err != nil {
return coreerr.E("PeerRegistry.saveAllowedPublicKeys", "failed to marshal allowlist", err)
}
tmpPath := r.allowlistPath + ".tmp"
if err := coreio.Local.Write(tmpPath, string(data)); err != nil {
return coreerr.E("PeerRegistry.saveAllowedPublicKeys", "failed to write allowlist temp file", err)
}
if err := coreio.Local.Rename(tmpPath, r.allowlistPath); err != nil {
coreio.Local.Delete(tmpPath)
return coreerr.E("PeerRegistry.saveAllowedPublicKeys", "failed to rename allowlist file", err)
}
return nil
}
// loadAllowedPublicKeys loads the allowlist from disk.
func (r *PeerRegistry) loadAllowedPublicKeys() error {
if !coreio.Local.Exists(r.allowlistPath) {
return nil
}
content, err := coreio.Local.Read(r.allowlistPath)
if err != nil {
return coreerr.E("PeerRegistry.loadAllowedPublicKeys", "failed to read allowlist", err)
}
var keys []string
if err := json.Unmarshal([]byte(content), &keys); err != nil {
return coreerr.E("PeerRegistry.loadAllowedPublicKeys", "failed to unmarshal allowlist", err)
}
r.allowedPublicKeyMu.Lock()
defer r.allowedPublicKeyMu.Unlock()
r.allowedPublicKeys = make(map[string]bool, len(keys))
for _, key := range keys {
if key == "" {
continue
}
r.allowedPublicKeys[key] = true
}
return nil
}
// save is a helper that schedules a debounced save.
// Kept for backward compatibility but now debounces writes.
// Must NOT be called with r.mu held.

View file

@ -389,39 +389,6 @@ func TestPeerRegistry_Persistence(t *testing.T) {
}
}
func TestPeerRegistry_AllowlistPersistence(t *testing.T) {
tmpDir, _ := os.MkdirTemp("", "allowlist-persist-test")
defer os.RemoveAll(tmpDir)
peersPath := filepath.Join(tmpDir, "peers.json")
pr1, err := NewPeerRegistryWithPath(peersPath)
if err != nil {
t.Fatalf("failed to create first registry: %v", err)
}
key := "allowlist-key-1234567890"
pr1.AllowPublicKey(key)
if err := pr1.Close(); err != nil {
t.Fatalf("failed to close first registry: %v", err)
}
pr2, err := NewPeerRegistryWithPath(peersPath)
if err != nil {
t.Fatalf("failed to create second registry: %v", err)
}
if !pr2.IsPublicKeyAllowed(key) {
t.Fatal("expected allowlisted key to survive reload")
}
keys := pr2.ListAllowedPublicKeys()
if !slices.Contains(keys, key) {
t.Fatalf("expected allowlisted key to be listed after reload, got %v", keys)
}
}
// --- Security Feature Tests ---
func TestPeerRegistry_AuthMode(t *testing.T) {

View file

@ -76,20 +76,10 @@ func NewMessageDeduplicator(ttl time.Duration) *MessageDeduplicator {
// IsDuplicate checks if a message ID has been seen recently
func (d *MessageDeduplicator) IsDuplicate(msgID string) bool {
d.mu.Lock()
defer d.mu.Unlock()
seenAt, exists := d.seen[msgID]
if !exists {
return false
}
if d.ttl > 0 && time.Since(seenAt) > d.ttl {
delete(d.seen, msgID)
return false
}
return true
d.mu.RLock()
_, exists := d.seen[msgID]
d.mu.RUnlock()
return exists
}
// Mark records a message ID as seen

View file

@ -159,17 +159,6 @@ func TestMessageDeduplicator(t *testing.T) {
}
})
t.Run("ExpiredEntriesAreNotDuplicates", func(t *testing.T) {
d := NewMessageDeduplicator(25 * time.Millisecond)
d.Mark("msg-expired")
time.Sleep(40 * time.Millisecond)
if d.IsDuplicate("msg-expired") {
t.Error("expired message should not remain a duplicate")
}
})
t.Run("ConcurrentAccess", func(t *testing.T) {
d := NewMessageDeduplicator(5 * time.Minute)
var wg sync.WaitGroup

View file

@ -26,7 +26,7 @@ type MinerInstance interface {
GetName() string
GetType() string
GetStats() (any, error)
GetConsoleHistorySince(lines int, since time.Time) []string
GetConsoleHistory(lines int) []string
}
// ProfileManager interface for profile operations.
@ -55,6 +55,7 @@ func NewWorker(node *NodeManager, transport *Transport) *Worker {
}
}
// SetMinerManager sets the miner manager for handling miner operations.
func (w *Worker) SetMinerManager(manager MinerManager) {
w.minerManager = manager
@ -285,12 +286,7 @@ func (w *Worker) handleGetLogs(msg *Message) (*Message, error) {
return nil, coreerr.E("Worker.handleGetLogs", "miner not found: "+payload.MinerName, nil)
}
var since time.Time
if payload.Since > 0 {
since = time.UnixMilli(payload.Since)
}
lines := miner.GetConsoleHistorySince(payload.Lines, since)
lines := miner.GetConsoleHistory(payload.Lines)
logs := LogsPayload{
MinerName: payload.MinerName,

View file

@ -550,14 +550,10 @@ type mockMinerInstance struct {
stats any
}
func (m *mockMinerInstance) GetName() string { return m.name }
func (m *mockMinerInstance) GetType() string { return m.minerType }
func (m *mockMinerInstance) GetStats() (any, error) {
return m.stats, nil
}
func (m *mockMinerInstance) GetConsoleHistorySince(lines int, since time.Time) []string {
return []string{}
}
func (m *mockMinerInstance) GetName() string { return m.name }
func (m *mockMinerInstance) GetType() string { return m.minerType }
func (m *mockMinerInstance) GetStats() (any, error) { return m.stats, nil }
func (m *mockMinerInstance) GetConsoleHistory(lines int) []string { return []string{} }
type mockProfileManager struct{}