Delete inline comments that restate what the next line of code does (AX Principle 2). Affected: circuit_breaker.go, service.go, transport.go, worker.go, identity.go, peer.go, bufpool.go, settings_manager.go, node_service.go. Co-Authored-By: Charon <charon@lethean.io>
936 lines
29 KiB
Go
936 lines
29 KiB
Go
package node
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"net/http"
|
|
"net/url"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"forge.lthn.ai/Snider/Borg/pkg/smsg"
|
|
"forge.lthn.ai/Snider/Mining/pkg/logging"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// if debugLogCounter.Add(1)%debugLogInterval == 0 { logging.Debug("received message", ...) }
|
|
var debugLogCounter atomic.Int64
|
|
|
|
// if debugLogCounter.Add(1)%debugLogInterval == 0 { logging.Debug("...") } // logs 1 in 100 messages
|
|
const debugLogInterval = 100
|
|
|
|
// DefaultMaxMessageSize is 1MB; conn.SetReadLimit(DefaultMaxMessageSize) protects against oversized messages.
|
|
const DefaultMaxMessageSize int64 = 1 << 20 // 1MB
|
|
|
|
// transportConfig := node.DefaultTransportConfig()
|
|
// transportConfig.ListenAddr = ":9095"
|
|
// transportConfig.MaxConns = 50
|
|
type TransportConfig struct {
|
|
ListenAddr string // ":9091" default
|
|
WSPath string // "/ws" - WebSocket endpoint path
|
|
TLSCertPath string // Optional TLS for wss://
|
|
TLSKeyPath string
|
|
MaxConns int // Maximum concurrent connections
|
|
MaxMessageSize int64 // Maximum message size in bytes (0 = 1MB default)
|
|
PingInterval time.Duration // WebSocket keepalive interval
|
|
PongTimeout time.Duration // Timeout waiting for pong
|
|
}
|
|
|
|
// transportConfig := node.DefaultTransportConfig()
|
|
// transportConfig.ListenAddr = ":9095"
|
|
// transportConfig.TLSCertPath = "/etc/lethean/tls.crt"
|
|
func DefaultTransportConfig() TransportConfig {
|
|
return TransportConfig{
|
|
ListenAddr: ":9091",
|
|
WSPath: "/ws",
|
|
MaxConns: 100,
|
|
MaxMessageSize: DefaultMaxMessageSize,
|
|
PingInterval: 30 * time.Second,
|
|
PongTimeout: 10 * time.Second,
|
|
}
|
|
}
|
|
|
|
// transport.OnMessage(func(conn *PeerConnection, msg *Message) { worker.HandleMessage(conn, msg) })
|
|
type MessageHandler func(conn *PeerConnection, msg *Message)
|
|
|
|
// dedup := node.NewMessageDeduplicator(5 * time.Minute)
|
|
// if dedup.IsDuplicate(msg.ID) { continue }
|
|
// dedup.Mark(msg.ID)
|
|
type MessageDeduplicator struct {
|
|
seen map[string]time.Time
|
|
mutex sync.RWMutex
|
|
timeToLive time.Duration
|
|
}
|
|
|
|
// dedup := node.NewMessageDeduplicator(5 * time.Minute)
|
|
// dedup.Mark(msg.ID)
|
|
// if dedup.IsDuplicate(msg.ID) { continue }
|
|
func NewMessageDeduplicator(timeToLive time.Duration) *MessageDeduplicator {
|
|
deduplicator := &MessageDeduplicator{
|
|
seen: make(map[string]time.Time),
|
|
timeToLive: timeToLive,
|
|
}
|
|
return deduplicator
|
|
}
|
|
|
|
// if dedup.IsDuplicate(msg.ID) { continue } // drop already-processed message
|
|
func (deduplicator *MessageDeduplicator) IsDuplicate(msgID string) bool {
|
|
deduplicator.mutex.RLock()
|
|
_, exists := deduplicator.seen[msgID]
|
|
deduplicator.mutex.RUnlock()
|
|
return exists
|
|
}
|
|
|
|
// dedup.Mark(msg.ID) // call after IsDuplicate returns false
|
|
func (deduplicator *MessageDeduplicator) Mark(msgID string) {
|
|
deduplicator.mutex.Lock()
|
|
deduplicator.seen[msgID] = time.Now()
|
|
deduplicator.mutex.Unlock()
|
|
}
|
|
|
|
// go dedup.Cleanup() // call periodically; entries older than timeToLive are dropped
|
|
func (deduplicator *MessageDeduplicator) Cleanup() {
|
|
deduplicator.mutex.Lock()
|
|
defer deduplicator.mutex.Unlock()
|
|
now := time.Now()
|
|
for id, seen := range deduplicator.seen {
|
|
if now.Sub(seen) > deduplicator.timeToLive {
|
|
delete(deduplicator.seen, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
// transport := node.NewTransport(nodeManager, peerRegistry, node.DefaultTransportConfig())
|
|
// if err := transport.Start(); err != nil { return err }
|
|
// defer transport.Stop()
|
|
type Transport struct {
|
|
config TransportConfig
|
|
server *http.Server
|
|
upgrader websocket.Upgrader
|
|
connections map[string]*PeerConnection // peer ID -> connection
|
|
pendingConns atomic.Int32 // tracks connections during handshake
|
|
node *NodeManager
|
|
registry *PeerRegistry
|
|
handler MessageHandler
|
|
deduplicator *MessageDeduplicator // Message deduplication
|
|
mutex sync.RWMutex
|
|
cancelContext context.Context
|
|
cancel context.CancelFunc
|
|
waitGroup sync.WaitGroup
|
|
}
|
|
|
|
// limiter := node.NewPeerRateLimiter(100, 50) // 100 burst capacity, 50 tokens/sec refill
|
|
// if !limiter.Allow() { continue } // drop message from rate-limited peer
|
|
type PeerRateLimiter struct {
|
|
tokens int
|
|
maxTokens int
|
|
refillRate int // tokens per second
|
|
lastRefill time.Time
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
// limiter := node.NewPeerRateLimiter(100, 50) // 100 burst, 50 messages/second refill
|
|
// if limiter.Allow() { process(msg) } else { drop(msg) }
|
|
func NewPeerRateLimiter(maxTokens, refillRate int) *PeerRateLimiter {
|
|
return &PeerRateLimiter{
|
|
tokens: maxTokens,
|
|
maxTokens: maxTokens,
|
|
refillRate: refillRate,
|
|
lastRefill: time.Now(),
|
|
}
|
|
}
|
|
|
|
// if limiter.Allow() { process(msg) } else { drop(msg) }
|
|
func (limiter *PeerRateLimiter) Allow() bool {
|
|
limiter.mutex.Lock()
|
|
defer limiter.mutex.Unlock()
|
|
|
|
// Refill tokens based on elapsed time
|
|
now := time.Now()
|
|
elapsed := now.Sub(limiter.lastRefill)
|
|
tokensToAdd := int(elapsed.Seconds()) * limiter.refillRate
|
|
if tokensToAdd > 0 {
|
|
limiter.tokens = min(limiter.tokens+tokensToAdd, limiter.maxTokens)
|
|
limiter.lastRefill = now
|
|
}
|
|
|
|
if limiter.tokens > 0 {
|
|
limiter.tokens--
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// connection := transport.connections[peer.ID]
|
|
// if err := connection.Send(msg); err != nil { logging.Error("send failed", ...) }
|
|
// connection.GracefulClose("shutdown", DisconnectShutdown)
|
|
type PeerConnection struct {
|
|
Peer *Peer
|
|
Conn *websocket.Conn
|
|
SharedSecret []byte // Derived via X25519 ECDH, used for SMSG
|
|
LastActivity time.Time
|
|
writeMutex sync.Mutex // Serialize WebSocket writes
|
|
transport *Transport
|
|
closeOnce sync.Once // Ensure Close() is only called once
|
|
rateLimiter *PeerRateLimiter // Per-peer message rate limiting
|
|
}
|
|
|
|
// transportConfig := node.DefaultTransportConfig()
|
|
// transportConfig.ListenAddr = ":9095"
|
|
// transport := node.NewTransport(nodeManager, peerRegistry, transportConfig)
|
|
// if err := transport.Start(); err != nil { return err }
|
|
func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportConfig) *Transport {
|
|
cancelContext, cancel := context.WithCancel(context.Background())
|
|
|
|
return &Transport{
|
|
config: config,
|
|
node: node,
|
|
registry: registry,
|
|
connections: make(map[string]*PeerConnection),
|
|
deduplicator: NewMessageDeduplicator(5 * time.Minute), // 5 minute TTL for dedup
|
|
upgrader: websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(request *http.Request) bool {
|
|
// Allow local connections only for security
|
|
origin := request.Header.Get("Origin")
|
|
if origin == "" {
|
|
return true // No origin header (non-browser client)
|
|
}
|
|
// Allow localhost and 127.0.0.1 origins
|
|
parsedURL, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
host := parsedURL.Hostname()
|
|
return host == "localhost" || host == "127.0.0.1" || host == "::1"
|
|
},
|
|
},
|
|
cancelContext: cancelContext,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
// if err := transport.Start(); err != nil { return err }
|
|
func (transport *Transport) Start() error {
|
|
serveMux := http.NewServeMux()
|
|
serveMux.HandleFunc(transport.config.WSPath, transport.handleWSUpgrade)
|
|
|
|
transport.server = &http.Server{
|
|
Addr: transport.config.ListenAddr,
|
|
Handler: serveMux,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
IdleTimeout: 60 * time.Second,
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
}
|
|
|
|
// Apply TLS hardening if TLS is enabled
|
|
if transport.config.TLSCertPath != "" && transport.config.TLSKeyPath != "" {
|
|
transport.server.TLSConfig = &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
CipherSuites: []uint16{
|
|
// TLS 1.3 ciphers (automatically used when available)
|
|
tls.TLS_AES_128_GCM_SHA256,
|
|
tls.TLS_AES_256_GCM_SHA384,
|
|
tls.TLS_CHACHA20_POLY1305_SHA256,
|
|
// TLS 1.2 secure ciphers
|
|
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
|
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
|
},
|
|
CurvePreferences: []tls.CurveID{
|
|
tls.X25519,
|
|
tls.CurveP256,
|
|
},
|
|
}
|
|
}
|
|
|
|
transport.waitGroup.Add(1)
|
|
go func() {
|
|
defer transport.waitGroup.Done()
|
|
var err error
|
|
if transport.config.TLSCertPath != "" && transport.config.TLSKeyPath != "" {
|
|
err = transport.server.ListenAndServeTLS(transport.config.TLSCertPath, transport.config.TLSKeyPath)
|
|
} else {
|
|
err = transport.server.ListenAndServe()
|
|
}
|
|
if err != nil && err != http.ErrServerClosed {
|
|
logging.Error("HTTP server error", logging.Fields{"error": err, "addr": transport.config.ListenAddr})
|
|
}
|
|
}()
|
|
|
|
// Start message deduplication cleanup goroutine
|
|
transport.waitGroup.Add(1)
|
|
go func() {
|
|
defer transport.waitGroup.Done()
|
|
ticker := time.NewTicker(time.Minute)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-transport.cancelContext.Done():
|
|
return
|
|
case <-ticker.C:
|
|
transport.deduplicator.Cleanup()
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
// defer transport.Stop() // call on shutdown; drains connections before returning
|
|
func (transport *Transport) Stop() error {
|
|
transport.cancel()
|
|
|
|
// Gracefully close all connections with shutdown message
|
|
transport.mutex.Lock()
|
|
for _, connection := range transport.connections {
|
|
connection.GracefulClose("server shutdown", DisconnectShutdown)
|
|
}
|
|
transport.mutex.Unlock()
|
|
|
|
// Shutdown HTTP server if it was started
|
|
if transport.server != nil {
|
|
shutdownContext, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
if err := transport.server.Shutdown(shutdownContext); err != nil {
|
|
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "server shutdown error: " + err.Error()}
|
|
}
|
|
}
|
|
|
|
transport.waitGroup.Wait()
|
|
return nil
|
|
}
|
|
|
|
// transport.OnMessage(worker.HandleMessage) // call before Start() to avoid races
|
|
func (transport *Transport) OnMessage(handler MessageHandler) {
|
|
transport.mutex.Lock()
|
|
defer transport.mutex.Unlock()
|
|
transport.handler = handler
|
|
}
|
|
|
|
// connection, err := transport.Connect(peer)
|
|
// if err != nil { return fmt.Errorf("connect failed: %w", err) }
|
|
func (transport *Transport) Connect(peer *Peer) (*PeerConnection, error) {
|
|
// Build WebSocket URL
|
|
scheme := "ws"
|
|
if transport.config.TLSCertPath != "" {
|
|
scheme = "wss"
|
|
}
|
|
peerURL := url.URL{Scheme: scheme, Host: peer.Address, Path: transport.config.WSPath}
|
|
|
|
dialer := websocket.Dialer{
|
|
HandshakeTimeout: 10 * time.Second,
|
|
}
|
|
conn, _, err := dialer.Dial(peerURL.String(), nil)
|
|
if err != nil {
|
|
return nil, &ProtocolError{Code: ErrCodeOperationFailed, Message: "failed to connect to peer: " + err.Error()}
|
|
}
|
|
|
|
connection := &PeerConnection{
|
|
Peer: peer,
|
|
Conn: conn,
|
|
LastActivity: time.Now(),
|
|
transport: transport,
|
|
rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill
|
|
}
|
|
|
|
if err := transport.performHandshake(connection); err != nil {
|
|
conn.Close()
|
|
return nil, &ProtocolError{Code: ErrCodeOperationFailed, Message: "handshake failed: " + err.Error()}
|
|
}
|
|
|
|
transport.mutex.Lock()
|
|
transport.connections[connection.Peer.ID] = connection
|
|
transport.mutex.Unlock()
|
|
|
|
logging.Debug("connected to peer", logging.Fields{"peer_id": connection.Peer.ID, "secret_len": len(connection.SharedSecret)})
|
|
|
|
// Update registry
|
|
transport.registry.SetConnected(connection.Peer.ID, true)
|
|
|
|
// Start read loop
|
|
transport.waitGroup.Add(1)
|
|
go transport.readLoop(connection)
|
|
|
|
logging.Debug("started readLoop for peer", logging.Fields{"peer_id": connection.Peer.ID})
|
|
|
|
// Start keepalive
|
|
transport.waitGroup.Add(1)
|
|
go transport.keepalive(connection)
|
|
|
|
return connection, nil
|
|
}
|
|
|
|
// if err := transport.Send(peer.ID, msg); err != nil { return err }
|
|
func (transport *Transport) Send(peerID string, msg *Message) error {
|
|
transport.mutex.RLock()
|
|
connection, exists := transport.connections[peerID]
|
|
transport.mutex.RUnlock()
|
|
|
|
if !exists {
|
|
return &ProtocolError{Code: ErrCodeNotFound, Message: "peer " + peerID + " not connected"}
|
|
}
|
|
|
|
return connection.Send(msg)
|
|
}
|
|
|
|
// if err := transport.Broadcast(msg); err != nil { logging.Warn("broadcast partial failure", ...) }
|
|
// Sender (msg.From) is excluded automatically to prevent echo.
|
|
func (transport *Transport) Broadcast(msg *Message) error {
|
|
transport.mutex.RLock()
|
|
connections := make([]*PeerConnection, 0, len(transport.connections))
|
|
for _, connection := range transport.connections {
|
|
// Exclude sender from broadcast to prevent echo (P2P-MED-6)
|
|
if connection.Peer != nil && connection.Peer.ID == msg.From {
|
|
continue
|
|
}
|
|
connections = append(connections, connection)
|
|
}
|
|
transport.mutex.RUnlock()
|
|
|
|
var lastErr error
|
|
for _, connection := range connections {
|
|
if err := connection.Send(msg); err != nil {
|
|
lastErr = err
|
|
}
|
|
}
|
|
return lastErr
|
|
}
|
|
|
|
// conn := transport.GetConnection(peer.ID)
|
|
// if conn == nil { return fmt.Errorf("peer not connected") }
|
|
func (transport *Transport) GetConnection(peerID string) *PeerConnection {
|
|
transport.mutex.RLock()
|
|
defer transport.mutex.RUnlock()
|
|
return transport.connections[peerID]
|
|
}
|
|
|
|
// mux.HandleFunc(transport.config.WSPath, transport.handleWSUpgrade) // registered in NewTransport via Start()
|
|
func (transport *Transport) handleWSUpgrade(responseWriter http.ResponseWriter, request *http.Request) {
|
|
// Enforce MaxConns limit (including pending connections during handshake)
|
|
transport.mutex.RLock()
|
|
currentConns := len(transport.connections)
|
|
transport.mutex.RUnlock()
|
|
pendingConns := int(transport.pendingConns.Load())
|
|
|
|
totalConns := currentConns + pendingConns
|
|
if totalConns >= transport.config.MaxConns {
|
|
http.Error(responseWriter, "Too many connections", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
transport.pendingConns.Add(1)
|
|
defer transport.pendingConns.Add(-1)
|
|
|
|
conn, err := transport.upgrader.Upgrade(responseWriter, request, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Apply message size limit during handshake to prevent memory exhaustion
|
|
maxSize := transport.config.MaxMessageSize
|
|
if maxSize <= 0 {
|
|
maxSize = DefaultMaxMessageSize
|
|
}
|
|
conn.SetReadLimit(maxSize)
|
|
|
|
// Set handshake timeout to prevent slow/malicious clients from blocking
|
|
handshakeTimeout := 10 * time.Second
|
|
conn.SetReadDeadline(time.Now().Add(handshakeTimeout))
|
|
|
|
// Wait for handshake from client
|
|
_, data, err := conn.ReadMessage()
|
|
if err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Decode handshake message (not encrypted yet, contains public key)
|
|
var message Message
|
|
if err := UnmarshalJSON(data, &message); err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
if message.Type != MsgHandshake {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
var payload HandshakePayload
|
|
if err := message.ParsePayload(&payload); err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Check protocol version compatibility (P2P-MED-1)
|
|
if !IsProtocolVersionSupported(payload.Version) {
|
|
logging.Warn("peer connection rejected: incompatible protocol version", logging.Fields{
|
|
"peer_version": payload.Version,
|
|
"supported_versions": SupportedProtocolVersions,
|
|
"peer_id": payload.Identity.ID,
|
|
})
|
|
identity := transport.node.GetIdentity()
|
|
if identity != nil {
|
|
rejectPayload := HandshakeAckPayload{
|
|
Identity: *identity,
|
|
Accepted: false,
|
|
Reason: "incompatible protocol version " + payload.Version + ", supported: " + joinVersions(SupportedProtocolVersions),
|
|
}
|
|
rejectMsg, _ := NewMessage(MsgHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload)
|
|
if rejectData, err := MarshalJSON(rejectMsg); err == nil {
|
|
conn.WriteMessage(websocket.TextMessage, rejectData)
|
|
}
|
|
}
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Derive shared secret from peer's public key
|
|
sharedSecret, err := transport.node.DeriveSharedSecret(payload.Identity.PublicKey)
|
|
if err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
if !transport.registry.IsPeerAllowed(payload.Identity.ID, payload.Identity.PublicKey) {
|
|
logging.Warn("peer connection rejected: not in allowlist", logging.Fields{
|
|
"peer_id": payload.Identity.ID,
|
|
"peer_name": payload.Identity.Name,
|
|
"public_key": safeKeyPrefix(payload.Identity.PublicKey),
|
|
})
|
|
// Send rejection before closing
|
|
identity := transport.node.GetIdentity()
|
|
if identity != nil {
|
|
rejectPayload := HandshakeAckPayload{
|
|
Identity: *identity,
|
|
Accepted: false,
|
|
Reason: "peer not authorized",
|
|
}
|
|
rejectMsg, _ := NewMessage(MsgHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload)
|
|
if rejectData, err := MarshalJSON(rejectMsg); err == nil {
|
|
conn.WriteMessage(websocket.TextMessage, rejectData)
|
|
}
|
|
}
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Create peer if not exists (only if auth passed)
|
|
peer := transport.registry.GetPeer(payload.Identity.ID)
|
|
if peer == nil {
|
|
// Auto-register the peer since they passed allowlist check
|
|
peer = &Peer{
|
|
ID: payload.Identity.ID,
|
|
Name: payload.Identity.Name,
|
|
PublicKey: payload.Identity.PublicKey,
|
|
Role: payload.Identity.Role,
|
|
AddedAt: time.Now(),
|
|
Score: 50,
|
|
}
|
|
transport.registry.AddPeer(peer)
|
|
logging.Info("auto-registered new peer", logging.Fields{
|
|
"peer_id": peer.ID,
|
|
"peer_name": peer.Name,
|
|
})
|
|
}
|
|
|
|
connection := &PeerConnection{
|
|
Peer: peer,
|
|
Conn: conn,
|
|
SharedSecret: sharedSecret,
|
|
LastActivity: time.Now(),
|
|
transport: transport,
|
|
rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill
|
|
}
|
|
|
|
// Send handshake acknowledgment
|
|
identity := transport.node.GetIdentity()
|
|
if identity == nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
var challengeResponse []byte
|
|
if len(payload.Challenge) > 0 {
|
|
challengeResponse = SignChallenge(payload.Challenge, sharedSecret)
|
|
}
|
|
|
|
ackPayload := HandshakeAckPayload{
|
|
Identity: *identity,
|
|
ChallengeResponse: challengeResponse,
|
|
Accepted: true,
|
|
}
|
|
|
|
acknowledgementMessage, err := NewMessage(MsgHandshakeAck, identity.ID, peer.ID, ackPayload)
|
|
if err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
acknowledgementData, err := MarshalJSON(acknowledgementMessage)
|
|
if err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
if err := conn.WriteMessage(websocket.TextMessage, acknowledgementData); err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
transport.mutex.Lock()
|
|
transport.connections[peer.ID] = connection
|
|
transport.mutex.Unlock()
|
|
|
|
transport.registry.SetConnected(peer.ID, true)
|
|
|
|
// Start read loop
|
|
transport.waitGroup.Add(1)
|
|
go transport.readLoop(connection)
|
|
|
|
// Start keepalive
|
|
transport.waitGroup.Add(1)
|
|
go transport.keepalive(connection)
|
|
}
|
|
|
|
// if err := transport.performHandshake(pc); err != nil { conn.Close(); return nil, err }
|
|
func (transport *Transport) performHandshake(pc *PeerConnection) error {
|
|
// Set handshake timeout
|
|
handshakeTimeout := 10 * time.Second
|
|
pc.Conn.SetWriteDeadline(time.Now().Add(handshakeTimeout))
|
|
pc.Conn.SetReadDeadline(time.Now().Add(handshakeTimeout))
|
|
defer func() {
|
|
// Reset deadlines after handshake
|
|
pc.Conn.SetWriteDeadline(time.Time{})
|
|
pc.Conn.SetReadDeadline(time.Time{})
|
|
}()
|
|
|
|
identity := transport.node.GetIdentity()
|
|
if identity == nil {
|
|
return &ProtocolError{Code: ErrCodeUnauthorized, Message: "node identity not initialized"}
|
|
}
|
|
|
|
// Generate challenge for the server to prove it has the matching private key
|
|
challenge, err := GenerateChallenge()
|
|
if err != nil {
|
|
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "generate challenge: " + err.Error()}
|
|
}
|
|
|
|
payload := HandshakePayload{
|
|
Identity: *identity,
|
|
Challenge: challenge,
|
|
Version: ProtocolVersion,
|
|
}
|
|
|
|
msg, err := NewMessage(MsgHandshake, identity.ID, pc.Peer.ID, payload)
|
|
if err != nil {
|
|
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "create handshake message: " + err.Error()}
|
|
}
|
|
|
|
// First message is unencrypted (peer needs our public key)
|
|
data, err := MarshalJSON(msg)
|
|
if err != nil {
|
|
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "marshal handshake message: " + err.Error()}
|
|
}
|
|
|
|
if err := pc.Conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
|
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "send handshake: " + err.Error()}
|
|
}
|
|
|
|
// Wait for ack
|
|
_, acknowledgementData, err := pc.Conn.ReadMessage()
|
|
if err != nil {
|
|
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "read handshake ack: " + err.Error()}
|
|
}
|
|
|
|
var acknowledgementMessage Message
|
|
if err := UnmarshalJSON(acknowledgementData, &acknowledgementMessage); err != nil {
|
|
return &ProtocolError{Code: ErrCodeInvalidMessage, Message: "unmarshal handshake ack: " + err.Error()}
|
|
}
|
|
|
|
if acknowledgementMessage.Type != MsgHandshakeAck {
|
|
return &ProtocolError{Code: ErrCodeInvalidMessage, Message: "expected handshake_ack, got " + string(acknowledgementMessage.Type)}
|
|
}
|
|
|
|
var ackPayload HandshakeAckPayload
|
|
if err := acknowledgementMessage.ParsePayload(&ackPayload); err != nil {
|
|
return &ProtocolError{Code: ErrCodeInvalidMessage, Message: "parse handshake ack payload: " + err.Error()}
|
|
}
|
|
|
|
if !ackPayload.Accepted {
|
|
return &ProtocolError{Code: ErrCodeUnauthorized, Message: "handshake rejected: " + ackPayload.Reason}
|
|
}
|
|
|
|
// Update peer with the received identity info
|
|
pc.Peer.ID = ackPayload.Identity.ID
|
|
pc.Peer.PublicKey = ackPayload.Identity.PublicKey
|
|
pc.Peer.Name = ackPayload.Identity.Name
|
|
pc.Peer.Role = ackPayload.Identity.Role
|
|
|
|
// Verify challenge response - derive shared secret first using the peer's public key
|
|
sharedSecret, err := transport.node.DeriveSharedSecret(pc.Peer.PublicKey)
|
|
if err != nil {
|
|
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "derive shared secret for challenge verification: " + err.Error()}
|
|
}
|
|
|
|
if len(ackPayload.ChallengeResponse) == 0 {
|
|
return &ProtocolError{Code: ErrCodeUnauthorized, Message: "server did not provide challenge response"}
|
|
}
|
|
if !VerifyChallenge(challenge, ackPayload.ChallengeResponse, sharedSecret) {
|
|
return &ProtocolError{Code: ErrCodeUnauthorized, Message: "challenge response verification failed: server may not have matching private key"}
|
|
}
|
|
|
|
pc.SharedSecret = sharedSecret
|
|
|
|
if err := transport.registry.UpdatePeer(pc.Peer); err != nil {
|
|
// If update fails (peer not found with old ID), add as new
|
|
transport.registry.AddPeer(pc.Peer)
|
|
}
|
|
|
|
logging.Debug("handshake completed with challenge-response verification", logging.Fields{
|
|
"peer_id": pc.Peer.ID,
|
|
"peer_name": pc.Peer.Name,
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// go transport.readLoop(pc) // started by Connect and handleWSUpgrade after handshake
|
|
func (transport *Transport) readLoop(pc *PeerConnection) {
|
|
defer transport.waitGroup.Done()
|
|
defer transport.removeConnection(pc)
|
|
|
|
// Apply message size limit to prevent memory exhaustion attacks
|
|
maxSize := transport.config.MaxMessageSize
|
|
if maxSize <= 0 {
|
|
maxSize = DefaultMaxMessageSize
|
|
}
|
|
pc.Conn.SetReadLimit(maxSize)
|
|
|
|
for {
|
|
select {
|
|
case <-transport.cancelContext.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
// Set read deadline to prevent blocking forever on unresponsive connections
|
|
readDeadline := transport.config.PingInterval + transport.config.PongTimeout
|
|
if err := pc.Conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil {
|
|
logging.Error("SetReadDeadline error", logging.Fields{"peer_id": pc.Peer.ID, "error": err})
|
|
return
|
|
}
|
|
|
|
_, data, err := pc.Conn.ReadMessage()
|
|
if err != nil {
|
|
logging.Debug("read error from peer", logging.Fields{"peer_id": pc.Peer.ID, "error": err})
|
|
return
|
|
}
|
|
|
|
pc.LastActivity = time.Now()
|
|
|
|
// Check rate limit before processing
|
|
if pc.rateLimiter != nil && !pc.rateLimiter.Allow() {
|
|
logging.Warn("peer rate limited, dropping message", logging.Fields{"peer_id": pc.Peer.ID})
|
|
continue // Drop message from rate-limited peer
|
|
}
|
|
|
|
// Decrypt message using SMSG with shared secret
|
|
msg, err := transport.decryptMessage(data, pc.SharedSecret)
|
|
if err != nil {
|
|
logging.Debug("decrypt error from peer", logging.Fields{"peer_id": pc.Peer.ID, "error": err, "data_len": len(data)})
|
|
continue // Skip invalid messages
|
|
}
|
|
|
|
// Check for duplicate messages (prevents amplification attacks)
|
|
if transport.deduplicator.IsDuplicate(msg.ID) {
|
|
logging.Debug("dropping duplicate message", logging.Fields{"msg_id": msg.ID, "peer_id": pc.Peer.ID})
|
|
continue
|
|
}
|
|
transport.deduplicator.Mark(msg.ID)
|
|
|
|
// Rate limit debug logs in hot path to reduce noise (log 1 in N messages)
|
|
if debugLogCounter.Add(1)%debugLogInterval == 0 {
|
|
logging.Debug("received message from peer", logging.Fields{"type": msg.Type, "peer_id": pc.Peer.ID, "reply_to": msg.ReplyTo, "sample": "1/100"})
|
|
}
|
|
|
|
// Dispatch to handler (read handler under lock to avoid race)
|
|
transport.mutex.RLock()
|
|
handler := transport.handler
|
|
transport.mutex.RUnlock()
|
|
if handler != nil {
|
|
handler(pc, msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// go transport.keepalive(pc) // started by Connect and handleWSUpgrade; pings every transport.config.PingInterval
|
|
func (transport *Transport) keepalive(pc *PeerConnection) {
|
|
defer transport.waitGroup.Done()
|
|
|
|
ticker := time.NewTicker(transport.config.PingInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-transport.cancelContext.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if time.Since(pc.LastActivity) > transport.config.PingInterval+transport.config.PongTimeout {
|
|
transport.removeConnection(pc)
|
|
return
|
|
}
|
|
|
|
// Send ping
|
|
identity := transport.node.GetIdentity()
|
|
pingMsg, err := NewMessage(MsgPing, identity.ID, pc.Peer.ID, PingPayload{
|
|
SentAt: time.Now().UnixMilli(),
|
|
})
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if err := pc.Send(pingMsg); err != nil {
|
|
transport.removeConnection(pc)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// transport.removeConnection(pc) // called from readLoop on error and keepalive on timeout
|
|
func (transport *Transport) removeConnection(pc *PeerConnection) {
|
|
transport.mutex.Lock()
|
|
delete(transport.connections, pc.Peer.ID)
|
|
transport.mutex.Unlock()
|
|
|
|
transport.registry.SetConnected(pc.Peer.ID, false)
|
|
pc.Close()
|
|
}
|
|
|
|
// if err := conn.Send(msg); err != nil { logging.Error("send failed", ...) }
|
|
func (pc *PeerConnection) Send(msg *Message) error {
|
|
pc.writeMutex.Lock()
|
|
defer pc.writeMutex.Unlock()
|
|
|
|
// Encrypt message using SMSG
|
|
data, err := pc.transport.encryptMessage(msg, pc.SharedSecret)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set write deadline to prevent blocking forever
|
|
if err := pc.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
|
return &ProtocolError{Code: ErrCodeOperationFailed, Message: "failed to set write deadline: " + err.Error()}
|
|
}
|
|
defer pc.Conn.SetWriteDeadline(time.Time{}) // Reset deadline after send
|
|
|
|
return pc.Conn.WriteMessage(websocket.BinaryMessage, data)
|
|
}
|
|
|
|
// if err := conn.Close(); err != nil { logging.Warn("close failed", logging.Fields{"error": err}) }
|
|
func (pc *PeerConnection) Close() error {
|
|
var err error
|
|
pc.closeOnce.Do(func() {
|
|
err = pc.Conn.Close()
|
|
})
|
|
return err
|
|
}
|
|
|
|
// pc.GracefulClose("server shutdown", DisconnectShutdown)
|
|
// msg, _ := NewMessage(MsgDisconnect, from, to, DisconnectPayload{Reason: "shutdown", Code: DisconnectShutdown})
|
|
type DisconnectPayload struct {
|
|
Reason string `json:"reason"`
|
|
Code int `json:"code"` // Optional disconnect code
|
|
}
|
|
|
|
// Disconnect codes
|
|
const (
|
|
DisconnectNormal = 1000 // Normal closure
|
|
DisconnectGoingAway = 1001 // Server/peer going away
|
|
DisconnectProtocolErr = 1002 // Protocol error
|
|
DisconnectTimeout = 1003 // Idle timeout
|
|
DisconnectShutdown = 1004 // Server shutdown
|
|
)
|
|
|
|
// pc.GracefulClose("server shutdown", DisconnectShutdown)
|
|
// pc.GracefulClose("idle timeout", DisconnectTimeout)
|
|
func (pc *PeerConnection) GracefulClose(reason string, code int) error {
|
|
var err error
|
|
pc.closeOnce.Do(func() {
|
|
// Try to send disconnect message (best effort)
|
|
if pc.transport != nil && pc.SharedSecret != nil {
|
|
identity := pc.transport.node.GetIdentity()
|
|
if identity != nil {
|
|
payload := DisconnectPayload{
|
|
Reason: reason,
|
|
Code: code,
|
|
}
|
|
msg, msgErr := NewMessage(MsgDisconnect, identity.ID, pc.Peer.ID, payload)
|
|
if msgErr == nil {
|
|
// Set short deadline for disconnect message
|
|
pc.Conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
|
pc.Send(msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
err = pc.Conn.Close()
|
|
})
|
|
return err
|
|
}
|
|
|
|
// data, err := transport.encryptMessage(msg, pc.SharedSecret)
|
|
// if err != nil { return err }
|
|
// pc.Conn.WriteMessage(websocket.BinaryMessage, data)
|
|
func (transport *Transport) encryptMessage(msg *Message, sharedSecret []byte) ([]byte, error) {
|
|
msgData, err := MarshalJSON(msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
smsgMsg := smsg.NewMessage(string(msgData))
|
|
password := base64.StdEncoding.EncodeToString(sharedSecret)
|
|
encrypted, err := smsg.Encrypt(smsgMsg, password)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return encrypted, nil
|
|
}
|
|
|
|
// msg, err := transport.decryptMessage(data, pc.SharedSecret)
|
|
// if err != nil { continue } // skip invalid/unauthenticated messages
|
|
func (transport *Transport) decryptMessage(data []byte, sharedSecret []byte) (*Message, error) {
|
|
// Decrypt using shared secret as password
|
|
password := base64.StdEncoding.EncodeToString(sharedSecret)
|
|
decryptedMessage, err := smsg.Decrypt(data, password)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse message from JSON
|
|
var message Message
|
|
if err := UnmarshalJSON([]byte(decryptedMessage.Body), &message); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &message, nil
|
|
}
|
|
|
|
// connectionCount := transport.ConnectedPeers()
|
|
// if connectionCount == 0 { return ErrNoPeers }
|
|
func (transport *Transport) ConnectedPeers() int {
|
|
transport.mutex.RLock()
|
|
defer transport.mutex.RUnlock()
|
|
return len(transport.connections)
|
|
}
|