Critical fixes (6): - CRIT-001/002: Add safeKeyPrefix() to prevent panic on short public keys - CRIT-003/004: Add sync.Once pattern for thread-safe singleton initialization - CRIT-005: Harden console ANSI parser with length limits and stricter validation - CRIT-006: Add client-side input validation for profile creation High priority fixes (10): - HIGH-001: Add secondary timeout in TTMiner to prevent goroutine leak - HIGH-002: Verify atomic flag prevents timeout middleware race - HIGH-004: Add LimitReader (100MB) to prevent decompression bombs - HIGH-005: Add Lines parameter validation (max 10000) in worker - HIGH-006: Add TLS 1.2+ config with secure cipher suites - HIGH-007: Add pool URL format and wallet length validation - HIGH-008: Add SIGHUP handling and force cleanup on Stop() failure - HIGH-009: Add WebSocket message size limit and event type validation - HIGH-010: Refactor to use takeUntil(destroy$) for observable cleanup - HIGH-011: Add sanitizeErrorDetails() with debug mode control 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
934 lines
25 KiB
Go
934 lines
25 KiB
Go
package node
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/Snider/Borg/pkg/smsg"
|
|
"github.com/Snider/Mining/pkg/logging"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// debugLogCounter tracks message counts for rate limiting debug logs
|
|
var debugLogCounter atomic.Int64
|
|
|
|
// debugLogInterval controls how often we log debug messages in hot paths (1 in N)
|
|
const debugLogInterval = 100
|
|
|
|
// DefaultMaxMessageSize is the default maximum message size (1MB)
|
|
const DefaultMaxMessageSize int64 = 1 << 20 // 1MB
|
|
|
|
// TransportConfig configures the WebSocket transport.
|
|
type TransportConfig struct {
|
|
ListenAddr string // ":9091" default
|
|
WSPath string // "/ws" - WebSocket endpoint path
|
|
TLSCertPath string // Optional TLS for wss://
|
|
TLSKeyPath string
|
|
MaxConns int // Maximum concurrent connections
|
|
MaxMessageSize int64 // Maximum message size in bytes (0 = 1MB default)
|
|
PingInterval time.Duration // WebSocket keepalive interval
|
|
PongTimeout time.Duration // Timeout waiting for pong
|
|
}
|
|
|
|
// DefaultTransportConfig returns sensible defaults.
|
|
func DefaultTransportConfig() TransportConfig {
|
|
return TransportConfig{
|
|
ListenAddr: ":9091",
|
|
WSPath: "/ws",
|
|
MaxConns: 100,
|
|
MaxMessageSize: DefaultMaxMessageSize,
|
|
PingInterval: 30 * time.Second,
|
|
PongTimeout: 10 * time.Second,
|
|
}
|
|
}
|
|
|
|
// MessageHandler processes incoming messages.
|
|
type MessageHandler func(conn *PeerConnection, msg *Message)
|
|
|
|
// MessageDeduplicator tracks seen message IDs to prevent duplicate processing
|
|
type MessageDeduplicator struct {
|
|
seen map[string]time.Time
|
|
mu sync.RWMutex
|
|
ttl time.Duration
|
|
}
|
|
|
|
// NewMessageDeduplicator creates a deduplicator with specified TTL
|
|
func NewMessageDeduplicator(ttl time.Duration) *MessageDeduplicator {
|
|
d := &MessageDeduplicator{
|
|
seen: make(map[string]time.Time),
|
|
ttl: ttl,
|
|
}
|
|
return d
|
|
}
|
|
|
|
// IsDuplicate checks if a message ID has been seen recently
|
|
func (d *MessageDeduplicator) IsDuplicate(msgID string) bool {
|
|
d.mu.RLock()
|
|
_, exists := d.seen[msgID]
|
|
d.mu.RUnlock()
|
|
return exists
|
|
}
|
|
|
|
// Mark records a message ID as seen
|
|
func (d *MessageDeduplicator) Mark(msgID string) {
|
|
d.mu.Lock()
|
|
d.seen[msgID] = time.Now()
|
|
d.mu.Unlock()
|
|
}
|
|
|
|
// Cleanup removes expired entries
|
|
func (d *MessageDeduplicator) Cleanup() {
|
|
d.mu.Lock()
|
|
defer d.mu.Unlock()
|
|
now := time.Now()
|
|
for id, seen := range d.seen {
|
|
if now.Sub(seen) > d.ttl {
|
|
delete(d.seen, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Transport manages WebSocket connections with SMSG encryption.
|
|
type Transport struct {
|
|
config TransportConfig
|
|
server *http.Server
|
|
upgrader websocket.Upgrader
|
|
conns map[string]*PeerConnection // peer ID -> connection
|
|
pendingConns atomic.Int32 // tracks connections during handshake
|
|
node *NodeManager
|
|
registry *PeerRegistry
|
|
handler MessageHandler
|
|
dedup *MessageDeduplicator // Message deduplication
|
|
mu sync.RWMutex
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
// PeerRateLimiter implements a simple token bucket rate limiter per peer
|
|
type PeerRateLimiter struct {
|
|
tokens int
|
|
maxTokens int
|
|
refillRate int // tokens per second
|
|
lastRefill time.Time
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NewPeerRateLimiter creates a rate limiter with specified messages/second
|
|
func NewPeerRateLimiter(maxTokens, refillRate int) *PeerRateLimiter {
|
|
return &PeerRateLimiter{
|
|
tokens: maxTokens,
|
|
maxTokens: maxTokens,
|
|
refillRate: refillRate,
|
|
lastRefill: time.Now(),
|
|
}
|
|
}
|
|
|
|
// Allow checks if a message is allowed and consumes a token if so
|
|
func (r *PeerRateLimiter) Allow() bool {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
// Refill tokens based on elapsed time
|
|
now := time.Now()
|
|
elapsed := now.Sub(r.lastRefill)
|
|
tokensToAdd := int(elapsed.Seconds()) * r.refillRate
|
|
if tokensToAdd > 0 {
|
|
r.tokens = min(r.tokens+tokensToAdd, r.maxTokens)
|
|
r.lastRefill = now
|
|
}
|
|
|
|
// Check if we have tokens available
|
|
if r.tokens > 0 {
|
|
r.tokens--
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// PeerConnection represents an active connection to a peer.
|
|
type PeerConnection struct {
|
|
Peer *Peer
|
|
Conn *websocket.Conn
|
|
SharedSecret []byte // Derived via X25519 ECDH, used for SMSG
|
|
LastActivity time.Time
|
|
writeMu sync.Mutex // Serialize WebSocket writes
|
|
transport *Transport
|
|
closeOnce sync.Once // Ensure Close() is only called once
|
|
rateLimiter *PeerRateLimiter // Per-peer message rate limiting
|
|
}
|
|
|
|
// NewTransport creates a new WebSocket transport.
|
|
func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportConfig) *Transport {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
return &Transport{
|
|
config: config,
|
|
node: node,
|
|
registry: registry,
|
|
conns: make(map[string]*PeerConnection),
|
|
dedup: NewMessageDeduplicator(5 * time.Minute), // 5 minute TTL for dedup
|
|
upgrader: websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
// Allow local connections only for security
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
return true // No origin header (non-browser client)
|
|
}
|
|
// Allow localhost and 127.0.0.1 origins
|
|
u, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
host := u.Hostname()
|
|
return host == "localhost" || host == "127.0.0.1" || host == "::1"
|
|
},
|
|
},
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
// Start begins listening for incoming connections.
|
|
func (t *Transport) Start() error {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(t.config.WSPath, t.handleWSUpgrade)
|
|
|
|
t.server = &http.Server{
|
|
Addr: t.config.ListenAddr,
|
|
Handler: mux,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
IdleTimeout: 60 * time.Second,
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
}
|
|
|
|
// Apply TLS hardening if TLS is enabled
|
|
if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" {
|
|
t.server.TLSConfig = &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
CipherSuites: []uint16{
|
|
// TLS 1.3 ciphers (automatically used when available)
|
|
tls.TLS_AES_128_GCM_SHA256,
|
|
tls.TLS_AES_256_GCM_SHA384,
|
|
tls.TLS_CHACHA20_POLY1305_SHA256,
|
|
// TLS 1.2 secure ciphers
|
|
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
|
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
|
},
|
|
CurvePreferences: []tls.CurveID{
|
|
tls.X25519,
|
|
tls.CurveP256,
|
|
},
|
|
}
|
|
}
|
|
|
|
t.wg.Add(1)
|
|
go func() {
|
|
defer t.wg.Done()
|
|
var err error
|
|
if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" {
|
|
err = t.server.ListenAndServeTLS(t.config.TLSCertPath, t.config.TLSKeyPath)
|
|
} else {
|
|
err = t.server.ListenAndServe()
|
|
}
|
|
if err != nil && err != http.ErrServerClosed {
|
|
logging.Error("HTTP server error", logging.Fields{"error": err, "addr": t.config.ListenAddr})
|
|
}
|
|
}()
|
|
|
|
// Start message deduplication cleanup goroutine
|
|
t.wg.Add(1)
|
|
go func() {
|
|
defer t.wg.Done()
|
|
ticker := time.NewTicker(time.Minute)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-t.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
t.dedup.Cleanup()
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop gracefully shuts down the transport.
|
|
func (t *Transport) Stop() error {
|
|
t.cancel()
|
|
|
|
// Gracefully close all connections with shutdown message
|
|
t.mu.Lock()
|
|
for _, pc := range t.conns {
|
|
pc.GracefulClose("server shutdown", DisconnectShutdown)
|
|
}
|
|
t.mu.Unlock()
|
|
|
|
// Shutdown HTTP server if it was started
|
|
if t.server != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
if err := t.server.Shutdown(ctx); err != nil {
|
|
return fmt.Errorf("server shutdown error: %w", err)
|
|
}
|
|
}
|
|
|
|
t.wg.Wait()
|
|
return nil
|
|
}
|
|
|
|
// OnMessage sets the handler for incoming messages.
|
|
// Must be called before Start() to avoid races.
|
|
func (t *Transport) OnMessage(handler MessageHandler) {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
t.handler = handler
|
|
}
|
|
|
|
// Connect establishes a connection to a peer.
|
|
func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
|
|
// Build WebSocket URL
|
|
scheme := "ws"
|
|
if t.config.TLSCertPath != "" {
|
|
scheme = "wss"
|
|
}
|
|
u := url.URL{Scheme: scheme, Host: peer.Address, Path: t.config.WSPath}
|
|
|
|
// Dial the peer with timeout to prevent hanging on unresponsive peers
|
|
dialer := websocket.Dialer{
|
|
HandshakeTimeout: 10 * time.Second,
|
|
}
|
|
conn, _, err := dialer.Dial(u.String(), nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to peer: %w", err)
|
|
}
|
|
|
|
pc := &PeerConnection{
|
|
Peer: peer,
|
|
Conn: conn,
|
|
LastActivity: time.Now(),
|
|
transport: t,
|
|
rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill
|
|
}
|
|
|
|
// Perform handshake with challenge-response authentication
|
|
// This also derives and stores the shared secret in pc.SharedSecret
|
|
if err := t.performHandshake(pc); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("handshake failed: %w", err)
|
|
}
|
|
|
|
// Store connection using the real peer ID from handshake
|
|
t.mu.Lock()
|
|
t.conns[pc.Peer.ID] = pc
|
|
t.mu.Unlock()
|
|
|
|
logging.Debug("connected to peer", logging.Fields{"peer_id": pc.Peer.ID, "secret_len": len(pc.SharedSecret)})
|
|
|
|
// Update registry
|
|
t.registry.SetConnected(pc.Peer.ID, true)
|
|
|
|
// Start read loop
|
|
t.wg.Add(1)
|
|
go t.readLoop(pc)
|
|
|
|
logging.Debug("started readLoop for peer", logging.Fields{"peer_id": pc.Peer.ID})
|
|
|
|
// Start keepalive
|
|
t.wg.Add(1)
|
|
go t.keepalive(pc)
|
|
|
|
return pc, nil
|
|
}
|
|
|
|
// Send sends a message to a specific peer.
|
|
func (t *Transport) Send(peerID string, msg *Message) error {
|
|
t.mu.RLock()
|
|
pc, exists := t.conns[peerID]
|
|
t.mu.RUnlock()
|
|
|
|
if !exists {
|
|
return fmt.Errorf("peer %s not connected", peerID)
|
|
}
|
|
|
|
return pc.Send(msg)
|
|
}
|
|
|
|
// Broadcast sends a message to all connected peers except the sender.
|
|
// The sender is identified by msg.From and excluded to prevent echo.
|
|
func (t *Transport) Broadcast(msg *Message) error {
|
|
t.mu.RLock()
|
|
conns := make([]*PeerConnection, 0, len(t.conns))
|
|
for _, pc := range t.conns {
|
|
// Exclude sender from broadcast to prevent echo (P2P-MED-6)
|
|
if pc.Peer != nil && pc.Peer.ID == msg.From {
|
|
continue
|
|
}
|
|
conns = append(conns, pc)
|
|
}
|
|
t.mu.RUnlock()
|
|
|
|
var lastErr error
|
|
for _, pc := range conns {
|
|
if err := pc.Send(msg); err != nil {
|
|
lastErr = err
|
|
}
|
|
}
|
|
return lastErr
|
|
}
|
|
|
|
// GetConnection returns an active connection to a peer.
|
|
func (t *Transport) GetConnection(peerID string) *PeerConnection {
|
|
t.mu.RLock()
|
|
defer t.mu.RUnlock()
|
|
return t.conns[peerID]
|
|
}
|
|
|
|
// handleWSUpgrade handles incoming WebSocket connections.
|
|
func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
|
|
// Enforce MaxConns limit (including pending connections during handshake)
|
|
t.mu.RLock()
|
|
currentConns := len(t.conns)
|
|
t.mu.RUnlock()
|
|
pendingConns := int(t.pendingConns.Load())
|
|
|
|
totalConns := currentConns + pendingConns
|
|
if totalConns >= t.config.MaxConns {
|
|
http.Error(w, "Too many connections", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
// Track this connection as pending during handshake
|
|
t.pendingConns.Add(1)
|
|
defer t.pendingConns.Add(-1)
|
|
|
|
conn, err := t.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Apply message size limit during handshake to prevent memory exhaustion
|
|
maxSize := t.config.MaxMessageSize
|
|
if maxSize <= 0 {
|
|
maxSize = DefaultMaxMessageSize
|
|
}
|
|
conn.SetReadLimit(maxSize)
|
|
|
|
// Set handshake timeout to prevent slow/malicious clients from blocking
|
|
handshakeTimeout := 10 * time.Second
|
|
conn.SetReadDeadline(time.Now().Add(handshakeTimeout))
|
|
|
|
// Wait for handshake from client
|
|
_, data, err := conn.ReadMessage()
|
|
if err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Decode handshake message (not encrypted yet, contains public key)
|
|
var msg Message
|
|
if err := json.Unmarshal(data, &msg); err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
if msg.Type != MsgHandshake {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
var payload HandshakePayload
|
|
if err := msg.ParsePayload(&payload); err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Check protocol version compatibility (P2P-MED-1)
|
|
if !IsProtocolVersionSupported(payload.Version) {
|
|
logging.Warn("peer connection rejected: incompatible protocol version", logging.Fields{
|
|
"peer_version": payload.Version,
|
|
"supported_versions": SupportedProtocolVersions,
|
|
"peer_id": payload.Identity.ID,
|
|
})
|
|
identity := t.node.GetIdentity()
|
|
if identity != nil {
|
|
rejectPayload := HandshakeAckPayload{
|
|
Identity: *identity,
|
|
Accepted: false,
|
|
Reason: fmt.Sprintf("incompatible protocol version %s, supported: %v", payload.Version, SupportedProtocolVersions),
|
|
}
|
|
rejectMsg, _ := NewMessage(MsgHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload)
|
|
if rejectData, err := MarshalJSON(rejectMsg); err == nil {
|
|
conn.WriteMessage(websocket.TextMessage, rejectData)
|
|
}
|
|
}
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Derive shared secret from peer's public key
|
|
sharedSecret, err := t.node.DeriveSharedSecret(payload.Identity.PublicKey)
|
|
if err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Check if peer is allowed to connect (allowlist check)
|
|
if !t.registry.IsPeerAllowed(payload.Identity.ID, payload.Identity.PublicKey) {
|
|
logging.Warn("peer connection rejected: not in allowlist", logging.Fields{
|
|
"peer_id": payload.Identity.ID,
|
|
"peer_name": payload.Identity.Name,
|
|
"public_key": safeKeyPrefix(payload.Identity.PublicKey),
|
|
})
|
|
// Send rejection before closing
|
|
identity := t.node.GetIdentity()
|
|
if identity != nil {
|
|
rejectPayload := HandshakeAckPayload{
|
|
Identity: *identity,
|
|
Accepted: false,
|
|
Reason: "peer not authorized",
|
|
}
|
|
rejectMsg, _ := NewMessage(MsgHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload)
|
|
if rejectData, err := MarshalJSON(rejectMsg); err == nil {
|
|
conn.WriteMessage(websocket.TextMessage, rejectData)
|
|
}
|
|
}
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Create peer if not exists (only if auth passed)
|
|
peer := t.registry.GetPeer(payload.Identity.ID)
|
|
if peer == nil {
|
|
// Auto-register the peer since they passed allowlist check
|
|
peer = &Peer{
|
|
ID: payload.Identity.ID,
|
|
Name: payload.Identity.Name,
|
|
PublicKey: payload.Identity.PublicKey,
|
|
Role: payload.Identity.Role,
|
|
AddedAt: time.Now(),
|
|
Score: 50,
|
|
}
|
|
t.registry.AddPeer(peer)
|
|
logging.Info("auto-registered new peer", logging.Fields{
|
|
"peer_id": peer.ID,
|
|
"peer_name": peer.Name,
|
|
})
|
|
}
|
|
|
|
pc := &PeerConnection{
|
|
Peer: peer,
|
|
Conn: conn,
|
|
SharedSecret: sharedSecret,
|
|
LastActivity: time.Now(),
|
|
transport: t,
|
|
rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill
|
|
}
|
|
|
|
// Send handshake acknowledgment
|
|
identity := t.node.GetIdentity()
|
|
if identity == nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Sign the client's challenge to prove we have the matching private key
|
|
var challengeResponse []byte
|
|
if len(payload.Challenge) > 0 {
|
|
challengeResponse = SignChallenge(payload.Challenge, sharedSecret)
|
|
}
|
|
|
|
ackPayload := HandshakeAckPayload{
|
|
Identity: *identity,
|
|
ChallengeResponse: challengeResponse,
|
|
Accepted: true,
|
|
}
|
|
|
|
ackMsg, err := NewMessage(MsgHandshakeAck, identity.ID, peer.ID, ackPayload)
|
|
if err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// First ack is unencrypted (peer needs to know our public key)
|
|
ackData, err := MarshalJSON(ackMsg)
|
|
if err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
if err := conn.WriteMessage(websocket.TextMessage, ackData); err != nil {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Store connection
|
|
t.mu.Lock()
|
|
t.conns[peer.ID] = pc
|
|
t.mu.Unlock()
|
|
|
|
// Update registry
|
|
t.registry.SetConnected(peer.ID, true)
|
|
|
|
// Start read loop
|
|
t.wg.Add(1)
|
|
go t.readLoop(pc)
|
|
|
|
// Start keepalive
|
|
t.wg.Add(1)
|
|
go t.keepalive(pc)
|
|
}
|
|
|
|
// performHandshake initiates handshake with a peer.
|
|
func (t *Transport) performHandshake(pc *PeerConnection) error {
|
|
// Set handshake timeout
|
|
handshakeTimeout := 10 * time.Second
|
|
pc.Conn.SetWriteDeadline(time.Now().Add(handshakeTimeout))
|
|
pc.Conn.SetReadDeadline(time.Now().Add(handshakeTimeout))
|
|
defer func() {
|
|
// Reset deadlines after handshake
|
|
pc.Conn.SetWriteDeadline(time.Time{})
|
|
pc.Conn.SetReadDeadline(time.Time{})
|
|
}()
|
|
|
|
identity := t.node.GetIdentity()
|
|
if identity == nil {
|
|
return fmt.Errorf("node identity not initialized")
|
|
}
|
|
|
|
// Generate challenge for the server to prove it has the matching private key
|
|
challenge, err := GenerateChallenge()
|
|
if err != nil {
|
|
return fmt.Errorf("generate challenge: %w", err)
|
|
}
|
|
|
|
payload := HandshakePayload{
|
|
Identity: *identity,
|
|
Challenge: challenge,
|
|
Version: ProtocolVersion,
|
|
}
|
|
|
|
msg, err := NewMessage(MsgHandshake, identity.ID, pc.Peer.ID, payload)
|
|
if err != nil {
|
|
return fmt.Errorf("create handshake message: %w", err)
|
|
}
|
|
|
|
// First message is unencrypted (peer needs our public key)
|
|
data, err := MarshalJSON(msg)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal handshake message: %w", err)
|
|
}
|
|
|
|
if err := pc.Conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
|
return fmt.Errorf("send handshake: %w", err)
|
|
}
|
|
|
|
// Wait for ack
|
|
_, ackData, err := pc.Conn.ReadMessage()
|
|
if err != nil {
|
|
return fmt.Errorf("read handshake ack: %w", err)
|
|
}
|
|
|
|
var ackMsg Message
|
|
if err := json.Unmarshal(ackData, &ackMsg); err != nil {
|
|
return fmt.Errorf("unmarshal handshake ack: %w", err)
|
|
}
|
|
|
|
if ackMsg.Type != MsgHandshakeAck {
|
|
return fmt.Errorf("expected handshake_ack, got %s", ackMsg.Type)
|
|
}
|
|
|
|
var ackPayload HandshakeAckPayload
|
|
if err := ackMsg.ParsePayload(&ackPayload); err != nil {
|
|
return fmt.Errorf("parse handshake ack payload: %w", err)
|
|
}
|
|
|
|
if !ackPayload.Accepted {
|
|
return fmt.Errorf("handshake rejected: %s", ackPayload.Reason)
|
|
}
|
|
|
|
// Update peer with the received identity info
|
|
pc.Peer.ID = ackPayload.Identity.ID
|
|
pc.Peer.PublicKey = ackPayload.Identity.PublicKey
|
|
pc.Peer.Name = ackPayload.Identity.Name
|
|
pc.Peer.Role = ackPayload.Identity.Role
|
|
|
|
// Verify challenge response - derive shared secret first using the peer's public key
|
|
sharedSecret, err := t.node.DeriveSharedSecret(pc.Peer.PublicKey)
|
|
if err != nil {
|
|
return fmt.Errorf("derive shared secret for challenge verification: %w", err)
|
|
}
|
|
|
|
// Verify the server's response to our challenge
|
|
if len(ackPayload.ChallengeResponse) == 0 {
|
|
return fmt.Errorf("server did not provide challenge response")
|
|
}
|
|
if !VerifyChallenge(challenge, ackPayload.ChallengeResponse, sharedSecret) {
|
|
return fmt.Errorf("challenge response verification failed: server may not have matching private key")
|
|
}
|
|
|
|
// Store the shared secret for later use
|
|
pc.SharedSecret = sharedSecret
|
|
|
|
// Update the peer in registry with the real identity
|
|
if err := t.registry.UpdatePeer(pc.Peer); err != nil {
|
|
// If update fails (peer not found with old ID), add as new
|
|
t.registry.AddPeer(pc.Peer)
|
|
}
|
|
|
|
logging.Debug("handshake completed with challenge-response verification", logging.Fields{
|
|
"peer_id": pc.Peer.ID,
|
|
"peer_name": pc.Peer.Name,
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// readLoop reads messages from a peer connection.
|
|
func (t *Transport) readLoop(pc *PeerConnection) {
|
|
defer t.wg.Done()
|
|
defer t.removeConnection(pc)
|
|
|
|
// Apply message size limit to prevent memory exhaustion attacks
|
|
maxSize := t.config.MaxMessageSize
|
|
if maxSize <= 0 {
|
|
maxSize = DefaultMaxMessageSize
|
|
}
|
|
pc.Conn.SetReadLimit(maxSize)
|
|
|
|
for {
|
|
select {
|
|
case <-t.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
// Set read deadline to prevent blocking forever on unresponsive connections
|
|
readDeadline := t.config.PingInterval + t.config.PongTimeout
|
|
if err := pc.Conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil {
|
|
logging.Error("SetReadDeadline error", logging.Fields{"peer_id": pc.Peer.ID, "error": err})
|
|
return
|
|
}
|
|
|
|
_, data, err := pc.Conn.ReadMessage()
|
|
if err != nil {
|
|
logging.Debug("read error from peer", logging.Fields{"peer_id": pc.Peer.ID, "error": err})
|
|
return
|
|
}
|
|
|
|
pc.LastActivity = time.Now()
|
|
|
|
// Check rate limit before processing
|
|
if pc.rateLimiter != nil && !pc.rateLimiter.Allow() {
|
|
logging.Warn("peer rate limited, dropping message", logging.Fields{"peer_id": pc.Peer.ID})
|
|
continue // Drop message from rate-limited peer
|
|
}
|
|
|
|
// Decrypt message using SMSG with shared secret
|
|
msg, err := t.decryptMessage(data, pc.SharedSecret)
|
|
if err != nil {
|
|
logging.Debug("decrypt error from peer", logging.Fields{"peer_id": pc.Peer.ID, "error": err, "data_len": len(data)})
|
|
continue // Skip invalid messages
|
|
}
|
|
|
|
// Check for duplicate messages (prevents amplification attacks)
|
|
if t.dedup.IsDuplicate(msg.ID) {
|
|
logging.Debug("dropping duplicate message", logging.Fields{"msg_id": msg.ID, "peer_id": pc.Peer.ID})
|
|
continue
|
|
}
|
|
t.dedup.Mark(msg.ID)
|
|
|
|
// Rate limit debug logs in hot path to reduce noise (log 1 in N messages)
|
|
if debugLogCounter.Add(1)%debugLogInterval == 0 {
|
|
logging.Debug("received message from peer", logging.Fields{"type": msg.Type, "peer_id": pc.Peer.ID, "reply_to": msg.ReplyTo, "sample": "1/100"})
|
|
}
|
|
|
|
// Dispatch to handler (read handler under lock to avoid race)
|
|
t.mu.RLock()
|
|
handler := t.handler
|
|
t.mu.RUnlock()
|
|
if handler != nil {
|
|
handler(pc, msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// keepalive sends periodic pings.
|
|
func (t *Transport) keepalive(pc *PeerConnection) {
|
|
defer t.wg.Done()
|
|
|
|
ticker := time.NewTicker(t.config.PingInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-t.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
// Check if connection is still alive
|
|
if time.Since(pc.LastActivity) > t.config.PingInterval+t.config.PongTimeout {
|
|
t.removeConnection(pc)
|
|
return
|
|
}
|
|
|
|
// Send ping
|
|
identity := t.node.GetIdentity()
|
|
pingMsg, err := NewMessage(MsgPing, identity.ID, pc.Peer.ID, PingPayload{
|
|
SentAt: time.Now().UnixMilli(),
|
|
})
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if err := pc.Send(pingMsg); err != nil {
|
|
t.removeConnection(pc)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// removeConnection removes and cleans up a connection.
|
|
func (t *Transport) removeConnection(pc *PeerConnection) {
|
|
t.mu.Lock()
|
|
delete(t.conns, pc.Peer.ID)
|
|
t.mu.Unlock()
|
|
|
|
t.registry.SetConnected(pc.Peer.ID, false)
|
|
pc.Close()
|
|
}
|
|
|
|
// Send sends an encrypted message over the connection.
|
|
func (pc *PeerConnection) Send(msg *Message) error {
|
|
pc.writeMu.Lock()
|
|
defer pc.writeMu.Unlock()
|
|
|
|
// Encrypt message using SMSG
|
|
data, err := pc.transport.encryptMessage(msg, pc.SharedSecret)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set write deadline to prevent blocking forever
|
|
if err := pc.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
|
return fmt.Errorf("failed to set write deadline: %w", err)
|
|
}
|
|
defer pc.Conn.SetWriteDeadline(time.Time{}) // Reset deadline after send
|
|
|
|
return pc.Conn.WriteMessage(websocket.BinaryMessage, data)
|
|
}
|
|
|
|
// Close closes the connection.
|
|
func (pc *PeerConnection) Close() error {
|
|
var err error
|
|
pc.closeOnce.Do(func() {
|
|
err = pc.Conn.Close()
|
|
})
|
|
return err
|
|
}
|
|
|
|
// DisconnectPayload contains reason for disconnect.
|
|
type DisconnectPayload struct {
|
|
Reason string `json:"reason"`
|
|
Code int `json:"code"` // Optional disconnect code
|
|
}
|
|
|
|
// Disconnect codes
|
|
const (
|
|
DisconnectNormal = 1000 // Normal closure
|
|
DisconnectGoingAway = 1001 // Server/peer going away
|
|
DisconnectProtocolErr = 1002 // Protocol error
|
|
DisconnectTimeout = 1003 // Idle timeout
|
|
DisconnectShutdown = 1004 // Server shutdown
|
|
)
|
|
|
|
// GracefulClose sends a disconnect message before closing the connection.
|
|
func (pc *PeerConnection) GracefulClose(reason string, code int) error {
|
|
var err error
|
|
pc.closeOnce.Do(func() {
|
|
// Try to send disconnect message (best effort)
|
|
if pc.transport != nil && pc.SharedSecret != nil {
|
|
identity := pc.transport.node.GetIdentity()
|
|
if identity != nil {
|
|
payload := DisconnectPayload{
|
|
Reason: reason,
|
|
Code: code,
|
|
}
|
|
msg, msgErr := NewMessage(MsgDisconnect, identity.ID, pc.Peer.ID, payload)
|
|
if msgErr == nil {
|
|
// Set short deadline for disconnect message
|
|
pc.Conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
|
pc.Send(msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close the underlying connection
|
|
err = pc.Conn.Close()
|
|
})
|
|
return err
|
|
}
|
|
|
|
// encryptMessage encrypts a message using SMSG with the shared secret.
|
|
func (t *Transport) encryptMessage(msg *Message, sharedSecret []byte) ([]byte, error) {
|
|
// Serialize message to JSON (using pooled buffer for efficiency)
|
|
msgData, err := MarshalJSON(msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create SMSG message
|
|
smsgMsg := smsg.NewMessage(string(msgData))
|
|
|
|
// Encrypt using shared secret as password (base64 encoded)
|
|
password := base64.StdEncoding.EncodeToString(sharedSecret)
|
|
encrypted, err := smsg.Encrypt(smsgMsg, password)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return encrypted, nil
|
|
}
|
|
|
|
// decryptMessage decrypts a message using SMSG with the shared secret.
|
|
func (t *Transport) decryptMessage(data []byte, sharedSecret []byte) (*Message, error) {
|
|
// Decrypt using shared secret as password
|
|
password := base64.StdEncoding.EncodeToString(sharedSecret)
|
|
smsgMsg, err := smsg.Decrypt(data, password)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse message from JSON
|
|
var msg Message
|
|
if err := json.Unmarshal([]byte(smsgMsg.Body), &msg); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &msg, nil
|
|
}
|
|
|
|
// ConnectedPeers returns the number of connected peers.
|
|
func (t *Transport) ConnectedPeers() int {
|
|
t.mu.RLock()
|
|
defer t.mu.RUnlock()
|
|
return len(t.conns)
|
|
}
|