refactor(node): tighten AX naming across core paths
All checks were successful
Security Scan / security (push) Successful in 11s
Test / test (push) Successful in 1m38s

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-03-30 22:31:11 +00:00
parent dec79b54d6
commit 819862a1a4
4 changed files with 250 additions and 259 deletions

View file

@ -55,14 +55,14 @@ type BundleManifest struct {
// bundle, err := CreateProfileBundle(profileJSON, "xmrig-default", "password")
func CreateProfileBundle(profileJSON []byte, name string, password string) (*Bundle, error) {
// Create a TIM with just the profile config
t, err := tim.New()
timBundle, err := tim.New()
if err != nil {
return nil, core.E("CreateProfileBundle", "failed to create TIM", err)
}
t.Config = profileJSON
timBundle.Config = profileJSON
// Encrypt to STIM format
stimData, err := t.ToSigil(password)
stimData, err := timBundle.ToSigil(password)
if err != nil {
return nil, core.E("CreateProfileBundle", "failed to encrypt bundle", err)
}
@ -112,24 +112,24 @@ func CreateMinerBundle(minerPath string, profileJSON []byte, name string, passwo
}
// Create DataNode from tarball
dn, err := datanode.FromTar(tarData)
dataNode, err := datanode.FromTar(tarData)
if err != nil {
return nil, core.E("CreateMinerBundle", "failed to create datanode", err)
}
// Create TIM from DataNode
t, err := tim.FromDataNode(dn)
timBundle, err := tim.FromDataNode(dataNode)
if err != nil {
return nil, core.E("CreateMinerBundle", "failed to create TIM", err)
}
// Set profile as config if provided
if profileJSON != nil {
t.Config = profileJSON
timBundle.Config = profileJSON
}
// Encrypt to STIM format
stimData, err := t.ToSigil(password)
stimData, err := timBundle.ToSigil(password)
if err != nil {
return nil, core.E("CreateMinerBundle", "failed to encrypt bundle", err)
}
@ -159,12 +159,12 @@ func ExtractProfileBundle(bundle *Bundle, password string) ([]byte, error) {
}
// Decrypt STIM format
t, err := tim.FromSigil(bundle.Data, password)
timBundle, err := tim.FromSigil(bundle.Data, password)
if err != nil {
return nil, core.E("ExtractProfileBundle", "failed to decrypt bundle", err)
}
return t.Config, nil
return timBundle.Config, nil
}
// ExtractMinerBundle decrypts and extracts a miner bundle, returning the miner path and profile.
@ -177,13 +177,13 @@ func ExtractMinerBundle(bundle *Bundle, password string, destDir string) (string
}
// Decrypt STIM format
t, err := tim.FromSigil(bundle.Data, password)
timBundle, err := tim.FromSigil(bundle.Data, password)
if err != nil {
return "", nil, core.E("ExtractMinerBundle", "failed to decrypt bundle", err)
}
// Convert rootfs to tarball and extract
tarData, err := t.RootFS.ToTar()
tarData, err := timBundle.RootFS.ToTar()
if err != nil {
return "", nil, core.E("ExtractMinerBundle", "failed to convert rootfs to tar", err)
}
@ -194,7 +194,7 @@ func ExtractMinerBundle(bundle *Bundle, password string, destDir string) (string
return "", nil, core.E("ExtractMinerBundle", "failed to extract tarball", err)
}
return minerPath, t.Config, nil
return minerPath, timBundle.Config, nil
}
// VerifyBundle checks if a bundle's checksum is valid.
@ -222,24 +222,24 @@ func isJSON(data []byte) bool {
// createTarball creates a tar archive from a map of filename -> content.
func createTarball(files map[string][]byte) ([]byte, error) {
var buf bytes.Buffer
tw := tar.NewWriter(&buf)
tarWriter := tar.NewWriter(&buf)
// Track directories we've created
dirs := make(map[string]bool)
createdDirectories := make(map[string]bool)
for name, content := range files {
// Create parent directories if needed
dir := core.PathDir(name)
if dir != "." && !dirs[dir] {
hdr := &tar.Header{
if dir != "." && !createdDirectories[dir] {
header := &tar.Header{
Name: dir + "/",
Mode: 0755,
Typeflag: tar.TypeDir,
}
if err := tw.WriteHeader(hdr); err != nil {
if err := tarWriter.WriteHeader(header); err != nil {
return nil, err
}
dirs[dir] = true
createdDirectories[dir] = true
}
// Determine file mode (executable for binaries in miners/)
@ -248,20 +248,20 @@ func createTarball(files map[string][]byte) ([]byte, error) {
mode = 0755
}
hdr := &tar.Header{
header := &tar.Header{
Name: name,
Mode: mode,
Size: int64(len(content)),
}
if err := tw.WriteHeader(hdr); err != nil {
if err := tarWriter.WriteHeader(header); err != nil {
return nil, err
}
if _, err := tw.Write(content); err != nil {
if _, err := tarWriter.Write(content); err != nil {
return nil, err
}
}
if err := tw.Close(); err != nil {
if err := tarWriter.Close(); err != nil {
return nil, err
}
@ -290,11 +290,11 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
return "", err
}
tr := tar.NewReader(bytes.NewReader(tarData))
tarReader := tar.NewReader(bytes.NewReader(tarData))
var firstExecutable string
for {
hdr, err := tr.Next()
header, err := tarReader.Next()
if err == io.EOF {
break
}
@ -303,16 +303,16 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
}
// Security: Sanitize the tar entry name to prevent path traversal (Zip Slip)
cleanName := core.CleanPath(hdr.Name, "/")
cleanName := core.CleanPath(header.Name, "/")
// Reject absolute paths
if core.PathIsAbs(cleanName) {
return "", core.E("extractTarball", "invalid tar entry: absolute path not allowed: "+hdr.Name, nil)
return "", core.E("extractTarball", "invalid tar entry: absolute path not allowed: "+header.Name, nil)
}
// Reject paths that escape the destination directory
if core.HasPrefix(cleanName, "../") || cleanName == ".." {
return "", core.E("extractTarball", "invalid tar entry: path traversal attempt: "+hdr.Name, nil)
return "", core.E("extractTarball", "invalid tar entry: path traversal attempt: "+header.Name, nil)
}
// Build the full path and verify it's within destDir
@ -324,10 +324,10 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
allowedPrefix = absDestDir
}
if !core.HasPrefix(fullPath, allowedPrefix) && fullPath != absDestDir {
return "", core.E("extractTarball", "invalid tar entry: path escape attempt: "+hdr.Name, nil)
return "", core.E("extractTarball", "invalid tar entry: path escape attempt: "+header.Name, nil)
}
switch hdr.Typeflag {
switch header.Typeflag {
case tar.TypeDir:
if err := filesystemEnsureDir(fullPath); err != nil {
return "", err
@ -340,21 +340,21 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
// Limit file size to prevent decompression bombs (100MB max per file)
const maxFileSize int64 = 100 * 1024 * 1024
limitedReader := io.LimitReader(tr, maxFileSize+1)
limitedReader := io.LimitReader(tarReader, maxFileSize+1)
content, err := io.ReadAll(limitedReader)
if err != nil {
return "", core.E("extractTarball", "failed to write file "+hdr.Name, err)
return "", core.E("extractTarball", "failed to write file "+header.Name, err)
}
if int64(len(content)) > maxFileSize {
filesystemDelete(fullPath)
return "", core.E("extractTarball", "file "+hdr.Name+" exceeds maximum size", nil)
return "", core.E("extractTarball", "file "+header.Name+" exceeds maximum size", nil)
}
if err := filesystemResultError(localFileSystem.WriteMode(fullPath, string(content), fs.FileMode(hdr.Mode))); err != nil {
return "", core.E("extractTarball", "failed to create file "+hdr.Name, err)
if err := filesystemResultError(localFileSystem.WriteMode(fullPath, string(content), fs.FileMode(header.Mode))); err != nil {
return "", core.E("extractTarball", "failed to create file "+header.Name, err)
}
// Track first executable
if hdr.Mode&0111 != 0 && firstExecutable == "" {
if header.Mode&0111 != 0 && firstExecutable == "" {
firstExecutable = fullPath
}
// Explicitly ignore symlinks and hard links to prevent symlink attacks

View file

@ -41,8 +41,8 @@ type Connection struct {
// WriteTimeout is the deadline applied before each write call.
WriteTimeout time.Duration
conn net.Conn
writeMu sync.Mutex
conn net.Conn
writeMutex sync.Mutex
}
// NewConnection creates a Connection that wraps conn with sensible defaults.
@ -61,7 +61,7 @@ func NewConnection(conn net.Conn) *Connection {
//
// err := conn.WritePacket(CommandPing, payload, true)
func (c *Connection) WritePacket(cmd uint32, payload []byte, expectResponse bool) error {
h := Header{
header := Header{
Signature: Signature,
PayloadSize: uint64(len(payload)),
ExpectResponse: expectResponse,
@ -70,14 +70,14 @@ func (c *Connection) WritePacket(cmd uint32, payload []byte, expectResponse bool
Flags: FlagRequest,
ProtocolVersion: LevinProtocolVersion,
}
return c.writeFrame(&h, payload)
return c.writeFrame(&header, payload)
}
// WriteResponse sends a Levin response packet with the given return code.
//
// err := conn.WriteResponse(CommandPing, payload, ReturnOK)
func (c *Connection) WriteResponse(cmd uint32, payload []byte, returnCode int32) error {
h := Header{
header := Header{
Signature: Signature,
PayloadSize: uint64(len(payload)),
ExpectResponse: false,
@ -86,15 +86,15 @@ func (c *Connection) WriteResponse(cmd uint32, payload []byte, returnCode int32)
Flags: FlagResponse,
ProtocolVersion: LevinProtocolVersion,
}
return c.writeFrame(&h, payload)
return c.writeFrame(&header, payload)
}
// writeFrame serialises header + payload and writes them atomically.
func (c *Connection) writeFrame(h *Header, payload []byte) error {
buf := EncodeHeader(h)
func (c *Connection) writeFrame(header *Header, payload []byte) error {
buf := EncodeHeader(header)
c.writeMu.Lock()
defer c.writeMu.Unlock()
c.writeMutex.Lock()
defer c.writeMutex.Unlock()
if err := c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)); err != nil {
return err
@ -122,32 +122,32 @@ func (c *Connection) ReadPacket() (Header, []byte, error) {
}
// Read header.
var hdrBuf [HeaderSize]byte
if _, err := io.ReadFull(c.conn, hdrBuf[:]); err != nil {
var headerBytes [HeaderSize]byte
if _, err := io.ReadFull(c.conn, headerBytes[:]); err != nil {
return Header{}, nil, err
}
h, err := DecodeHeader(hdrBuf)
header, err := DecodeHeader(headerBytes)
if err != nil {
return Header{}, nil, err
}
// Check against the connection-specific payload limit.
if h.PayloadSize > c.MaxPayloadSize {
if header.PayloadSize > c.MaxPayloadSize {
return Header{}, nil, ErrorPayloadTooBig
}
// Empty payload is valid — return nil data without allocation.
if h.PayloadSize == 0 {
return h, nil, nil
if header.PayloadSize == 0 {
return header, nil, nil
}
payload := make([]byte, h.PayloadSize)
payload := make([]byte, header.PayloadSize)
if _, err := io.ReadFull(c.conn, payload); err != nil {
return Header{}, nil, err
}
return h, payload, nil
return header, payload, nil
}
// Close closes the underlying network connection.

View file

@ -37,8 +37,8 @@ type Peer struct {
Connected bool `json:"-"`
}
// saveDebounceInterval is the minimum time between disk writes.
const saveDebounceInterval = 5 * time.Second
// peerRegistrySaveDebounceInterval is the minimum time between disk writes.
const peerRegistrySaveDebounceInterval = 5 * time.Second
// PeerAuthMode controls how unknown peers are handled
//
@ -58,8 +58,8 @@ const (
PeerNameMaxLength = 64
)
// peerNameRegex validates peer names: alphanumeric, hyphens, underscores, and spaces
var peerNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9\-_ ]{0,62}[a-zA-Z0-9]$|^[a-zA-Z0-9]$`)
// peerNamePattern validates peer names: alphanumeric, hyphens, underscores, and spaces.
var peerNamePattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9\-_ ]{0,62}[a-zA-Z0-9]$|^[a-zA-Z0-9]$`)
// safeKeyPrefix returns a truncated key for logging, handling short keys safely
func safeKeyPrefix(key string) string {
@ -85,7 +85,7 @@ func validatePeerName(name string) error {
if len(name) > PeerNameMaxLength {
return core.E("validatePeerName", "peer name too long", nil)
}
if !peerNameRegex.MatchString(name) {
if !peerNamePattern.MatchString(name) {
return core.E("validatePeerName", "peer name contains invalid characters (use alphanumeric, hyphens, underscores, spaces)", nil)
}
return nil
@ -106,11 +106,9 @@ type PeerRegistry struct {
allowedPublicKeyMu sync.RWMutex // Protects allowedPublicKeys
// Debounce disk writes
dirty bool // Whether there are unsaved changes
saveTimer *time.Timer // Timer for debounced save
saveMu sync.Mutex // Protects dirty and saveTimer
stopChan chan struct{} // Signal to stop background save
saveStopOnce sync.Once // Ensure stopChan is closed only once
hasPendingChanges bool // Whether there are unsaved changes
pendingSaveTimer *time.Timer // Timer for debounced save
saveMutex sync.Mutex // Protects pending save state
}
// Dimension weights for peer selection
@ -144,8 +142,7 @@ func NewPeerRegistryFromPath(peersPath string) (*PeerRegistry, error) {
pr := &PeerRegistry{
peers: make(map[string]*Peer),
path: peersPath,
stopChan: make(chan struct{}),
authMode: PeerAuthOpen, // Default to open for backward compatibility
authMode: PeerAuthOpen, // Default to open.
allowedPublicKeys: make(map[string]bool),
}
@ -286,7 +283,8 @@ func (r *PeerRegistry) AddPeer(peer *Peer) error {
r.rebuildKDTree()
r.mu.Unlock()
return r.save()
r.scheduleSave()
return nil
}
// UpdatePeer updates an existing peer's information.
@ -303,7 +301,8 @@ func (r *PeerRegistry) UpdatePeer(peer *Peer) error {
r.rebuildKDTree()
r.mu.Unlock()
return r.save()
r.scheduleSave()
return nil
}
// RemovePeer removes a peer from the registry.
@ -320,7 +319,8 @@ func (r *PeerRegistry) RemovePeer(id string) error {
r.rebuildKDTree()
r.mu.Unlock()
return r.save()
r.scheduleSave()
return nil
}
// Peer returns a copy of the peer with the supplied ID.
@ -363,7 +363,7 @@ func (r *PeerRegistry) Peers() iter.Seq[*Peer] {
// UpdateMetrics updates a peer's performance metrics.
// Note: Persistence is debounced. Call Close() to flush before shutdown.
func (r *PeerRegistry) UpdateMetrics(id string, pingMS, geoKM float64, hops int) error {
func (r *PeerRegistry) UpdateMetrics(id string, pingMilliseconds, geoKilometres float64, hopCount int) error {
r.mu.Lock()
peer, exists := r.peers[id]
@ -372,15 +372,16 @@ func (r *PeerRegistry) UpdateMetrics(id string, pingMS, geoKM float64, hops int)
return core.E("PeerRegistry.UpdateMetrics", "peer "+id+" not found", nil)
}
peer.PingMS = pingMS
peer.GeoKM = geoKM
peer.Hops = hops
peer.PingMS = pingMilliseconds
peer.GeoKM = geoKilometres
peer.Hops = hopCount
peer.LastSeen = time.Now()
r.rebuildKDTree()
r.mu.Unlock()
return r.save()
r.scheduleSave()
return nil
}
// UpdateScore updates a peer's reliability score.
@ -401,7 +402,8 @@ func (r *PeerRegistry) UpdateScore(id string, score float64) error {
r.rebuildKDTree()
r.mu.Unlock()
return r.save()
r.scheduleSave()
return nil
}
// MarkConnected updates a peer's connection state.
@ -441,7 +443,7 @@ func (r *PeerRegistry) RecordSuccess(id string) {
peer.Score = min(peer.Score+ScoreSuccessIncrement, ScoreMaximum)
peer.LastSeen = time.Now()
r.mu.Unlock()
r.save()
r.scheduleSave()
}
// RecordFailure records a failed interaction with a peer, reducing their score.
@ -456,7 +458,7 @@ func (r *PeerRegistry) RecordFailure(id string) {
peer.Score = max(peer.Score-ScoreFailureDecrement, ScoreMinimum)
newScore := peer.Score
r.mu.Unlock()
r.save()
r.scheduleSave()
logging.Debug("peer score decreased", logging.Fields{
"peer_id": id,
@ -477,7 +479,7 @@ func (r *PeerRegistry) RecordTimeout(id string) {
peer.Score = max(peer.Score-ScoreTimeoutDecrement, ScoreMinimum)
newScore := peer.Score
r.mu.Unlock()
r.save()
r.scheduleSave()
logging.Debug("peer score decreased", logging.Fields{
"peer_id": id,
@ -642,26 +644,26 @@ func (r *PeerRegistry) rebuildKDTree() {
}
// scheduleSave schedules a debounced save operation.
// Multiple calls within saveDebounceInterval will be coalesced into a single save.
// Must NOT be called with r.mu held.
// Multiple calls within peerRegistrySaveDebounceInterval will be coalesced into a single save.
// Call it after releasing r.mu so peer state and save state do not interleave.
func (r *PeerRegistry) scheduleSave() {
r.saveMu.Lock()
defer r.saveMu.Unlock()
r.saveMutex.Lock()
defer r.saveMutex.Unlock()
r.dirty = true
r.hasPendingChanges = true
// If timer already running, let it handle the save
if r.saveTimer != nil {
if r.pendingSaveTimer != nil {
return
}
// Start a new timer
r.saveTimer = time.AfterFunc(saveDebounceInterval, func() {
r.saveMu.Lock()
r.saveTimer = nil
shouldSave := r.dirty
r.dirty = false
r.saveMu.Unlock()
r.pendingSaveTimer = time.AfterFunc(peerRegistrySaveDebounceInterval, func() {
r.saveMutex.Lock()
r.pendingSaveTimer = nil
shouldSave := r.hasPendingChanges
r.hasPendingChanges = false
r.saveMutex.Unlock()
if shouldSave {
r.mu.RLock()
@ -709,19 +711,15 @@ func (r *PeerRegistry) saveNow() error {
// Close flushes any pending changes and releases resources.
func (r *PeerRegistry) Close() error {
r.saveStopOnce.Do(func() {
close(r.stopChan)
})
// Cancel pending timer and save immediately if dirty
r.saveMu.Lock()
if r.saveTimer != nil {
r.saveTimer.Stop()
r.saveTimer = nil
// Cancel any pending timer and save immediately if changes are queued.
r.saveMutex.Lock()
if r.pendingSaveTimer != nil {
r.pendingSaveTimer.Stop()
r.pendingSaveTimer = nil
}
shouldSave := r.dirty
r.dirty = false
r.saveMu.Unlock()
shouldSave := r.hasPendingChanges
r.hasPendingChanges = false
r.saveMutex.Unlock()
if shouldSave {
r.mu.RLock()
@ -733,14 +731,6 @@ func (r *PeerRegistry) Close() error {
return nil
}
// save is a helper that schedules a debounced save.
// Kept for backward compatibility but now debounces writes.
// Must NOT be called with r.mu held.
func (r *PeerRegistry) save() error {
r.scheduleSave()
return nil // Errors will be logged asynchronously
}
// load reads peers from disk.
func (r *PeerRegistry) load() error {
content, err := filesystemRead(r.path)

View file

@ -21,11 +21,11 @@ import (
"github.com/gorilla/websocket"
)
// debugLogCounter tracks message counts for rate limiting debug logs
var debugLogCounter atomic.Int64
// messageLogSampleCounter tracks message counts for sampled debug logs.
var messageLogSampleCounter atomic.Int64
// debugLogInterval controls how often we log debug messages in hot paths (1 in N)
const debugLogInterval = 100
// messageLogSampleInterval controls how often we log debug messages in hot paths (1 in N).
const messageLogSampleInterval = 100
// DefaultMaxMessageSize is the default maximum message size (1MB)
const DefaultMaxMessageSize int64 = 1 << 20 // 1MB
@ -66,49 +66,49 @@ func DefaultTransportConfig() TransportConfig {
// var handler MessageHandler = func(conn *PeerConnection, msg *Message) {}
type MessageHandler func(conn *PeerConnection, msg *Message)
// MessageDeduplicator tracks seen message IDs to prevent duplicate processing
// MessageDeduplicator tracks recent message IDs to prevent duplicate processing.
//
// deduplicator := NewMessageDeduplicator(5 * time.Minute)
type MessageDeduplicator struct {
seen map[string]time.Time
mu sync.RWMutex
ttl time.Duration
recentMessageTimes map[string]time.Time
mutex sync.RWMutex
timeToLive time.Duration
}
// NewMessageDeduplicator creates a deduplicator with specified TTL
// NewMessageDeduplicator creates a deduplicator with the supplied retention window.
//
// deduplicator := NewMessageDeduplicator(5 * time.Minute)
func NewMessageDeduplicator(ttl time.Duration) *MessageDeduplicator {
func NewMessageDeduplicator(retentionWindow time.Duration) *MessageDeduplicator {
d := &MessageDeduplicator{
seen: make(map[string]time.Time),
ttl: ttl,
recentMessageTimes: make(map[string]time.Time),
timeToLive: retentionWindow,
}
return d
}
// IsDuplicate checks if a message ID has been seen recently
// IsDuplicate checks whether a message ID is still within the retention window.
func (d *MessageDeduplicator) IsDuplicate(msgID string) bool {
d.mu.RLock()
_, exists := d.seen[msgID]
d.mu.RUnlock()
d.mutex.RLock()
_, exists := d.recentMessageTimes[msgID]
d.mutex.RUnlock()
return exists
}
// Mark records a message ID as seen
// Mark records a message ID as recently seen.
func (d *MessageDeduplicator) Mark(msgID string) {
d.mu.Lock()
d.seen[msgID] = time.Now()
d.mu.Unlock()
d.mutex.Lock()
d.recentMessageTimes[msgID] = time.Now()
d.mutex.Unlock()
}
// Cleanup removes expired entries
// Cleanup removes expired entries from the deduplicator.
func (d *MessageDeduplicator) Cleanup() {
d.mu.Lock()
defer d.mu.Unlock()
d.mutex.Lock()
defer d.mutex.Unlock()
now := time.Now()
for id, seen := range d.seen {
if now.Sub(seen) > d.ttl {
delete(d.seen, id)
for id, seenAt := range d.recentMessageTimes {
if now.Sub(seenAt) > d.timeToLive {
delete(d.recentMessageTimes, id)
}
}
}
@ -117,61 +117,62 @@ func (d *MessageDeduplicator) Cleanup() {
//
// transport := NewTransport(nodeManager, peerRegistry, DefaultTransportConfig())
type Transport struct {
config TransportConfig
server *http.Server
upgrader websocket.Upgrader
conns map[string]*PeerConnection // peer ID -> connection
pendingConns atomic.Int32 // tracks connections during handshake
node *NodeManager
registry *PeerRegistry
handler MessageHandler
dedup *MessageDeduplicator // Message deduplication
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
config TransportConfig
httpServer *http.Server
upgrader websocket.Upgrader
connections map[string]*PeerConnection // peer ID -> connection
pendingHandshakeCount atomic.Int32 // tracks connections during handshake
nodeManager *NodeManager
peerRegistry *PeerRegistry
messageHandler MessageHandler
messageDeduplicator *MessageDeduplicator // Message deduplication
mutex sync.RWMutex
lifecycleContext context.Context
cancelLifecycle context.CancelFunc
waitGroup sync.WaitGroup
}
// PeerRateLimiter implements a simple token bucket rate limiter per peer
// PeerRateLimiter implements a simple token bucket rate limiter per peer.
//
// rateLimiter := NewPeerRateLimiter(100, 50)
type PeerRateLimiter struct {
tokens int
maxTokens int
refillRate int // tokens per second
lastRefill time.Time
mu sync.Mutex
availableTokens int
capacity int
refillPerSecond int // tokens per second
lastRefillTime time.Time
mutex sync.Mutex
}
// NewPeerRateLimiter creates a rate limiter with specified messages/second
// NewPeerRateLimiter creates a token bucket seeded with maxTokens and refilled
// at refillRate tokens per second.
//
// rateLimiter := NewPeerRateLimiter(100, 50)
func NewPeerRateLimiter(maxTokens, refillRate int) *PeerRateLimiter {
func NewPeerRateLimiter(maxTokens, refillPerSecond int) *PeerRateLimiter {
return &PeerRateLimiter{
tokens: maxTokens,
maxTokens: maxTokens,
refillRate: refillRate,
lastRefill: time.Now(),
availableTokens: maxTokens,
capacity: maxTokens,
refillPerSecond: refillPerSecond,
lastRefillTime: time.Now(),
}
}
// Allow checks if a message is allowed and consumes a token if so
func (r *PeerRateLimiter) Allow() bool {
r.mu.Lock()
defer r.mu.Unlock()
r.mutex.Lock()
defer r.mutex.Unlock()
// Refill tokens based on elapsed time
now := time.Now()
elapsed := now.Sub(r.lastRefill)
tokensToAdd := int(elapsed.Seconds()) * r.refillRate
elapsed := now.Sub(r.lastRefillTime)
tokensToAdd := int(elapsed.Seconds()) * r.refillPerSecond
if tokensToAdd > 0 {
r.tokens = min(r.tokens+tokensToAdd, r.maxTokens)
r.lastRefill = now
r.availableTokens = min(r.availableTokens+tokensToAdd, r.capacity)
r.lastRefillTime = now
}
// Check if we have tokens available
if r.tokens > 0 {
r.tokens--
if r.availableTokens > 0 {
r.availableTokens--
return true
}
return false
@ -186,7 +187,7 @@ type PeerConnection struct {
SharedSecret []byte // Derived via X25519 ECDH, used for SMSG
LastActivity time.Time
UserAgent string // Request identity advertised by the peer
writeMu sync.Mutex // Serialize WebSocket writes
writeMutex sync.Mutex // Serialize WebSocket writes
transport *Transport
closeOnce sync.Once // Ensure Close() is only called once
rateLimiter *PeerRateLimiter // Per-peer message rate limiting
@ -196,14 +197,14 @@ type PeerConnection struct {
//
// transport := NewTransport(nodeManager, peerRegistry, DefaultTransportConfig())
func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportConfig) *Transport {
ctx, cancel := context.WithCancel(context.Background())
lifecycleContext, cancelLifecycle := context.WithCancel(context.Background())
return &Transport{
config: config,
node: node,
registry: registry,
conns: make(map[string]*PeerConnection),
dedup: NewMessageDeduplicator(5 * time.Minute), // 5 minute TTL for dedup
config: config,
nodeManager: node,
peerRegistry: registry,
connections: make(map[string]*PeerConnection),
messageDeduplicator: NewMessageDeduplicator(5 * time.Minute), // 5 minute TTL for dedup
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
@ -222,8 +223,8 @@ func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportCon
return host == "localhost" || host == "127.0.0.1" || host == "::1"
},
},
ctx: ctx,
cancel: cancel,
lifecycleContext: lifecycleContext,
cancelLifecycle: cancelLifecycle,
}
}
@ -263,7 +264,7 @@ func agentHeaderToken(value string) string {
// agentUserAgent returns a transparent identity string for request headers.
func (t *Transport) agentUserAgent() string {
identity := t.node.Identity()
identity := t.nodeManager.Identity()
if identity == nil {
return core.Sprintf("%s proto=%s", agentUserAgentPrefix, ProtocolVersion)
}
@ -283,7 +284,7 @@ func (t *Transport) Start() error {
mux := http.NewServeMux()
mux.HandleFunc(t.config.WSPath, t.handleWSUpgrade)
t.server = &http.Server{
t.httpServer = &http.Server{
Addr: t.config.ListenAddr,
Handler: mux,
ReadTimeout: 30 * time.Second,
@ -294,7 +295,7 @@ func (t *Transport) Start() error {
// Apply TLS hardening if TLS is enabled
if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" {
t.server.TLSConfig = &tls.Config{
t.httpServer.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
// TLS 1.3 ciphers (automatically used when available)
@ -316,12 +317,12 @@ func (t *Transport) Start() error {
}
}
t.wg.Go(func() {
t.waitGroup.Go(func() {
var err error
if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" {
err = t.server.ListenAndServeTLS(t.config.TLSCertPath, t.config.TLSKeyPath)
err = t.httpServer.ListenAndServeTLS(t.config.TLSCertPath, t.config.TLSKeyPath)
} else {
err = t.server.ListenAndServe()
err = t.httpServer.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
logging.Error("HTTP server error", logging.Fields{"error": err, "addr": t.config.ListenAddr})
@ -329,15 +330,15 @@ func (t *Transport) Start() error {
})
// Start message deduplication cleanup goroutine
t.wg.Go(func() {
t.waitGroup.Go(func() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-t.ctx.Done():
case <-t.lifecycleContext.Done():
return
case <-ticker.C:
t.dedup.Cleanup()
t.messageDeduplicator.Cleanup()
}
}
})
@ -347,28 +348,28 @@ func (t *Transport) Start() error {
// Stop closes active connections and shuts the transport down cleanly.
func (t *Transport) Stop() error {
t.cancel()
t.cancelLifecycle()
// Gracefully close all connections with shutdown message
t.mu.RLock()
conns := slices.Collect(maps.Values(t.conns))
t.mu.RUnlock()
t.mutex.RLock()
connections := slices.Collect(maps.Values(t.connections))
t.mutex.RUnlock()
for _, pc := range conns {
for _, pc := range connections {
pc.GracefulClose("server shutdown", DisconnectShutdown)
}
// Shutdown HTTP server if it was started
if t.server != nil {
if t.httpServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := t.server.Shutdown(ctx); err != nil {
if err := t.httpServer.Shutdown(ctx); err != nil {
return core.E("Transport.Stop", "server shutdown error", err)
}
}
t.wg.Wait()
t.waitGroup.Wait()
return nil
}
@ -376,9 +377,9 @@ func (t *Transport) Stop() error {
//
// transport.OnMessage(worker.HandleMessage)
func (t *Transport) OnMessage(handler MessageHandler) {
t.mu.Lock()
defer t.mu.Unlock()
t.handler = handler
t.mutex.Lock()
defer t.mutex.Unlock()
t.messageHandler = handler
}
// Connect dials a peer, completes the handshake, and starts the session loops.
@ -419,9 +420,9 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
}
// Store connection using the real peer ID from handshake
t.mu.Lock()
t.conns[pc.Peer.ID] = pc
t.mu.Unlock()
t.mutex.Lock()
t.connections[pc.Peer.ID] = pc
t.mutex.Unlock()
logging.Debug("connected to peer", logging.Fields{"peer_id": pc.Peer.ID, "secret_len": len(pc.SharedSecret)})
logging.Debug("connected peer metadata", logging.Fields{
@ -430,16 +431,16 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
})
// Update registry
t.registry.MarkConnected(pc.Peer.ID, true)
t.peerRegistry.MarkConnected(pc.Peer.ID, true)
// Start read loop
t.wg.Add(1)
t.waitGroup.Add(1)
go t.readLoop(pc)
logging.Debug("started readLoop for peer", logging.Fields{"peer_id": pc.Peer.ID})
// Start keepalive
t.wg.Add(1)
t.waitGroup.Add(1)
go t.keepalive(pc)
return pc, nil
@ -447,9 +448,9 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
// Send transmits an encrypted message to a connected peer.
func (t *Transport) Send(peerID string, msg *Message) error {
t.mu.RLock()
pc, exists := t.conns[peerID]
t.mu.RUnlock()
t.mutex.RLock()
pc, exists := t.connections[peerID]
t.mutex.RUnlock()
if !exists {
return core.E("Transport.Send", "peer "+peerID+" not connected", nil)
@ -461,10 +462,10 @@ func (t *Transport) Send(peerID string, msg *Message) error {
// Connections returns an iterator over all active peer connections.
func (t *Transport) Connections() iter.Seq[*PeerConnection] {
return func(yield func(*PeerConnection) bool) {
t.mu.RLock()
defer t.mu.RUnlock()
t.mutex.RLock()
defer t.mutex.RUnlock()
for _, pc := range t.conns {
for _, pc := range t.connections {
if !yield(pc) {
return
}
@ -494,9 +495,9 @@ func (t *Transport) Broadcast(msg *Message) error {
//
// connection := transport.Connection("worker-1")
func (t *Transport) Connection(peerID string) *PeerConnection {
t.mu.RLock()
defer t.mu.RUnlock()
return t.conns[peerID]
t.mutex.RLock()
defer t.mutex.RUnlock()
return t.connections[peerID]
}
// handleWSUpgrade handles incoming WebSocket connections.
@ -504,20 +505,20 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
userAgent := r.Header.Get("User-Agent")
// Enforce MaxConns limit (including pending connections during handshake)
t.mu.RLock()
currentConns := len(t.conns)
t.mu.RUnlock()
pendingConns := int(t.pendingConns.Load())
t.mutex.RLock()
currentConnections := len(t.connections)
t.mutex.RUnlock()
pendingHandshakeCount := int(t.pendingHandshakeCount.Load())
totalConns := currentConns + pendingConns
if totalConns >= t.config.MaxConns {
totalConnections := currentConnections + pendingHandshakeCount
if totalConnections >= t.config.MaxConns {
http.Error(w, "Too many connections", http.StatusServiceUnavailable)
return
}
// Track this connection as pending during handshake
t.pendingConns.Add(1)
defer t.pendingConns.Add(-1)
t.pendingHandshakeCount.Add(1)
defer t.pendingHandshakeCount.Add(-1)
conn, err := t.upgrader.Upgrade(w, r, nil)
if err != nil {
@ -568,7 +569,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
"peer_id": payload.Identity.ID,
"user_agent": userAgent,
})
identity := t.node.Identity()
identity := t.nodeManager.Identity()
if identity != nil {
rejectPayload := HandshakeAckPayload{
Identity: *identity,
@ -585,14 +586,14 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
}
// Derive shared secret from peer's public key
sharedSecret, err := t.node.DeriveSharedSecret(payload.Identity.PublicKey)
sharedSecret, err := t.nodeManager.DeriveSharedSecret(payload.Identity.PublicKey)
if err != nil {
conn.Close()
return
}
// Check if peer is allowed to connect (allowlist check)
if !t.registry.IsPeerAllowed(payload.Identity.ID, payload.Identity.PublicKey) {
if !t.peerRegistry.IsPeerAllowed(payload.Identity.ID, payload.Identity.PublicKey) {
logging.Warn("peer connection rejected: not in allowlist", logging.Fields{
"peer_id": payload.Identity.ID,
"peer_name": payload.Identity.Name,
@ -600,7 +601,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
"user_agent": userAgent,
})
// Send rejection before closing
identity := t.node.Identity()
identity := t.nodeManager.Identity()
if identity != nil {
rejectPayload := HandshakeAckPayload{
Identity: *identity,
@ -617,7 +618,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
}
// Create peer if not exists (only if auth passed)
peer := t.registry.Peer(payload.Identity.ID)
peer := t.peerRegistry.Peer(payload.Identity.ID)
if peer == nil {
// Auto-register the peer since they passed allowlist check
peer = &Peer{
@ -628,7 +629,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
AddedAt: time.Now(),
Score: 50,
}
t.registry.AddPeer(peer)
t.peerRegistry.AddPeer(peer)
logging.Info("auto-registered new peer", logging.Fields{
"peer_id": peer.ID,
"peer_name": peer.Name,
@ -646,7 +647,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
}
// Send handshake acknowledgment
identity := t.node.Identity()
identity := t.nodeManager.Identity()
if identity == nil {
conn.Close()
return
@ -683,12 +684,12 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
}
// Store connection
t.mu.Lock()
t.conns[peer.ID] = pc
t.mu.Unlock()
t.mutex.Lock()
t.connections[peer.ID] = pc
t.mutex.Unlock()
// Update registry
t.registry.MarkConnected(peer.ID, true)
t.peerRegistry.MarkConnected(peer.ID, true)
logging.Debug("accepted peer connection", logging.Fields{
"peer_id": peer.ID,
@ -697,11 +698,11 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
})
// Start read loop
t.wg.Add(1)
t.waitGroup.Add(1)
go t.readLoop(pc)
// Start keepalive
t.wg.Add(1)
t.waitGroup.Add(1)
go t.keepalive(pc)
}
@ -717,7 +718,7 @@ func (t *Transport) performHandshake(pc *PeerConnection) error {
pc.Conn.SetReadDeadline(time.Time{})
}()
identity := t.node.Identity()
identity := t.nodeManager.Identity()
if identity == nil {
return ErrorIdentityNotInitialized
}
@ -780,7 +781,7 @@ func (t *Transport) performHandshake(pc *PeerConnection) error {
pc.Peer.Role = ackPayload.Identity.Role
// Verify challenge response - derive shared secret first using the peer's public key
sharedSecret, err := t.node.DeriveSharedSecret(pc.Peer.PublicKey)
sharedSecret, err := t.nodeManager.DeriveSharedSecret(pc.Peer.PublicKey)
if err != nil {
return core.E("Transport.performHandshake", "derive shared secret for challenge verification", err)
}
@ -797,9 +798,9 @@ func (t *Transport) performHandshake(pc *PeerConnection) error {
pc.SharedSecret = sharedSecret
// Update the peer in registry with the real identity
if err := t.registry.UpdatePeer(pc.Peer); err != nil {
if err := t.peerRegistry.UpdatePeer(pc.Peer); err != nil {
// If update fails (peer not found with old ID), add as new
t.registry.AddPeer(pc.Peer)
t.peerRegistry.AddPeer(pc.Peer)
}
logging.Debug("handshake completed with challenge-response verification", logging.Fields{
@ -813,7 +814,7 @@ func (t *Transport) performHandshake(pc *PeerConnection) error {
// readLoop reads messages from a peer connection.
func (t *Transport) readLoop(pc *PeerConnection) {
defer t.wg.Done()
defer t.waitGroup.Done()
defer t.removeConnection(pc)
// Apply message size limit to prevent memory exhaustion attacks
@ -825,7 +826,7 @@ func (t *Transport) readLoop(pc *PeerConnection) {
for {
select {
case <-t.ctx.Done():
case <-t.lifecycleContext.Done():
return
default:
}
@ -859,21 +860,21 @@ func (t *Transport) readLoop(pc *PeerConnection) {
}
// Check for duplicate messages (prevents amplification attacks)
if t.dedup.IsDuplicate(msg.ID) {
if t.messageDeduplicator.IsDuplicate(msg.ID) {
logging.Debug("dropping duplicate message", logging.Fields{"msg_id": msg.ID, "peer_id": pc.Peer.ID})
continue
}
t.dedup.Mark(msg.ID)
t.messageDeduplicator.Mark(msg.ID)
// Rate limit debug logs in hot path to reduce noise (log 1 in N messages)
if debugLogCounter.Add(1)%debugLogInterval == 0 {
if messageLogSampleCounter.Add(1)%messageLogSampleInterval == 0 {
logging.Debug("received message from peer", logging.Fields{"type": msg.Type, "peer_id": pc.Peer.ID, "reply_to": msg.ReplyTo, "sample": "1/100"})
}
// Dispatch to handler (read handler under lock to avoid race)
t.mu.RLock()
handler := t.handler
t.mu.RUnlock()
t.mutex.RLock()
handler := t.messageHandler
t.mutex.RUnlock()
if handler != nil {
handler(pc, msg)
}
@ -882,14 +883,14 @@ func (t *Transport) readLoop(pc *PeerConnection) {
// keepalive sends periodic pings.
func (t *Transport) keepalive(pc *PeerConnection) {
defer t.wg.Done()
defer t.waitGroup.Done()
ticker := time.NewTicker(t.config.PingInterval)
defer ticker.Stop()
for {
select {
case <-t.ctx.Done():
case <-t.lifecycleContext.Done():
return
case <-ticker.C:
// Check if connection is still alive
@ -899,7 +900,7 @@ func (t *Transport) keepalive(pc *PeerConnection) {
}
// Send ping
identity := t.node.Identity()
identity := t.nodeManager.Identity()
pingMsg, err := NewMessage(MessagePing, identity.ID, pc.Peer.ID, PingPayload{
SentAt: time.Now().UnixMilli(),
})
@ -917,18 +918,18 @@ func (t *Transport) keepalive(pc *PeerConnection) {
// removeConnection removes and cleans up a connection.
func (t *Transport) removeConnection(pc *PeerConnection) {
t.mu.Lock()
delete(t.conns, pc.Peer.ID)
t.mu.Unlock()
t.mutex.Lock()
delete(t.connections, pc.Peer.ID)
t.mutex.Unlock()
t.registry.MarkConnected(pc.Peer.ID, false)
t.peerRegistry.MarkConnected(pc.Peer.ID, false)
pc.Close()
}
// Send sends an encrypted message over the connection.
func (pc *PeerConnection) Send(msg *Message) error {
pc.writeMu.Lock()
defer pc.writeMu.Unlock()
pc.writeMutex.Lock()
defer pc.writeMutex.Unlock()
// Encrypt message using SMSG
data, err := pc.transport.encryptMessage(msg, pc.SharedSecret)
@ -976,11 +977,11 @@ func (pc *PeerConnection) GracefulClose(reason string, code int) error {
var err error
pc.closeOnce.Do(func() {
// Try to send disconnect message (best effort).
// Note: we must NOT call SetWriteDeadline outside writeMu — Send()
// Note: we must NOT call SetWriteDeadline outside writeMutex — Send()
// already manages write deadlines under the lock. Setting it here
// without the lock races with concurrent Send() calls (P2P-RACE-1).
if pc.transport != nil && pc.SharedSecret != nil {
identity := pc.transport.node.Identity()
identity := pc.transport.nodeManager.Identity()
if identity != nil {
payload := DisconnectPayload{
Reason: reason,
@ -1042,7 +1043,7 @@ func (t *Transport) decryptMessage(data []byte, sharedSecret []byte) (*Message,
//
// count := transport.ConnectedPeerCount()
func (t *Transport) ConnectedPeerCount() int {
t.mu.RLock()
defer t.mu.RUnlock()
return len(t.conns)
t.mutex.RLock()
defer t.mutex.RUnlock()
return len(t.connections)
}