refactor(node): tighten AX naming across core paths
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
dec79b54d6
commit
819862a1a4
4 changed files with 250 additions and 259 deletions
|
|
@ -55,14 +55,14 @@ type BundleManifest struct {
|
|||
// bundle, err := CreateProfileBundle(profileJSON, "xmrig-default", "password")
|
||||
func CreateProfileBundle(profileJSON []byte, name string, password string) (*Bundle, error) {
|
||||
// Create a TIM with just the profile config
|
||||
t, err := tim.New()
|
||||
timBundle, err := tim.New()
|
||||
if err != nil {
|
||||
return nil, core.E("CreateProfileBundle", "failed to create TIM", err)
|
||||
}
|
||||
t.Config = profileJSON
|
||||
timBundle.Config = profileJSON
|
||||
|
||||
// Encrypt to STIM format
|
||||
stimData, err := t.ToSigil(password)
|
||||
stimData, err := timBundle.ToSigil(password)
|
||||
if err != nil {
|
||||
return nil, core.E("CreateProfileBundle", "failed to encrypt bundle", err)
|
||||
}
|
||||
|
|
@ -112,24 +112,24 @@ func CreateMinerBundle(minerPath string, profileJSON []byte, name string, passwo
|
|||
}
|
||||
|
||||
// Create DataNode from tarball
|
||||
dn, err := datanode.FromTar(tarData)
|
||||
dataNode, err := datanode.FromTar(tarData)
|
||||
if err != nil {
|
||||
return nil, core.E("CreateMinerBundle", "failed to create datanode", err)
|
||||
}
|
||||
|
||||
// Create TIM from DataNode
|
||||
t, err := tim.FromDataNode(dn)
|
||||
timBundle, err := tim.FromDataNode(dataNode)
|
||||
if err != nil {
|
||||
return nil, core.E("CreateMinerBundle", "failed to create TIM", err)
|
||||
}
|
||||
|
||||
// Set profile as config if provided
|
||||
if profileJSON != nil {
|
||||
t.Config = profileJSON
|
||||
timBundle.Config = profileJSON
|
||||
}
|
||||
|
||||
// Encrypt to STIM format
|
||||
stimData, err := t.ToSigil(password)
|
||||
stimData, err := timBundle.ToSigil(password)
|
||||
if err != nil {
|
||||
return nil, core.E("CreateMinerBundle", "failed to encrypt bundle", err)
|
||||
}
|
||||
|
|
@ -159,12 +159,12 @@ func ExtractProfileBundle(bundle *Bundle, password string) ([]byte, error) {
|
|||
}
|
||||
|
||||
// Decrypt STIM format
|
||||
t, err := tim.FromSigil(bundle.Data, password)
|
||||
timBundle, err := tim.FromSigil(bundle.Data, password)
|
||||
if err != nil {
|
||||
return nil, core.E("ExtractProfileBundle", "failed to decrypt bundle", err)
|
||||
}
|
||||
|
||||
return t.Config, nil
|
||||
return timBundle.Config, nil
|
||||
}
|
||||
|
||||
// ExtractMinerBundle decrypts and extracts a miner bundle, returning the miner path and profile.
|
||||
|
|
@ -177,13 +177,13 @@ func ExtractMinerBundle(bundle *Bundle, password string, destDir string) (string
|
|||
}
|
||||
|
||||
// Decrypt STIM format
|
||||
t, err := tim.FromSigil(bundle.Data, password)
|
||||
timBundle, err := tim.FromSigil(bundle.Data, password)
|
||||
if err != nil {
|
||||
return "", nil, core.E("ExtractMinerBundle", "failed to decrypt bundle", err)
|
||||
}
|
||||
|
||||
// Convert rootfs to tarball and extract
|
||||
tarData, err := t.RootFS.ToTar()
|
||||
tarData, err := timBundle.RootFS.ToTar()
|
||||
if err != nil {
|
||||
return "", nil, core.E("ExtractMinerBundle", "failed to convert rootfs to tar", err)
|
||||
}
|
||||
|
|
@ -194,7 +194,7 @@ func ExtractMinerBundle(bundle *Bundle, password string, destDir string) (string
|
|||
return "", nil, core.E("ExtractMinerBundle", "failed to extract tarball", err)
|
||||
}
|
||||
|
||||
return minerPath, t.Config, nil
|
||||
return minerPath, timBundle.Config, nil
|
||||
}
|
||||
|
||||
// VerifyBundle checks if a bundle's checksum is valid.
|
||||
|
|
@ -222,24 +222,24 @@ func isJSON(data []byte) bool {
|
|||
// createTarball creates a tar archive from a map of filename -> content.
|
||||
func createTarball(files map[string][]byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
tw := tar.NewWriter(&buf)
|
||||
tarWriter := tar.NewWriter(&buf)
|
||||
|
||||
// Track directories we've created
|
||||
dirs := make(map[string]bool)
|
||||
createdDirectories := make(map[string]bool)
|
||||
|
||||
for name, content := range files {
|
||||
// Create parent directories if needed
|
||||
dir := core.PathDir(name)
|
||||
if dir != "." && !dirs[dir] {
|
||||
hdr := &tar.Header{
|
||||
if dir != "." && !createdDirectories[dir] {
|
||||
header := &tar.Header{
|
||||
Name: dir + "/",
|
||||
Mode: 0755,
|
||||
Typeflag: tar.TypeDir,
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dirs[dir] = true
|
||||
createdDirectories[dir] = true
|
||||
}
|
||||
|
||||
// Determine file mode (executable for binaries in miners/)
|
||||
|
|
@ -248,20 +248,20 @@ func createTarball(files map[string][]byte) ([]byte, error) {
|
|||
mode = 0755
|
||||
}
|
||||
|
||||
hdr := &tar.Header{
|
||||
header := &tar.Header{
|
||||
Name: name,
|
||||
Mode: mode,
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := tw.Write(content); err != nil {
|
||||
if _, err := tarWriter.Write(content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tw.Close(); err != nil {
|
||||
if err := tarWriter.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -290,11 +290,11 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
tr := tar.NewReader(bytes.NewReader(tarData))
|
||||
tarReader := tar.NewReader(bytes.NewReader(tarData))
|
||||
var firstExecutable string
|
||||
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
|
@ -303,16 +303,16 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
|
|||
}
|
||||
|
||||
// Security: Sanitize the tar entry name to prevent path traversal (Zip Slip)
|
||||
cleanName := core.CleanPath(hdr.Name, "/")
|
||||
cleanName := core.CleanPath(header.Name, "/")
|
||||
|
||||
// Reject absolute paths
|
||||
if core.PathIsAbs(cleanName) {
|
||||
return "", core.E("extractTarball", "invalid tar entry: absolute path not allowed: "+hdr.Name, nil)
|
||||
return "", core.E("extractTarball", "invalid tar entry: absolute path not allowed: "+header.Name, nil)
|
||||
}
|
||||
|
||||
// Reject paths that escape the destination directory
|
||||
if core.HasPrefix(cleanName, "../") || cleanName == ".." {
|
||||
return "", core.E("extractTarball", "invalid tar entry: path traversal attempt: "+hdr.Name, nil)
|
||||
return "", core.E("extractTarball", "invalid tar entry: path traversal attempt: "+header.Name, nil)
|
||||
}
|
||||
|
||||
// Build the full path and verify it's within destDir
|
||||
|
|
@ -324,10 +324,10 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
|
|||
allowedPrefix = absDestDir
|
||||
}
|
||||
if !core.HasPrefix(fullPath, allowedPrefix) && fullPath != absDestDir {
|
||||
return "", core.E("extractTarball", "invalid tar entry: path escape attempt: "+hdr.Name, nil)
|
||||
return "", core.E("extractTarball", "invalid tar entry: path escape attempt: "+header.Name, nil)
|
||||
}
|
||||
|
||||
switch hdr.Typeflag {
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := filesystemEnsureDir(fullPath); err != nil {
|
||||
return "", err
|
||||
|
|
@ -340,21 +340,21 @@ func extractTarball(tarData []byte, destDir string) (string, error) {
|
|||
|
||||
// Limit file size to prevent decompression bombs (100MB max per file)
|
||||
const maxFileSize int64 = 100 * 1024 * 1024
|
||||
limitedReader := io.LimitReader(tr, maxFileSize+1)
|
||||
limitedReader := io.LimitReader(tarReader, maxFileSize+1)
|
||||
content, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return "", core.E("extractTarball", "failed to write file "+hdr.Name, err)
|
||||
return "", core.E("extractTarball", "failed to write file "+header.Name, err)
|
||||
}
|
||||
if int64(len(content)) > maxFileSize {
|
||||
filesystemDelete(fullPath)
|
||||
return "", core.E("extractTarball", "file "+hdr.Name+" exceeds maximum size", nil)
|
||||
return "", core.E("extractTarball", "file "+header.Name+" exceeds maximum size", nil)
|
||||
}
|
||||
if err := filesystemResultError(localFileSystem.WriteMode(fullPath, string(content), fs.FileMode(hdr.Mode))); err != nil {
|
||||
return "", core.E("extractTarball", "failed to create file "+hdr.Name, err)
|
||||
if err := filesystemResultError(localFileSystem.WriteMode(fullPath, string(content), fs.FileMode(header.Mode))); err != nil {
|
||||
return "", core.E("extractTarball", "failed to create file "+header.Name, err)
|
||||
}
|
||||
|
||||
// Track first executable
|
||||
if hdr.Mode&0111 != 0 && firstExecutable == "" {
|
||||
if header.Mode&0111 != 0 && firstExecutable == "" {
|
||||
firstExecutable = fullPath
|
||||
}
|
||||
// Explicitly ignore symlinks and hard links to prevent symlink attacks
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ type Connection struct {
|
|||
// WriteTimeout is the deadline applied before each write call.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
conn net.Conn
|
||||
writeMu sync.Mutex
|
||||
conn net.Conn
|
||||
writeMutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewConnection creates a Connection that wraps conn with sensible defaults.
|
||||
|
|
@ -61,7 +61,7 @@ func NewConnection(conn net.Conn) *Connection {
|
|||
//
|
||||
// err := conn.WritePacket(CommandPing, payload, true)
|
||||
func (c *Connection) WritePacket(cmd uint32, payload []byte, expectResponse bool) error {
|
||||
h := Header{
|
||||
header := Header{
|
||||
Signature: Signature,
|
||||
PayloadSize: uint64(len(payload)),
|
||||
ExpectResponse: expectResponse,
|
||||
|
|
@ -70,14 +70,14 @@ func (c *Connection) WritePacket(cmd uint32, payload []byte, expectResponse bool
|
|||
Flags: FlagRequest,
|
||||
ProtocolVersion: LevinProtocolVersion,
|
||||
}
|
||||
return c.writeFrame(&h, payload)
|
||||
return c.writeFrame(&header, payload)
|
||||
}
|
||||
|
||||
// WriteResponse sends a Levin response packet with the given return code.
|
||||
//
|
||||
// err := conn.WriteResponse(CommandPing, payload, ReturnOK)
|
||||
func (c *Connection) WriteResponse(cmd uint32, payload []byte, returnCode int32) error {
|
||||
h := Header{
|
||||
header := Header{
|
||||
Signature: Signature,
|
||||
PayloadSize: uint64(len(payload)),
|
||||
ExpectResponse: false,
|
||||
|
|
@ -86,15 +86,15 @@ func (c *Connection) WriteResponse(cmd uint32, payload []byte, returnCode int32)
|
|||
Flags: FlagResponse,
|
||||
ProtocolVersion: LevinProtocolVersion,
|
||||
}
|
||||
return c.writeFrame(&h, payload)
|
||||
return c.writeFrame(&header, payload)
|
||||
}
|
||||
|
||||
// writeFrame serialises header + payload and writes them atomically.
|
||||
func (c *Connection) writeFrame(h *Header, payload []byte) error {
|
||||
buf := EncodeHeader(h)
|
||||
func (c *Connection) writeFrame(header *Header, payload []byte) error {
|
||||
buf := EncodeHeader(header)
|
||||
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
c.writeMutex.Lock()
|
||||
defer c.writeMutex.Unlock()
|
||||
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)); err != nil {
|
||||
return err
|
||||
|
|
@ -122,32 +122,32 @@ func (c *Connection) ReadPacket() (Header, []byte, error) {
|
|||
}
|
||||
|
||||
// Read header.
|
||||
var hdrBuf [HeaderSize]byte
|
||||
if _, err := io.ReadFull(c.conn, hdrBuf[:]); err != nil {
|
||||
var headerBytes [HeaderSize]byte
|
||||
if _, err := io.ReadFull(c.conn, headerBytes[:]); err != nil {
|
||||
return Header{}, nil, err
|
||||
}
|
||||
|
||||
h, err := DecodeHeader(hdrBuf)
|
||||
header, err := DecodeHeader(headerBytes)
|
||||
if err != nil {
|
||||
return Header{}, nil, err
|
||||
}
|
||||
|
||||
// Check against the connection-specific payload limit.
|
||||
if h.PayloadSize > c.MaxPayloadSize {
|
||||
if header.PayloadSize > c.MaxPayloadSize {
|
||||
return Header{}, nil, ErrorPayloadTooBig
|
||||
}
|
||||
|
||||
// Empty payload is valid — return nil data without allocation.
|
||||
if h.PayloadSize == 0 {
|
||||
return h, nil, nil
|
||||
if header.PayloadSize == 0 {
|
||||
return header, nil, nil
|
||||
}
|
||||
|
||||
payload := make([]byte, h.PayloadSize)
|
||||
payload := make([]byte, header.PayloadSize)
|
||||
if _, err := io.ReadFull(c.conn, payload); err != nil {
|
||||
return Header{}, nil, err
|
||||
}
|
||||
|
||||
return h, payload, nil
|
||||
return header, payload, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying network connection.
|
||||
|
|
|
|||
102
node/peer.go
102
node/peer.go
|
|
@ -37,8 +37,8 @@ type Peer struct {
|
|||
Connected bool `json:"-"`
|
||||
}
|
||||
|
||||
// saveDebounceInterval is the minimum time between disk writes.
|
||||
const saveDebounceInterval = 5 * time.Second
|
||||
// peerRegistrySaveDebounceInterval is the minimum time between disk writes.
|
||||
const peerRegistrySaveDebounceInterval = 5 * time.Second
|
||||
|
||||
// PeerAuthMode controls how unknown peers are handled
|
||||
//
|
||||
|
|
@ -58,8 +58,8 @@ const (
|
|||
PeerNameMaxLength = 64
|
||||
)
|
||||
|
||||
// peerNameRegex validates peer names: alphanumeric, hyphens, underscores, and spaces
|
||||
var peerNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9\-_ ]{0,62}[a-zA-Z0-9]$|^[a-zA-Z0-9]$`)
|
||||
// peerNamePattern validates peer names: alphanumeric, hyphens, underscores, and spaces.
|
||||
var peerNamePattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9\-_ ]{0,62}[a-zA-Z0-9]$|^[a-zA-Z0-9]$`)
|
||||
|
||||
// safeKeyPrefix returns a truncated key for logging, handling short keys safely
|
||||
func safeKeyPrefix(key string) string {
|
||||
|
|
@ -85,7 +85,7 @@ func validatePeerName(name string) error {
|
|||
if len(name) > PeerNameMaxLength {
|
||||
return core.E("validatePeerName", "peer name too long", nil)
|
||||
}
|
||||
if !peerNameRegex.MatchString(name) {
|
||||
if !peerNamePattern.MatchString(name) {
|
||||
return core.E("validatePeerName", "peer name contains invalid characters (use alphanumeric, hyphens, underscores, spaces)", nil)
|
||||
}
|
||||
return nil
|
||||
|
|
@ -106,11 +106,9 @@ type PeerRegistry struct {
|
|||
allowedPublicKeyMu sync.RWMutex // Protects allowedPublicKeys
|
||||
|
||||
// Debounce disk writes
|
||||
dirty bool // Whether there are unsaved changes
|
||||
saveTimer *time.Timer // Timer for debounced save
|
||||
saveMu sync.Mutex // Protects dirty and saveTimer
|
||||
stopChan chan struct{} // Signal to stop background save
|
||||
saveStopOnce sync.Once // Ensure stopChan is closed only once
|
||||
hasPendingChanges bool // Whether there are unsaved changes
|
||||
pendingSaveTimer *time.Timer // Timer for debounced save
|
||||
saveMutex sync.Mutex // Protects pending save state
|
||||
}
|
||||
|
||||
// Dimension weights for peer selection
|
||||
|
|
@ -144,8 +142,7 @@ func NewPeerRegistryFromPath(peersPath string) (*PeerRegistry, error) {
|
|||
pr := &PeerRegistry{
|
||||
peers: make(map[string]*Peer),
|
||||
path: peersPath,
|
||||
stopChan: make(chan struct{}),
|
||||
authMode: PeerAuthOpen, // Default to open for backward compatibility
|
||||
authMode: PeerAuthOpen, // Default to open.
|
||||
allowedPublicKeys: make(map[string]bool),
|
||||
}
|
||||
|
||||
|
|
@ -286,7 +283,8 @@ func (r *PeerRegistry) AddPeer(peer *Peer) error {
|
|||
r.rebuildKDTree()
|
||||
r.mu.Unlock()
|
||||
|
||||
return r.save()
|
||||
r.scheduleSave()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeer updates an existing peer's information.
|
||||
|
|
@ -303,7 +301,8 @@ func (r *PeerRegistry) UpdatePeer(peer *Peer) error {
|
|||
r.rebuildKDTree()
|
||||
r.mu.Unlock()
|
||||
|
||||
return r.save()
|
||||
r.scheduleSave()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeer removes a peer from the registry.
|
||||
|
|
@ -320,7 +319,8 @@ func (r *PeerRegistry) RemovePeer(id string) error {
|
|||
r.rebuildKDTree()
|
||||
r.mu.Unlock()
|
||||
|
||||
return r.save()
|
||||
r.scheduleSave()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Peer returns a copy of the peer with the supplied ID.
|
||||
|
|
@ -363,7 +363,7 @@ func (r *PeerRegistry) Peers() iter.Seq[*Peer] {
|
|||
|
||||
// UpdateMetrics updates a peer's performance metrics.
|
||||
// Note: Persistence is debounced. Call Close() to flush before shutdown.
|
||||
func (r *PeerRegistry) UpdateMetrics(id string, pingMS, geoKM float64, hops int) error {
|
||||
func (r *PeerRegistry) UpdateMetrics(id string, pingMilliseconds, geoKilometres float64, hopCount int) error {
|
||||
r.mu.Lock()
|
||||
|
||||
peer, exists := r.peers[id]
|
||||
|
|
@ -372,15 +372,16 @@ func (r *PeerRegistry) UpdateMetrics(id string, pingMS, geoKM float64, hops int)
|
|||
return core.E("PeerRegistry.UpdateMetrics", "peer "+id+" not found", nil)
|
||||
}
|
||||
|
||||
peer.PingMS = pingMS
|
||||
peer.GeoKM = geoKM
|
||||
peer.Hops = hops
|
||||
peer.PingMS = pingMilliseconds
|
||||
peer.GeoKM = geoKilometres
|
||||
peer.Hops = hopCount
|
||||
peer.LastSeen = time.Now()
|
||||
|
||||
r.rebuildKDTree()
|
||||
r.mu.Unlock()
|
||||
|
||||
return r.save()
|
||||
r.scheduleSave()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateScore updates a peer's reliability score.
|
||||
|
|
@ -401,7 +402,8 @@ func (r *PeerRegistry) UpdateScore(id string, score float64) error {
|
|||
r.rebuildKDTree()
|
||||
r.mu.Unlock()
|
||||
|
||||
return r.save()
|
||||
r.scheduleSave()
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkConnected updates a peer's connection state.
|
||||
|
|
@ -441,7 +443,7 @@ func (r *PeerRegistry) RecordSuccess(id string) {
|
|||
peer.Score = min(peer.Score+ScoreSuccessIncrement, ScoreMaximum)
|
||||
peer.LastSeen = time.Now()
|
||||
r.mu.Unlock()
|
||||
r.save()
|
||||
r.scheduleSave()
|
||||
}
|
||||
|
||||
// RecordFailure records a failed interaction with a peer, reducing their score.
|
||||
|
|
@ -456,7 +458,7 @@ func (r *PeerRegistry) RecordFailure(id string) {
|
|||
peer.Score = max(peer.Score-ScoreFailureDecrement, ScoreMinimum)
|
||||
newScore := peer.Score
|
||||
r.mu.Unlock()
|
||||
r.save()
|
||||
r.scheduleSave()
|
||||
|
||||
logging.Debug("peer score decreased", logging.Fields{
|
||||
"peer_id": id,
|
||||
|
|
@ -477,7 +479,7 @@ func (r *PeerRegistry) RecordTimeout(id string) {
|
|||
peer.Score = max(peer.Score-ScoreTimeoutDecrement, ScoreMinimum)
|
||||
newScore := peer.Score
|
||||
r.mu.Unlock()
|
||||
r.save()
|
||||
r.scheduleSave()
|
||||
|
||||
logging.Debug("peer score decreased", logging.Fields{
|
||||
"peer_id": id,
|
||||
|
|
@ -642,26 +644,26 @@ func (r *PeerRegistry) rebuildKDTree() {
|
|||
}
|
||||
|
||||
// scheduleSave schedules a debounced save operation.
|
||||
// Multiple calls within saveDebounceInterval will be coalesced into a single save.
|
||||
// Must NOT be called with r.mu held.
|
||||
// Multiple calls within peerRegistrySaveDebounceInterval will be coalesced into a single save.
|
||||
// Call it after releasing r.mu so peer state and save state do not interleave.
|
||||
func (r *PeerRegistry) scheduleSave() {
|
||||
r.saveMu.Lock()
|
||||
defer r.saveMu.Unlock()
|
||||
r.saveMutex.Lock()
|
||||
defer r.saveMutex.Unlock()
|
||||
|
||||
r.dirty = true
|
||||
r.hasPendingChanges = true
|
||||
|
||||
// If timer already running, let it handle the save
|
||||
if r.saveTimer != nil {
|
||||
if r.pendingSaveTimer != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Start a new timer
|
||||
r.saveTimer = time.AfterFunc(saveDebounceInterval, func() {
|
||||
r.saveMu.Lock()
|
||||
r.saveTimer = nil
|
||||
shouldSave := r.dirty
|
||||
r.dirty = false
|
||||
r.saveMu.Unlock()
|
||||
r.pendingSaveTimer = time.AfterFunc(peerRegistrySaveDebounceInterval, func() {
|
||||
r.saveMutex.Lock()
|
||||
r.pendingSaveTimer = nil
|
||||
shouldSave := r.hasPendingChanges
|
||||
r.hasPendingChanges = false
|
||||
r.saveMutex.Unlock()
|
||||
|
||||
if shouldSave {
|
||||
r.mu.RLock()
|
||||
|
|
@ -709,19 +711,15 @@ func (r *PeerRegistry) saveNow() error {
|
|||
|
||||
// Close flushes any pending changes and releases resources.
|
||||
func (r *PeerRegistry) Close() error {
|
||||
r.saveStopOnce.Do(func() {
|
||||
close(r.stopChan)
|
||||
})
|
||||
|
||||
// Cancel pending timer and save immediately if dirty
|
||||
r.saveMu.Lock()
|
||||
if r.saveTimer != nil {
|
||||
r.saveTimer.Stop()
|
||||
r.saveTimer = nil
|
||||
// Cancel any pending timer and save immediately if changes are queued.
|
||||
r.saveMutex.Lock()
|
||||
if r.pendingSaveTimer != nil {
|
||||
r.pendingSaveTimer.Stop()
|
||||
r.pendingSaveTimer = nil
|
||||
}
|
||||
shouldSave := r.dirty
|
||||
r.dirty = false
|
||||
r.saveMu.Unlock()
|
||||
shouldSave := r.hasPendingChanges
|
||||
r.hasPendingChanges = false
|
||||
r.saveMutex.Unlock()
|
||||
|
||||
if shouldSave {
|
||||
r.mu.RLock()
|
||||
|
|
@ -733,14 +731,6 @@ func (r *PeerRegistry) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// save is a helper that schedules a debounced save.
|
||||
// Kept for backward compatibility but now debounces writes.
|
||||
// Must NOT be called with r.mu held.
|
||||
func (r *PeerRegistry) save() error {
|
||||
r.scheduleSave()
|
||||
return nil // Errors will be logged asynchronously
|
||||
}
|
||||
|
||||
// load reads peers from disk.
|
||||
func (r *PeerRegistry) load() error {
|
||||
content, err := filesystemRead(r.path)
|
||||
|
|
|
|||
|
|
@ -21,11 +21,11 @@ import (
|
|||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// debugLogCounter tracks message counts for rate limiting debug logs
|
||||
var debugLogCounter atomic.Int64
|
||||
// messageLogSampleCounter tracks message counts for sampled debug logs.
|
||||
var messageLogSampleCounter atomic.Int64
|
||||
|
||||
// debugLogInterval controls how often we log debug messages in hot paths (1 in N)
|
||||
const debugLogInterval = 100
|
||||
// messageLogSampleInterval controls how often we log debug messages in hot paths (1 in N).
|
||||
const messageLogSampleInterval = 100
|
||||
|
||||
// DefaultMaxMessageSize is the default maximum message size (1MB)
|
||||
const DefaultMaxMessageSize int64 = 1 << 20 // 1MB
|
||||
|
|
@ -66,49 +66,49 @@ func DefaultTransportConfig() TransportConfig {
|
|||
// var handler MessageHandler = func(conn *PeerConnection, msg *Message) {}
|
||||
type MessageHandler func(conn *PeerConnection, msg *Message)
|
||||
|
||||
// MessageDeduplicator tracks seen message IDs to prevent duplicate processing
|
||||
// MessageDeduplicator tracks recent message IDs to prevent duplicate processing.
|
||||
//
|
||||
// deduplicator := NewMessageDeduplicator(5 * time.Minute)
|
||||
type MessageDeduplicator struct {
|
||||
seen map[string]time.Time
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
recentMessageTimes map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
timeToLive time.Duration
|
||||
}
|
||||
|
||||
// NewMessageDeduplicator creates a deduplicator with specified TTL
|
||||
// NewMessageDeduplicator creates a deduplicator with the supplied retention window.
|
||||
//
|
||||
// deduplicator := NewMessageDeduplicator(5 * time.Minute)
|
||||
func NewMessageDeduplicator(ttl time.Duration) *MessageDeduplicator {
|
||||
func NewMessageDeduplicator(retentionWindow time.Duration) *MessageDeduplicator {
|
||||
d := &MessageDeduplicator{
|
||||
seen: make(map[string]time.Time),
|
||||
ttl: ttl,
|
||||
recentMessageTimes: make(map[string]time.Time),
|
||||
timeToLive: retentionWindow,
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// IsDuplicate checks if a message ID has been seen recently
|
||||
// IsDuplicate checks whether a message ID is still within the retention window.
|
||||
func (d *MessageDeduplicator) IsDuplicate(msgID string) bool {
|
||||
d.mu.RLock()
|
||||
_, exists := d.seen[msgID]
|
||||
d.mu.RUnlock()
|
||||
d.mutex.RLock()
|
||||
_, exists := d.recentMessageTimes[msgID]
|
||||
d.mutex.RUnlock()
|
||||
return exists
|
||||
}
|
||||
|
||||
// Mark records a message ID as seen
|
||||
// Mark records a message ID as recently seen.
|
||||
func (d *MessageDeduplicator) Mark(msgID string) {
|
||||
d.mu.Lock()
|
||||
d.seen[msgID] = time.Now()
|
||||
d.mu.Unlock()
|
||||
d.mutex.Lock()
|
||||
d.recentMessageTimes[msgID] = time.Now()
|
||||
d.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Cleanup removes expired entries
|
||||
// Cleanup removes expired entries from the deduplicator.
|
||||
func (d *MessageDeduplicator) Cleanup() {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for id, seen := range d.seen {
|
||||
if now.Sub(seen) > d.ttl {
|
||||
delete(d.seen, id)
|
||||
for id, seenAt := range d.recentMessageTimes {
|
||||
if now.Sub(seenAt) > d.timeToLive {
|
||||
delete(d.recentMessageTimes, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -117,61 +117,62 @@ func (d *MessageDeduplicator) Cleanup() {
|
|||
//
|
||||
// transport := NewTransport(nodeManager, peerRegistry, DefaultTransportConfig())
|
||||
type Transport struct {
|
||||
config TransportConfig
|
||||
server *http.Server
|
||||
upgrader websocket.Upgrader
|
||||
conns map[string]*PeerConnection // peer ID -> connection
|
||||
pendingConns atomic.Int32 // tracks connections during handshake
|
||||
node *NodeManager
|
||||
registry *PeerRegistry
|
||||
handler MessageHandler
|
||||
dedup *MessageDeduplicator // Message deduplication
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
config TransportConfig
|
||||
httpServer *http.Server
|
||||
upgrader websocket.Upgrader
|
||||
connections map[string]*PeerConnection // peer ID -> connection
|
||||
pendingHandshakeCount atomic.Int32 // tracks connections during handshake
|
||||
nodeManager *NodeManager
|
||||
peerRegistry *PeerRegistry
|
||||
messageHandler MessageHandler
|
||||
messageDeduplicator *MessageDeduplicator // Message deduplication
|
||||
mutex sync.RWMutex
|
||||
lifecycleContext context.Context
|
||||
cancelLifecycle context.CancelFunc
|
||||
waitGroup sync.WaitGroup
|
||||
}
|
||||
|
||||
// PeerRateLimiter implements a simple token bucket rate limiter per peer
|
||||
// PeerRateLimiter implements a simple token bucket rate limiter per peer.
|
||||
//
|
||||
// rateLimiter := NewPeerRateLimiter(100, 50)
|
||||
type PeerRateLimiter struct {
|
||||
tokens int
|
||||
maxTokens int
|
||||
refillRate int // tokens per second
|
||||
lastRefill time.Time
|
||||
mu sync.Mutex
|
||||
availableTokens int
|
||||
capacity int
|
||||
refillPerSecond int // tokens per second
|
||||
lastRefillTime time.Time
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewPeerRateLimiter creates a rate limiter with specified messages/second
|
||||
// NewPeerRateLimiter creates a token bucket seeded with maxTokens and refilled
|
||||
// at refillRate tokens per second.
|
||||
//
|
||||
// rateLimiter := NewPeerRateLimiter(100, 50)
|
||||
func NewPeerRateLimiter(maxTokens, refillRate int) *PeerRateLimiter {
|
||||
func NewPeerRateLimiter(maxTokens, refillPerSecond int) *PeerRateLimiter {
|
||||
return &PeerRateLimiter{
|
||||
tokens: maxTokens,
|
||||
maxTokens: maxTokens,
|
||||
refillRate: refillRate,
|
||||
lastRefill: time.Now(),
|
||||
availableTokens: maxTokens,
|
||||
capacity: maxTokens,
|
||||
refillPerSecond: refillPerSecond,
|
||||
lastRefillTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a message is allowed and consumes a token if so
|
||||
func (r *PeerRateLimiter) Allow() bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
// Refill tokens based on elapsed time
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(r.lastRefill)
|
||||
tokensToAdd := int(elapsed.Seconds()) * r.refillRate
|
||||
elapsed := now.Sub(r.lastRefillTime)
|
||||
tokensToAdd := int(elapsed.Seconds()) * r.refillPerSecond
|
||||
if tokensToAdd > 0 {
|
||||
r.tokens = min(r.tokens+tokensToAdd, r.maxTokens)
|
||||
r.lastRefill = now
|
||||
r.availableTokens = min(r.availableTokens+tokensToAdd, r.capacity)
|
||||
r.lastRefillTime = now
|
||||
}
|
||||
|
||||
// Check if we have tokens available
|
||||
if r.tokens > 0 {
|
||||
r.tokens--
|
||||
if r.availableTokens > 0 {
|
||||
r.availableTokens--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
|
@ -186,7 +187,7 @@ type PeerConnection struct {
|
|||
SharedSecret []byte // Derived via X25519 ECDH, used for SMSG
|
||||
LastActivity time.Time
|
||||
UserAgent string // Request identity advertised by the peer
|
||||
writeMu sync.Mutex // Serialize WebSocket writes
|
||||
writeMutex sync.Mutex // Serialize WebSocket writes
|
||||
transport *Transport
|
||||
closeOnce sync.Once // Ensure Close() is only called once
|
||||
rateLimiter *PeerRateLimiter // Per-peer message rate limiting
|
||||
|
|
@ -196,14 +197,14 @@ type PeerConnection struct {
|
|||
//
|
||||
// transport := NewTransport(nodeManager, peerRegistry, DefaultTransportConfig())
|
||||
func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportConfig) *Transport {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
lifecycleContext, cancelLifecycle := context.WithCancel(context.Background())
|
||||
|
||||
return &Transport{
|
||||
config: config,
|
||||
node: node,
|
||||
registry: registry,
|
||||
conns: make(map[string]*PeerConnection),
|
||||
dedup: NewMessageDeduplicator(5 * time.Minute), // 5 minute TTL for dedup
|
||||
config: config,
|
||||
nodeManager: node,
|
||||
peerRegistry: registry,
|
||||
connections: make(map[string]*PeerConnection),
|
||||
messageDeduplicator: NewMessageDeduplicator(5 * time.Minute), // 5 minute TTL for dedup
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
|
|
@ -222,8 +223,8 @@ func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportCon
|
|||
return host == "localhost" || host == "127.0.0.1" || host == "::1"
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
lifecycleContext: lifecycleContext,
|
||||
cancelLifecycle: cancelLifecycle,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -263,7 +264,7 @@ func agentHeaderToken(value string) string {
|
|||
|
||||
// agentUserAgent returns a transparent identity string for request headers.
|
||||
func (t *Transport) agentUserAgent() string {
|
||||
identity := t.node.Identity()
|
||||
identity := t.nodeManager.Identity()
|
||||
if identity == nil {
|
||||
return core.Sprintf("%s proto=%s", agentUserAgentPrefix, ProtocolVersion)
|
||||
}
|
||||
|
|
@ -283,7 +284,7 @@ func (t *Transport) Start() error {
|
|||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(t.config.WSPath, t.handleWSUpgrade)
|
||||
|
||||
t.server = &http.Server{
|
||||
t.httpServer = &http.Server{
|
||||
Addr: t.config.ListenAddr,
|
||||
Handler: mux,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
|
|
@ -294,7 +295,7 @@ func (t *Transport) Start() error {
|
|||
|
||||
// Apply TLS hardening if TLS is enabled
|
||||
if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" {
|
||||
t.server.TLSConfig = &tls.Config{
|
||||
t.httpServer.TLSConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
// TLS 1.3 ciphers (automatically used when available)
|
||||
|
|
@ -316,12 +317,12 @@ func (t *Transport) Start() error {
|
|||
}
|
||||
}
|
||||
|
||||
t.wg.Go(func() {
|
||||
t.waitGroup.Go(func() {
|
||||
var err error
|
||||
if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" {
|
||||
err = t.server.ListenAndServeTLS(t.config.TLSCertPath, t.config.TLSKeyPath)
|
||||
err = t.httpServer.ListenAndServeTLS(t.config.TLSCertPath, t.config.TLSKeyPath)
|
||||
} else {
|
||||
err = t.server.ListenAndServe()
|
||||
err = t.httpServer.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
logging.Error("HTTP server error", logging.Fields{"error": err, "addr": t.config.ListenAddr})
|
||||
|
|
@ -329,15 +330,15 @@ func (t *Transport) Start() error {
|
|||
})
|
||||
|
||||
// Start message deduplication cleanup goroutine
|
||||
t.wg.Go(func() {
|
||||
t.waitGroup.Go(func() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case <-t.lifecycleContext.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
t.dedup.Cleanup()
|
||||
t.messageDeduplicator.Cleanup()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
@ -347,28 +348,28 @@ func (t *Transport) Start() error {
|
|||
|
||||
// Stop closes active connections and shuts the transport down cleanly.
|
||||
func (t *Transport) Stop() error {
|
||||
t.cancel()
|
||||
t.cancelLifecycle()
|
||||
|
||||
// Gracefully close all connections with shutdown message
|
||||
t.mu.RLock()
|
||||
conns := slices.Collect(maps.Values(t.conns))
|
||||
t.mu.RUnlock()
|
||||
t.mutex.RLock()
|
||||
connections := slices.Collect(maps.Values(t.connections))
|
||||
t.mutex.RUnlock()
|
||||
|
||||
for _, pc := range conns {
|
||||
for _, pc := range connections {
|
||||
pc.GracefulClose("server shutdown", DisconnectShutdown)
|
||||
}
|
||||
|
||||
// Shutdown HTTP server if it was started
|
||||
if t.server != nil {
|
||||
if t.httpServer != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := t.server.Shutdown(ctx); err != nil {
|
||||
if err := t.httpServer.Shutdown(ctx); err != nil {
|
||||
return core.E("Transport.Stop", "server shutdown error", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.wg.Wait()
|
||||
t.waitGroup.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -376,9 +377,9 @@ func (t *Transport) Stop() error {
|
|||
//
|
||||
// transport.OnMessage(worker.HandleMessage)
|
||||
func (t *Transport) OnMessage(handler MessageHandler) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.handler = handler
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
t.messageHandler = handler
|
||||
}
|
||||
|
||||
// Connect dials a peer, completes the handshake, and starts the session loops.
|
||||
|
|
@ -419,9 +420,9 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
|
|||
}
|
||||
|
||||
// Store connection using the real peer ID from handshake
|
||||
t.mu.Lock()
|
||||
t.conns[pc.Peer.ID] = pc
|
||||
t.mu.Unlock()
|
||||
t.mutex.Lock()
|
||||
t.connections[pc.Peer.ID] = pc
|
||||
t.mutex.Unlock()
|
||||
|
||||
logging.Debug("connected to peer", logging.Fields{"peer_id": pc.Peer.ID, "secret_len": len(pc.SharedSecret)})
|
||||
logging.Debug("connected peer metadata", logging.Fields{
|
||||
|
|
@ -430,16 +431,16 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
|
|||
})
|
||||
|
||||
// Update registry
|
||||
t.registry.MarkConnected(pc.Peer.ID, true)
|
||||
t.peerRegistry.MarkConnected(pc.Peer.ID, true)
|
||||
|
||||
// Start read loop
|
||||
t.wg.Add(1)
|
||||
t.waitGroup.Add(1)
|
||||
go t.readLoop(pc)
|
||||
|
||||
logging.Debug("started readLoop for peer", logging.Fields{"peer_id": pc.Peer.ID})
|
||||
|
||||
// Start keepalive
|
||||
t.wg.Add(1)
|
||||
t.waitGroup.Add(1)
|
||||
go t.keepalive(pc)
|
||||
|
||||
return pc, nil
|
||||
|
|
@ -447,9 +448,9 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
|
|||
|
||||
// Send transmits an encrypted message to a connected peer.
|
||||
func (t *Transport) Send(peerID string, msg *Message) error {
|
||||
t.mu.RLock()
|
||||
pc, exists := t.conns[peerID]
|
||||
t.mu.RUnlock()
|
||||
t.mutex.RLock()
|
||||
pc, exists := t.connections[peerID]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return core.E("Transport.Send", "peer "+peerID+" not connected", nil)
|
||||
|
|
@ -461,10 +462,10 @@ func (t *Transport) Send(peerID string, msg *Message) error {
|
|||
// Connections returns an iterator over all active peer connections.
|
||||
func (t *Transport) Connections() iter.Seq[*PeerConnection] {
|
||||
return func(yield func(*PeerConnection) bool) {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
for _, pc := range t.conns {
|
||||
for _, pc := range t.connections {
|
||||
if !yield(pc) {
|
||||
return
|
||||
}
|
||||
|
|
@ -494,9 +495,9 @@ func (t *Transport) Broadcast(msg *Message) error {
|
|||
//
|
||||
// connection := transport.Connection("worker-1")
|
||||
func (t *Transport) Connection(peerID string) *PeerConnection {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.conns[peerID]
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
return t.connections[peerID]
|
||||
}
|
||||
|
||||
// handleWSUpgrade handles incoming WebSocket connections.
|
||||
|
|
@ -504,20 +505,20 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
|||
userAgent := r.Header.Get("User-Agent")
|
||||
|
||||
// Enforce MaxConns limit (including pending connections during handshake)
|
||||
t.mu.RLock()
|
||||
currentConns := len(t.conns)
|
||||
t.mu.RUnlock()
|
||||
pendingConns := int(t.pendingConns.Load())
|
||||
t.mutex.RLock()
|
||||
currentConnections := len(t.connections)
|
||||
t.mutex.RUnlock()
|
||||
pendingHandshakeCount := int(t.pendingHandshakeCount.Load())
|
||||
|
||||
totalConns := currentConns + pendingConns
|
||||
if totalConns >= t.config.MaxConns {
|
||||
totalConnections := currentConnections + pendingHandshakeCount
|
||||
if totalConnections >= t.config.MaxConns {
|
||||
http.Error(w, "Too many connections", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Track this connection as pending during handshake
|
||||
t.pendingConns.Add(1)
|
||||
defer t.pendingConns.Add(-1)
|
||||
t.pendingHandshakeCount.Add(1)
|
||||
defer t.pendingHandshakeCount.Add(-1)
|
||||
|
||||
conn, err := t.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
|
|
@ -568,7 +569,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
|||
"peer_id": payload.Identity.ID,
|
||||
"user_agent": userAgent,
|
||||
})
|
||||
identity := t.node.Identity()
|
||||
identity := t.nodeManager.Identity()
|
||||
if identity != nil {
|
||||
rejectPayload := HandshakeAckPayload{
|
||||
Identity: *identity,
|
||||
|
|
@ -585,14 +586,14 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Derive shared secret from peer's public key
|
||||
sharedSecret, err := t.node.DeriveSharedSecret(payload.Identity.PublicKey)
|
||||
sharedSecret, err := t.nodeManager.DeriveSharedSecret(payload.Identity.PublicKey)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if peer is allowed to connect (allowlist check)
|
||||
if !t.registry.IsPeerAllowed(payload.Identity.ID, payload.Identity.PublicKey) {
|
||||
if !t.peerRegistry.IsPeerAllowed(payload.Identity.ID, payload.Identity.PublicKey) {
|
||||
logging.Warn("peer connection rejected: not in allowlist", logging.Fields{
|
||||
"peer_id": payload.Identity.ID,
|
||||
"peer_name": payload.Identity.Name,
|
||||
|
|
@ -600,7 +601,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
|||
"user_agent": userAgent,
|
||||
})
|
||||
// Send rejection before closing
|
||||
identity := t.node.Identity()
|
||||
identity := t.nodeManager.Identity()
|
||||
if identity != nil {
|
||||
rejectPayload := HandshakeAckPayload{
|
||||
Identity: *identity,
|
||||
|
|
@ -617,7 +618,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Create peer if not exists (only if auth passed)
|
||||
peer := t.registry.Peer(payload.Identity.ID)
|
||||
peer := t.peerRegistry.Peer(payload.Identity.ID)
|
||||
if peer == nil {
|
||||
// Auto-register the peer since they passed allowlist check
|
||||
peer = &Peer{
|
||||
|
|
@ -628,7 +629,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
|||
AddedAt: time.Now(),
|
||||
Score: 50,
|
||||
}
|
||||
t.registry.AddPeer(peer)
|
||||
t.peerRegistry.AddPeer(peer)
|
||||
logging.Info("auto-registered new peer", logging.Fields{
|
||||
"peer_id": peer.ID,
|
||||
"peer_name": peer.Name,
|
||||
|
|
@ -646,7 +647,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Send handshake acknowledgment
|
||||
identity := t.node.Identity()
|
||||
identity := t.nodeManager.Identity()
|
||||
if identity == nil {
|
||||
conn.Close()
|
||||
return
|
||||
|
|
@ -683,12 +684,12 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Store connection
|
||||
t.mu.Lock()
|
||||
t.conns[peer.ID] = pc
|
||||
t.mu.Unlock()
|
||||
t.mutex.Lock()
|
||||
t.connections[peer.ID] = pc
|
||||
t.mutex.Unlock()
|
||||
|
||||
// Update registry
|
||||
t.registry.MarkConnected(peer.ID, true)
|
||||
t.peerRegistry.MarkConnected(peer.ID, true)
|
||||
|
||||
logging.Debug("accepted peer connection", logging.Fields{
|
||||
"peer_id": peer.ID,
|
||||
|
|
@ -697,11 +698,11 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
|||
})
|
||||
|
||||
// Start read loop
|
||||
t.wg.Add(1)
|
||||
t.waitGroup.Add(1)
|
||||
go t.readLoop(pc)
|
||||
|
||||
// Start keepalive
|
||||
t.wg.Add(1)
|
||||
t.waitGroup.Add(1)
|
||||
go t.keepalive(pc)
|
||||
}
|
||||
|
||||
|
|
@ -717,7 +718,7 @@ func (t *Transport) performHandshake(pc *PeerConnection) error {
|
|||
pc.Conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
|
||||
identity := t.node.Identity()
|
||||
identity := t.nodeManager.Identity()
|
||||
if identity == nil {
|
||||
return ErrorIdentityNotInitialized
|
||||
}
|
||||
|
|
@ -780,7 +781,7 @@ func (t *Transport) performHandshake(pc *PeerConnection) error {
|
|||
pc.Peer.Role = ackPayload.Identity.Role
|
||||
|
||||
// Verify challenge response - derive shared secret first using the peer's public key
|
||||
sharedSecret, err := t.node.DeriveSharedSecret(pc.Peer.PublicKey)
|
||||
sharedSecret, err := t.nodeManager.DeriveSharedSecret(pc.Peer.PublicKey)
|
||||
if err != nil {
|
||||
return core.E("Transport.performHandshake", "derive shared secret for challenge verification", err)
|
||||
}
|
||||
|
|
@ -797,9 +798,9 @@ func (t *Transport) performHandshake(pc *PeerConnection) error {
|
|||
pc.SharedSecret = sharedSecret
|
||||
|
||||
// Update the peer in registry with the real identity
|
||||
if err := t.registry.UpdatePeer(pc.Peer); err != nil {
|
||||
if err := t.peerRegistry.UpdatePeer(pc.Peer); err != nil {
|
||||
// If update fails (peer not found with old ID), add as new
|
||||
t.registry.AddPeer(pc.Peer)
|
||||
t.peerRegistry.AddPeer(pc.Peer)
|
||||
}
|
||||
|
||||
logging.Debug("handshake completed with challenge-response verification", logging.Fields{
|
||||
|
|
@ -813,7 +814,7 @@ func (t *Transport) performHandshake(pc *PeerConnection) error {
|
|||
|
||||
// readLoop reads messages from a peer connection.
|
||||
func (t *Transport) readLoop(pc *PeerConnection) {
|
||||
defer t.wg.Done()
|
||||
defer t.waitGroup.Done()
|
||||
defer t.removeConnection(pc)
|
||||
|
||||
// Apply message size limit to prevent memory exhaustion attacks
|
||||
|
|
@ -825,7 +826,7 @@ func (t *Transport) readLoop(pc *PeerConnection) {
|
|||
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case <-t.lifecycleContext.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
|
@ -859,21 +860,21 @@ func (t *Transport) readLoop(pc *PeerConnection) {
|
|||
}
|
||||
|
||||
// Check for duplicate messages (prevents amplification attacks)
|
||||
if t.dedup.IsDuplicate(msg.ID) {
|
||||
if t.messageDeduplicator.IsDuplicate(msg.ID) {
|
||||
logging.Debug("dropping duplicate message", logging.Fields{"msg_id": msg.ID, "peer_id": pc.Peer.ID})
|
||||
continue
|
||||
}
|
||||
t.dedup.Mark(msg.ID)
|
||||
t.messageDeduplicator.Mark(msg.ID)
|
||||
|
||||
// Rate limit debug logs in hot path to reduce noise (log 1 in N messages)
|
||||
if debugLogCounter.Add(1)%debugLogInterval == 0 {
|
||||
if messageLogSampleCounter.Add(1)%messageLogSampleInterval == 0 {
|
||||
logging.Debug("received message from peer", logging.Fields{"type": msg.Type, "peer_id": pc.Peer.ID, "reply_to": msg.ReplyTo, "sample": "1/100"})
|
||||
}
|
||||
|
||||
// Dispatch to handler (read handler under lock to avoid race)
|
||||
t.mu.RLock()
|
||||
handler := t.handler
|
||||
t.mu.RUnlock()
|
||||
t.mutex.RLock()
|
||||
handler := t.messageHandler
|
||||
t.mutex.RUnlock()
|
||||
if handler != nil {
|
||||
handler(pc, msg)
|
||||
}
|
||||
|
|
@ -882,14 +883,14 @@ func (t *Transport) readLoop(pc *PeerConnection) {
|
|||
|
||||
// keepalive sends periodic pings.
|
||||
func (t *Transport) keepalive(pc *PeerConnection) {
|
||||
defer t.wg.Done()
|
||||
defer t.waitGroup.Done()
|
||||
|
||||
ticker := time.NewTicker(t.config.PingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case <-t.lifecycleContext.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Check if connection is still alive
|
||||
|
|
@ -899,7 +900,7 @@ func (t *Transport) keepalive(pc *PeerConnection) {
|
|||
}
|
||||
|
||||
// Send ping
|
||||
identity := t.node.Identity()
|
||||
identity := t.nodeManager.Identity()
|
||||
pingMsg, err := NewMessage(MessagePing, identity.ID, pc.Peer.ID, PingPayload{
|
||||
SentAt: time.Now().UnixMilli(),
|
||||
})
|
||||
|
|
@ -917,18 +918,18 @@ func (t *Transport) keepalive(pc *PeerConnection) {
|
|||
|
||||
// removeConnection removes and cleans up a connection.
|
||||
func (t *Transport) removeConnection(pc *PeerConnection) {
|
||||
t.mu.Lock()
|
||||
delete(t.conns, pc.Peer.ID)
|
||||
t.mu.Unlock()
|
||||
t.mutex.Lock()
|
||||
delete(t.connections, pc.Peer.ID)
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.registry.MarkConnected(pc.Peer.ID, false)
|
||||
t.peerRegistry.MarkConnected(pc.Peer.ID, false)
|
||||
pc.Close()
|
||||
}
|
||||
|
||||
// Send sends an encrypted message over the connection.
|
||||
func (pc *PeerConnection) Send(msg *Message) error {
|
||||
pc.writeMu.Lock()
|
||||
defer pc.writeMu.Unlock()
|
||||
pc.writeMutex.Lock()
|
||||
defer pc.writeMutex.Unlock()
|
||||
|
||||
// Encrypt message using SMSG
|
||||
data, err := pc.transport.encryptMessage(msg, pc.SharedSecret)
|
||||
|
|
@ -976,11 +977,11 @@ func (pc *PeerConnection) GracefulClose(reason string, code int) error {
|
|||
var err error
|
||||
pc.closeOnce.Do(func() {
|
||||
// Try to send disconnect message (best effort).
|
||||
// Note: we must NOT call SetWriteDeadline outside writeMu — Send()
|
||||
// Note: we must NOT call SetWriteDeadline outside writeMutex — Send()
|
||||
// already manages write deadlines under the lock. Setting it here
|
||||
// without the lock races with concurrent Send() calls (P2P-RACE-1).
|
||||
if pc.transport != nil && pc.SharedSecret != nil {
|
||||
identity := pc.transport.node.Identity()
|
||||
identity := pc.transport.nodeManager.Identity()
|
||||
if identity != nil {
|
||||
payload := DisconnectPayload{
|
||||
Reason: reason,
|
||||
|
|
@ -1042,7 +1043,7 @@ func (t *Transport) decryptMessage(data []byte, sharedSecret []byte) (*Message,
|
|||
//
|
||||
// count := transport.ConnectedPeerCount()
|
||||
func (t *Transport) ConnectedPeerCount() int {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return len(t.conns)
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
return len(t.connections)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue