Mining/pkg/node/transport.go
Claude 4e7d4d906b
ax(batch): remove redundant prose comments in source files
Delete inline comments that restate what the next line of code does
(AX Principle 2). Affected: circuit_breaker.go, service.go, transport.go,
worker.go, identity.go, peer.go, bufpool.go, settings_manager.go,
node_service.go.

Co-Authored-By: Charon <charon@lethean.io>
2026-04-02 18:48:08 +01:00

936 lines
29 KiB
Go

package node
import (
"context"
"crypto/tls"
"encoding/base64"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
"forge.lthn.ai/Snider/Borg/pkg/smsg"
"forge.lthn.ai/Snider/Mining/pkg/logging"
"github.com/gorilla/websocket"
)
// if debugLogCounter.Add(1)%debugLogInterval == 0 { logging.Debug("received message", ...) }
var debugLogCounter atomic.Int64
// if debugLogCounter.Add(1)%debugLogInterval == 0 { logging.Debug("...") } // logs 1 in 100 messages
const debugLogInterval = 100
// DefaultMaxMessageSize is 1MB; conn.SetReadLimit(DefaultMaxMessageSize) protects against oversized messages.
const DefaultMaxMessageSize int64 = 1 << 20 // 1MB
// transportConfig := node.DefaultTransportConfig()
// transportConfig.ListenAddr = ":9095"
// transportConfig.MaxConns = 50
type TransportConfig struct {
ListenAddr string // ":9091" default
WSPath string // "/ws" - WebSocket endpoint path
TLSCertPath string // Optional TLS for wss://
TLSKeyPath string
MaxConns int // Maximum concurrent connections
MaxMessageSize int64 // Maximum message size in bytes (0 = 1MB default)
PingInterval time.Duration // WebSocket keepalive interval
PongTimeout time.Duration // Timeout waiting for pong
}
// transportConfig := node.DefaultTransportConfig()
// transportConfig.ListenAddr = ":9095"
// transportConfig.TLSCertPath = "/etc/lethean/tls.crt"
func DefaultTransportConfig() TransportConfig {
return TransportConfig{
ListenAddr: ":9091",
WSPath: "/ws",
MaxConns: 100,
MaxMessageSize: DefaultMaxMessageSize,
PingInterval: 30 * time.Second,
PongTimeout: 10 * time.Second,
}
}
// transport.OnMessage(func(conn *PeerConnection, msg *Message) { worker.HandleMessage(conn, msg) })
type MessageHandler func(conn *PeerConnection, msg *Message)
// dedup := node.NewMessageDeduplicator(5 * time.Minute)
// if dedup.IsDuplicate(msg.ID) { continue }
// dedup.Mark(msg.ID)
type MessageDeduplicator struct {
seen map[string]time.Time
mutex sync.RWMutex
timeToLive time.Duration
}
// dedup := node.NewMessageDeduplicator(5 * time.Minute)
// dedup.Mark(msg.ID)
// if dedup.IsDuplicate(msg.ID) { continue }
func NewMessageDeduplicator(timeToLive time.Duration) *MessageDeduplicator {
deduplicator := &MessageDeduplicator{
seen: make(map[string]time.Time),
timeToLive: timeToLive,
}
return deduplicator
}
// if dedup.IsDuplicate(msg.ID) { continue } // drop already-processed message
func (deduplicator *MessageDeduplicator) IsDuplicate(msgID string) bool {
deduplicator.mutex.RLock()
_, exists := deduplicator.seen[msgID]
deduplicator.mutex.RUnlock()
return exists
}
// dedup.Mark(msg.ID) // call after IsDuplicate returns false
func (deduplicator *MessageDeduplicator) Mark(msgID string) {
deduplicator.mutex.Lock()
deduplicator.seen[msgID] = time.Now()
deduplicator.mutex.Unlock()
}
// go dedup.Cleanup() // call periodically; entries older than timeToLive are dropped
func (deduplicator *MessageDeduplicator) Cleanup() {
deduplicator.mutex.Lock()
defer deduplicator.mutex.Unlock()
now := time.Now()
for id, seen := range deduplicator.seen {
if now.Sub(seen) > deduplicator.timeToLive {
delete(deduplicator.seen, id)
}
}
}
// transport := node.NewTransport(nodeManager, peerRegistry, node.DefaultTransportConfig())
// if err := transport.Start(); err != nil { return err }
// defer transport.Stop()
type Transport struct {
config TransportConfig
server *http.Server
upgrader websocket.Upgrader
connections map[string]*PeerConnection // peer ID -> connection
pendingConns atomic.Int32 // tracks connections during handshake
node *NodeManager
registry *PeerRegistry
handler MessageHandler
deduplicator *MessageDeduplicator // Message deduplication
mutex sync.RWMutex
cancelContext context.Context
cancel context.CancelFunc
waitGroup sync.WaitGroup
}
// limiter := node.NewPeerRateLimiter(100, 50) // 100 burst capacity, 50 tokens/sec refill
// if !limiter.Allow() { continue } // drop message from rate-limited peer
type PeerRateLimiter struct {
tokens int
maxTokens int
refillRate int // tokens per second
lastRefill time.Time
mutex sync.Mutex
}
// limiter := node.NewPeerRateLimiter(100, 50) // 100 burst, 50 messages/second refill
// if limiter.Allow() { process(msg) } else { drop(msg) }
func NewPeerRateLimiter(maxTokens, refillRate int) *PeerRateLimiter {
return &PeerRateLimiter{
tokens: maxTokens,
maxTokens: maxTokens,
refillRate: refillRate,
lastRefill: time.Now(),
}
}
// if limiter.Allow() { process(msg) } else { drop(msg) }
func (limiter *PeerRateLimiter) Allow() bool {
limiter.mutex.Lock()
defer limiter.mutex.Unlock()
// Refill tokens based on elapsed time
now := time.Now()
elapsed := now.Sub(limiter.lastRefill)
tokensToAdd := int(elapsed.Seconds()) * limiter.refillRate
if tokensToAdd > 0 {
limiter.tokens = min(limiter.tokens+tokensToAdd, limiter.maxTokens)
limiter.lastRefill = now
}
if limiter.tokens > 0 {
limiter.tokens--
return true
}
return false
}
// connection := transport.connections[peer.ID]
// if err := connection.Send(msg); err != nil { logging.Error("send failed", ...) }
// connection.GracefulClose("shutdown", DisconnectShutdown)
type PeerConnection struct {
Peer *Peer
Conn *websocket.Conn
SharedSecret []byte // Derived via X25519 ECDH, used for SMSG
LastActivity time.Time
writeMutex sync.Mutex // Serialize WebSocket writes
transport *Transport
closeOnce sync.Once // Ensure Close() is only called once
rateLimiter *PeerRateLimiter // Per-peer message rate limiting
}
// transportConfig := node.DefaultTransportConfig()
// transportConfig.ListenAddr = ":9095"
// transport := node.NewTransport(nodeManager, peerRegistry, transportConfig)
// if err := transport.Start(); err != nil { return err }
func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportConfig) *Transport {
cancelContext, cancel := context.WithCancel(context.Background())
return &Transport{
config: config,
node: node,
registry: registry,
connections: make(map[string]*PeerConnection),
deduplicator: NewMessageDeduplicator(5 * time.Minute), // 5 minute TTL for dedup
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(request *http.Request) bool {
// Allow local connections only for security
origin := request.Header.Get("Origin")
if origin == "" {
return true // No origin header (non-browser client)
}
// Allow localhost and 127.0.0.1 origins
parsedURL, err := url.Parse(origin)
if err != nil {
return false
}
host := parsedURL.Hostname()
return host == "localhost" || host == "127.0.0.1" || host == "::1"
},
},
cancelContext: cancelContext,
cancel: cancel,
}
}
// if err := transport.Start(); err != nil { return err }
func (transport *Transport) Start() error {
serveMux := http.NewServeMux()
serveMux.HandleFunc(transport.config.WSPath, transport.handleWSUpgrade)
transport.server = &http.Server{
Addr: transport.config.ListenAddr,
Handler: serveMux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
ReadHeaderTimeout: 10 * time.Second,
}
// Apply TLS hardening if TLS is enabled
if transport.config.TLSCertPath != "" && transport.config.TLSKeyPath != "" {
transport.server.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
// TLS 1.3 ciphers (automatically used when available)
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_AES_256_GCM_SHA384,
tls.TLS_CHACHA20_POLY1305_SHA256,
// TLS 1.2 secure ciphers
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
},
CurvePreferences: []tls.CurveID{
tls.X25519,
tls.CurveP256,
},
}
}
transport.waitGroup.Add(1)
go func() {
defer transport.waitGroup.Done()
var err error
if transport.config.TLSCertPath != "" && transport.config.TLSKeyPath != "" {
err = transport.server.ListenAndServeTLS(transport.config.TLSCertPath, transport.config.TLSKeyPath)
} else {
err = transport.server.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
logging.Error("HTTP server error", logging.Fields{"error": err, "addr": transport.config.ListenAddr})
}
}()
// Start message deduplication cleanup goroutine
transport.waitGroup.Add(1)
go func() {
defer transport.waitGroup.Done()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-transport.cancelContext.Done():
return
case <-ticker.C:
transport.deduplicator.Cleanup()
}
}
}()
return nil
}
// defer transport.Stop() // call on shutdown; drains connections before returning
func (transport *Transport) Stop() error {
transport.cancel()
// Gracefully close all connections with shutdown message
transport.mutex.Lock()
for _, connection := range transport.connections {
connection.GracefulClose("server shutdown", DisconnectShutdown)
}
transport.mutex.Unlock()
// Shutdown HTTP server if it was started
if transport.server != nil {
shutdownContext, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := transport.server.Shutdown(shutdownContext); err != nil {
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "server shutdown error: " + err.Error()}
}
}
transport.waitGroup.Wait()
return nil
}
// transport.OnMessage(worker.HandleMessage) // call before Start() to avoid races
func (transport *Transport) OnMessage(handler MessageHandler) {
transport.mutex.Lock()
defer transport.mutex.Unlock()
transport.handler = handler
}
// connection, err := transport.Connect(peer)
// if err != nil { return fmt.Errorf("connect failed: %w", err) }
func (transport *Transport) Connect(peer *Peer) (*PeerConnection, error) {
// Build WebSocket URL
scheme := "ws"
if transport.config.TLSCertPath != "" {
scheme = "wss"
}
peerURL := url.URL{Scheme: scheme, Host: peer.Address, Path: transport.config.WSPath}
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, _, err := dialer.Dial(peerURL.String(), nil)
if err != nil {
return nil, &ProtocolError{Code: ErrCodeOperationFailed, Message: "failed to connect to peer: " + err.Error()}
}
connection := &PeerConnection{
Peer: peer,
Conn: conn,
LastActivity: time.Now(),
transport: transport,
rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill
}
if err := transport.performHandshake(connection); err != nil {
conn.Close()
return nil, &ProtocolError{Code: ErrCodeOperationFailed, Message: "handshake failed: " + err.Error()}
}
transport.mutex.Lock()
transport.connections[connection.Peer.ID] = connection
transport.mutex.Unlock()
logging.Debug("connected to peer", logging.Fields{"peer_id": connection.Peer.ID, "secret_len": len(connection.SharedSecret)})
// Update registry
transport.registry.SetConnected(connection.Peer.ID, true)
// Start read loop
transport.waitGroup.Add(1)
go transport.readLoop(connection)
logging.Debug("started readLoop for peer", logging.Fields{"peer_id": connection.Peer.ID})
// Start keepalive
transport.waitGroup.Add(1)
go transport.keepalive(connection)
return connection, nil
}
// if err := transport.Send(peer.ID, msg); err != nil { return err }
func (transport *Transport) Send(peerID string, msg *Message) error {
transport.mutex.RLock()
connection, exists := transport.connections[peerID]
transport.mutex.RUnlock()
if !exists {
return &ProtocolError{Code: ErrCodeNotFound, Message: "peer " + peerID + " not connected"}
}
return connection.Send(msg)
}
// if err := transport.Broadcast(msg); err != nil { logging.Warn("broadcast partial failure", ...) }
// Sender (msg.From) is excluded automatically to prevent echo.
func (transport *Transport) Broadcast(msg *Message) error {
transport.mutex.RLock()
connections := make([]*PeerConnection, 0, len(transport.connections))
for _, connection := range transport.connections {
// Exclude sender from broadcast to prevent echo (P2P-MED-6)
if connection.Peer != nil && connection.Peer.ID == msg.From {
continue
}
connections = append(connections, connection)
}
transport.mutex.RUnlock()
var lastErr error
for _, connection := range connections {
if err := connection.Send(msg); err != nil {
lastErr = err
}
}
return lastErr
}
// conn := transport.GetConnection(peer.ID)
// if conn == nil { return fmt.Errorf("peer not connected") }
func (transport *Transport) GetConnection(peerID string) *PeerConnection {
transport.mutex.RLock()
defer transport.mutex.RUnlock()
return transport.connections[peerID]
}
// mux.HandleFunc(transport.config.WSPath, transport.handleWSUpgrade) // registered in NewTransport via Start()
func (transport *Transport) handleWSUpgrade(responseWriter http.ResponseWriter, request *http.Request) {
// Enforce MaxConns limit (including pending connections during handshake)
transport.mutex.RLock()
currentConns := len(transport.connections)
transport.mutex.RUnlock()
pendingConns := int(transport.pendingConns.Load())
totalConns := currentConns + pendingConns
if totalConns >= transport.config.MaxConns {
http.Error(responseWriter, "Too many connections", http.StatusServiceUnavailable)
return
}
transport.pendingConns.Add(1)
defer transport.pendingConns.Add(-1)
conn, err := transport.upgrader.Upgrade(responseWriter, request, nil)
if err != nil {
return
}
// Apply message size limit during handshake to prevent memory exhaustion
maxSize := transport.config.MaxMessageSize
if maxSize <= 0 {
maxSize = DefaultMaxMessageSize
}
conn.SetReadLimit(maxSize)
// Set handshake timeout to prevent slow/malicious clients from blocking
handshakeTimeout := 10 * time.Second
conn.SetReadDeadline(time.Now().Add(handshakeTimeout))
// Wait for handshake from client
_, data, err := conn.ReadMessage()
if err != nil {
conn.Close()
return
}
// Decode handshake message (not encrypted yet, contains public key)
var message Message
if err := UnmarshalJSON(data, &message); err != nil {
conn.Close()
return
}
if message.Type != MsgHandshake {
conn.Close()
return
}
var payload HandshakePayload
if err := message.ParsePayload(&payload); err != nil {
conn.Close()
return
}
// Check protocol version compatibility (P2P-MED-1)
if !IsProtocolVersionSupported(payload.Version) {
logging.Warn("peer connection rejected: incompatible protocol version", logging.Fields{
"peer_version": payload.Version,
"supported_versions": SupportedProtocolVersions,
"peer_id": payload.Identity.ID,
})
identity := transport.node.GetIdentity()
if identity != nil {
rejectPayload := HandshakeAckPayload{
Identity: *identity,
Accepted: false,
Reason: "incompatible protocol version " + payload.Version + ", supported: " + joinVersions(SupportedProtocolVersions),
}
rejectMsg, _ := NewMessage(MsgHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload)
if rejectData, err := MarshalJSON(rejectMsg); err == nil {
conn.WriteMessage(websocket.TextMessage, rejectData)
}
}
conn.Close()
return
}
// Derive shared secret from peer's public key
sharedSecret, err := transport.node.DeriveSharedSecret(payload.Identity.PublicKey)
if err != nil {
conn.Close()
return
}
if !transport.registry.IsPeerAllowed(payload.Identity.ID, payload.Identity.PublicKey) {
logging.Warn("peer connection rejected: not in allowlist", logging.Fields{
"peer_id": payload.Identity.ID,
"peer_name": payload.Identity.Name,
"public_key": safeKeyPrefix(payload.Identity.PublicKey),
})
// Send rejection before closing
identity := transport.node.GetIdentity()
if identity != nil {
rejectPayload := HandshakeAckPayload{
Identity: *identity,
Accepted: false,
Reason: "peer not authorized",
}
rejectMsg, _ := NewMessage(MsgHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload)
if rejectData, err := MarshalJSON(rejectMsg); err == nil {
conn.WriteMessage(websocket.TextMessage, rejectData)
}
}
conn.Close()
return
}
// Create peer if not exists (only if auth passed)
peer := transport.registry.GetPeer(payload.Identity.ID)
if peer == nil {
// Auto-register the peer since they passed allowlist check
peer = &Peer{
ID: payload.Identity.ID,
Name: payload.Identity.Name,
PublicKey: payload.Identity.PublicKey,
Role: payload.Identity.Role,
AddedAt: time.Now(),
Score: 50,
}
transport.registry.AddPeer(peer)
logging.Info("auto-registered new peer", logging.Fields{
"peer_id": peer.ID,
"peer_name": peer.Name,
})
}
connection := &PeerConnection{
Peer: peer,
Conn: conn,
SharedSecret: sharedSecret,
LastActivity: time.Now(),
transport: transport,
rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill
}
// Send handshake acknowledgment
identity := transport.node.GetIdentity()
if identity == nil {
conn.Close()
return
}
var challengeResponse []byte
if len(payload.Challenge) > 0 {
challengeResponse = SignChallenge(payload.Challenge, sharedSecret)
}
ackPayload := HandshakeAckPayload{
Identity: *identity,
ChallengeResponse: challengeResponse,
Accepted: true,
}
acknowledgementMessage, err := NewMessage(MsgHandshakeAck, identity.ID, peer.ID, ackPayload)
if err != nil {
conn.Close()
return
}
acknowledgementData, err := MarshalJSON(acknowledgementMessage)
if err != nil {
conn.Close()
return
}
if err := conn.WriteMessage(websocket.TextMessage, acknowledgementData); err != nil {
conn.Close()
return
}
transport.mutex.Lock()
transport.connections[peer.ID] = connection
transport.mutex.Unlock()
transport.registry.SetConnected(peer.ID, true)
// Start read loop
transport.waitGroup.Add(1)
go transport.readLoop(connection)
// Start keepalive
transport.waitGroup.Add(1)
go transport.keepalive(connection)
}
// if err := transport.performHandshake(pc); err != nil { conn.Close(); return nil, err }
func (transport *Transport) performHandshake(pc *PeerConnection) error {
// Set handshake timeout
handshakeTimeout := 10 * time.Second
pc.Conn.SetWriteDeadline(time.Now().Add(handshakeTimeout))
pc.Conn.SetReadDeadline(time.Now().Add(handshakeTimeout))
defer func() {
// Reset deadlines after handshake
pc.Conn.SetWriteDeadline(time.Time{})
pc.Conn.SetReadDeadline(time.Time{})
}()
identity := transport.node.GetIdentity()
if identity == nil {
return &ProtocolError{Code: ErrCodeUnauthorized, Message: "node identity not initialized"}
}
// Generate challenge for the server to prove it has the matching private key
challenge, err := GenerateChallenge()
if err != nil {
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "generate challenge: " + err.Error()}
}
payload := HandshakePayload{
Identity: *identity,
Challenge: challenge,
Version: ProtocolVersion,
}
msg, err := NewMessage(MsgHandshake, identity.ID, pc.Peer.ID, payload)
if err != nil {
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "create handshake message: " + err.Error()}
}
// First message is unencrypted (peer needs our public key)
data, err := MarshalJSON(msg)
if err != nil {
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "marshal handshake message: " + err.Error()}
}
if err := pc.Conn.WriteMessage(websocket.TextMessage, data); err != nil {
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "send handshake: " + err.Error()}
}
// Wait for ack
_, acknowledgementData, err := pc.Conn.ReadMessage()
if err != nil {
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "read handshake ack: " + err.Error()}
}
var acknowledgementMessage Message
if err := UnmarshalJSON(acknowledgementData, &acknowledgementMessage); err != nil {
return &ProtocolError{Code: ErrCodeInvalidMessage, Message: "unmarshal handshake ack: " + err.Error()}
}
if acknowledgementMessage.Type != MsgHandshakeAck {
return &ProtocolError{Code: ErrCodeInvalidMessage, Message: "expected handshake_ack, got " + string(acknowledgementMessage.Type)}
}
var ackPayload HandshakeAckPayload
if err := acknowledgementMessage.ParsePayload(&ackPayload); err != nil {
return &ProtocolError{Code: ErrCodeInvalidMessage, Message: "parse handshake ack payload: " + err.Error()}
}
if !ackPayload.Accepted {
return &ProtocolError{Code: ErrCodeUnauthorized, Message: "handshake rejected: " + ackPayload.Reason}
}
// Update peer with the received identity info
pc.Peer.ID = ackPayload.Identity.ID
pc.Peer.PublicKey = ackPayload.Identity.PublicKey
pc.Peer.Name = ackPayload.Identity.Name
pc.Peer.Role = ackPayload.Identity.Role
// Verify challenge response - derive shared secret first using the peer's public key
sharedSecret, err := transport.node.DeriveSharedSecret(pc.Peer.PublicKey)
if err != nil {
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "derive shared secret for challenge verification: " + err.Error()}
}
if len(ackPayload.ChallengeResponse) == 0 {
return &ProtocolError{Code: ErrCodeUnauthorized, Message: "server did not provide challenge response"}
}
if !VerifyChallenge(challenge, ackPayload.ChallengeResponse, sharedSecret) {
return &ProtocolError{Code: ErrCodeUnauthorized, Message: "challenge response verification failed: server may not have matching private key"}
}
pc.SharedSecret = sharedSecret
if err := transport.registry.UpdatePeer(pc.Peer); err != nil {
// If update fails (peer not found with old ID), add as new
transport.registry.AddPeer(pc.Peer)
}
logging.Debug("handshake completed with challenge-response verification", logging.Fields{
"peer_id": pc.Peer.ID,
"peer_name": pc.Peer.Name,
})
return nil
}
// go transport.readLoop(pc) // started by Connect and handleWSUpgrade after handshake
func (transport *Transport) readLoop(pc *PeerConnection) {
defer transport.waitGroup.Done()
defer transport.removeConnection(pc)
// Apply message size limit to prevent memory exhaustion attacks
maxSize := transport.config.MaxMessageSize
if maxSize <= 0 {
maxSize = DefaultMaxMessageSize
}
pc.Conn.SetReadLimit(maxSize)
for {
select {
case <-transport.cancelContext.Done():
return
default:
}
// Set read deadline to prevent blocking forever on unresponsive connections
readDeadline := transport.config.PingInterval + transport.config.PongTimeout
if err := pc.Conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil {
logging.Error("SetReadDeadline error", logging.Fields{"peer_id": pc.Peer.ID, "error": err})
return
}
_, data, err := pc.Conn.ReadMessage()
if err != nil {
logging.Debug("read error from peer", logging.Fields{"peer_id": pc.Peer.ID, "error": err})
return
}
pc.LastActivity = time.Now()
// Check rate limit before processing
if pc.rateLimiter != nil && !pc.rateLimiter.Allow() {
logging.Warn("peer rate limited, dropping message", logging.Fields{"peer_id": pc.Peer.ID})
continue // Drop message from rate-limited peer
}
// Decrypt message using SMSG with shared secret
msg, err := transport.decryptMessage(data, pc.SharedSecret)
if err != nil {
logging.Debug("decrypt error from peer", logging.Fields{"peer_id": pc.Peer.ID, "error": err, "data_len": len(data)})
continue // Skip invalid messages
}
// Check for duplicate messages (prevents amplification attacks)
if transport.deduplicator.IsDuplicate(msg.ID) {
logging.Debug("dropping duplicate message", logging.Fields{"msg_id": msg.ID, "peer_id": pc.Peer.ID})
continue
}
transport.deduplicator.Mark(msg.ID)
// Rate limit debug logs in hot path to reduce noise (log 1 in N messages)
if debugLogCounter.Add(1)%debugLogInterval == 0 {
logging.Debug("received message from peer", logging.Fields{"type": msg.Type, "peer_id": pc.Peer.ID, "reply_to": msg.ReplyTo, "sample": "1/100"})
}
// Dispatch to handler (read handler under lock to avoid race)
transport.mutex.RLock()
handler := transport.handler
transport.mutex.RUnlock()
if handler != nil {
handler(pc, msg)
}
}
}
// go transport.keepalive(pc) // started by Connect and handleWSUpgrade; pings every transport.config.PingInterval
func (transport *Transport) keepalive(pc *PeerConnection) {
defer transport.waitGroup.Done()
ticker := time.NewTicker(transport.config.PingInterval)
defer ticker.Stop()
for {
select {
case <-transport.cancelContext.Done():
return
case <-ticker.C:
if time.Since(pc.LastActivity) > transport.config.PingInterval+transport.config.PongTimeout {
transport.removeConnection(pc)
return
}
// Send ping
identity := transport.node.GetIdentity()
pingMsg, err := NewMessage(MsgPing, identity.ID, pc.Peer.ID, PingPayload{
SentAt: time.Now().UnixMilli(),
})
if err != nil {
continue
}
if err := pc.Send(pingMsg); err != nil {
transport.removeConnection(pc)
return
}
}
}
}
// transport.removeConnection(pc) // called from readLoop on error and keepalive on timeout
func (transport *Transport) removeConnection(pc *PeerConnection) {
transport.mutex.Lock()
delete(transport.connections, pc.Peer.ID)
transport.mutex.Unlock()
transport.registry.SetConnected(pc.Peer.ID, false)
pc.Close()
}
// if err := conn.Send(msg); err != nil { logging.Error("send failed", ...) }
func (pc *PeerConnection) Send(msg *Message) error {
pc.writeMutex.Lock()
defer pc.writeMutex.Unlock()
// Encrypt message using SMSG
data, err := pc.transport.encryptMessage(msg, pc.SharedSecret)
if err != nil {
return err
}
// Set write deadline to prevent blocking forever
if err := pc.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil {
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "failed to set write deadline: " + err.Error()}
}
defer pc.Conn.SetWriteDeadline(time.Time{}) // Reset deadline after send
return pc.Conn.WriteMessage(websocket.BinaryMessage, data)
}
// if err := conn.Close(); err != nil { logging.Warn("close failed", logging.Fields{"error": err}) }
func (pc *PeerConnection) Close() error {
var err error
pc.closeOnce.Do(func() {
err = pc.Conn.Close()
})
return err
}
// pc.GracefulClose("server shutdown", DisconnectShutdown)
// msg, _ := NewMessage(MsgDisconnect, from, to, DisconnectPayload{Reason: "shutdown", Code: DisconnectShutdown})
type DisconnectPayload struct {
Reason string `json:"reason"`
Code int `json:"code"` // Optional disconnect code
}
// Disconnect codes
const (
DisconnectNormal = 1000 // Normal closure
DisconnectGoingAway = 1001 // Server/peer going away
DisconnectProtocolErr = 1002 // Protocol error
DisconnectTimeout = 1003 // Idle timeout
DisconnectShutdown = 1004 // Server shutdown
)
// pc.GracefulClose("server shutdown", DisconnectShutdown)
// pc.GracefulClose("idle timeout", DisconnectTimeout)
func (pc *PeerConnection) GracefulClose(reason string, code int) error {
var err error
pc.closeOnce.Do(func() {
// Try to send disconnect message (best effort)
if pc.transport != nil && pc.SharedSecret != nil {
identity := pc.transport.node.GetIdentity()
if identity != nil {
payload := DisconnectPayload{
Reason: reason,
Code: code,
}
msg, msgErr := NewMessage(MsgDisconnect, identity.ID, pc.Peer.ID, payload)
if msgErr == nil {
// Set short deadline for disconnect message
pc.Conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
pc.Send(msg)
}
}
}
err = pc.Conn.Close()
})
return err
}
// data, err := transport.encryptMessage(msg, pc.SharedSecret)
// if err != nil { return err }
// pc.Conn.WriteMessage(websocket.BinaryMessage, data)
func (transport *Transport) encryptMessage(msg *Message, sharedSecret []byte) ([]byte, error) {
msgData, err := MarshalJSON(msg)
if err != nil {
return nil, err
}
smsgMsg := smsg.NewMessage(string(msgData))
password := base64.StdEncoding.EncodeToString(sharedSecret)
encrypted, err := smsg.Encrypt(smsgMsg, password)
if err != nil {
return nil, err
}
return encrypted, nil
}
// msg, err := transport.decryptMessage(data, pc.SharedSecret)
// if err != nil { continue } // skip invalid/unauthenticated messages
func (transport *Transport) decryptMessage(data []byte, sharedSecret []byte) (*Message, error) {
// Decrypt using shared secret as password
password := base64.StdEncoding.EncodeToString(sharedSecret)
decryptedMessage, err := smsg.Decrypt(data, password)
if err != nil {
return nil, err
}
// Parse message from JSON
var message Message
if err := UnmarshalJSON([]byte(decryptedMessage.Body), &message); err != nil {
return nil, err
}
return &message, nil
}
// connectionCount := transport.ConnectedPeers()
// if connectionCount == 0 { return ErrNoPeers }
func (transport *Transport) ConnectedPeers() int {
transport.mutex.RLock()
defer transport.mutex.RUnlock()
return len(transport.connections)
}