Mining/pkg/node/transport.go
snider 74bbf14de4 fix: Address networking, memory leak, and segfault issues from code review
Networking/Protocol fixes:
- Add HTTP server timeouts (Read/Write/Idle/ReadHeader) in service.go
- Fix CORS address parsing to use net.SplitHostPort safely
- Add request body size limit middleware (1MB max)
- Enforce MaxConns limit in WebSocket upgrade handler
- Fix WebSocket origin validation to only allow localhost
- Add read/write deadlines to WebSocket connections

Memory leak fixes:
- Add sync.Once to Manager.Stop() to prevent double-close panic
- Fix controller pending map leak by closing response channel
- Add memory reallocation for hashrate history slices when oversized
- Fix LogBuffer to truncate long lines and force reallocation on trim
- Add process wait timeout to prevent goroutine leaks on zombie processes
- Drain HTTP response body on copy error to allow connection reuse

Segfault/panic prevention:
- Add nil check in GetTotalHashrate for stats pointer
- Fix hashrate history slice reallocation to prevent capacity bloat

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-31 02:26:46 +00:00

585 lines
14 KiB
Go

package node
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"sync"
"time"
"github.com/Snider/Borg/pkg/smsg"
"github.com/gorilla/websocket"
)
// TransportConfig configures the WebSocket transport.
type TransportConfig struct {
ListenAddr string // ":9091" default
WSPath string // "/ws" - WebSocket endpoint path
TLSCertPath string // Optional TLS for wss://
TLSKeyPath string
MaxConns int // Maximum concurrent connections
PingInterval time.Duration // WebSocket keepalive interval
PongTimeout time.Duration // Timeout waiting for pong
}
// DefaultTransportConfig returns sensible defaults.
func DefaultTransportConfig() TransportConfig {
return TransportConfig{
ListenAddr: ":9091",
WSPath: "/ws",
MaxConns: 100,
PingInterval: 30 * time.Second,
PongTimeout: 10 * time.Second,
}
}
// MessageHandler processes incoming messages.
type MessageHandler func(conn *PeerConnection, msg *Message)
// Transport manages WebSocket connections with SMSG encryption.
type Transport struct {
config TransportConfig
server *http.Server
upgrader websocket.Upgrader
conns map[string]*PeerConnection // peer ID -> connection
node *NodeManager
registry *PeerRegistry
handler MessageHandler
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// PeerConnection represents an active connection to a peer.
type PeerConnection struct {
Peer *Peer
Conn *websocket.Conn
SharedSecret []byte // Derived via X25519 ECDH, used for SMSG
LastActivity time.Time
writeMu sync.Mutex // Serialize WebSocket writes
transport *Transport
}
// NewTransport creates a new WebSocket transport.
func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportConfig) *Transport {
ctx, cancel := context.WithCancel(context.Background())
return &Transport{
config: config,
node: node,
registry: registry,
conns: make(map[string]*PeerConnection),
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// Allow local connections only for security
origin := r.Header.Get("Origin")
if origin == "" {
return true // No origin header (non-browser client)
}
// Allow localhost and 127.0.0.1 origins
u, err := url.Parse(origin)
if err != nil {
return false
}
host := u.Hostname()
return host == "localhost" || host == "127.0.0.1" || host == "::1"
},
},
ctx: ctx,
cancel: cancel,
}
}
// Start begins listening for incoming connections.
func (t *Transport) Start() error {
mux := http.NewServeMux()
mux.HandleFunc(t.config.WSPath, t.handleWSUpgrade)
t.server = &http.Server{
Addr: t.config.ListenAddr,
Handler: mux,
}
t.wg.Add(1)
go func() {
defer t.wg.Done()
var err error
if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" {
err = t.server.ListenAndServeTLS(t.config.TLSCertPath, t.config.TLSKeyPath)
} else {
err = t.server.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
// Log error
}
}()
return nil
}
// Stop gracefully shuts down the transport.
func (t *Transport) Stop() error {
t.cancel()
// Close all connections
t.mu.Lock()
for _, pc := range t.conns {
pc.Close()
}
t.mu.Unlock()
// Shutdown HTTP server
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := t.server.Shutdown(ctx); err != nil {
return fmt.Errorf("server shutdown error: %w", err)
}
t.wg.Wait()
return nil
}
// OnMessage sets the handler for incoming messages.
func (t *Transport) OnMessage(handler MessageHandler) {
t.handler = handler
}
// Connect establishes a connection to a peer.
func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
// Build WebSocket URL
scheme := "ws"
if t.config.TLSCertPath != "" {
scheme = "wss"
}
u := url.URL{Scheme: scheme, Host: peer.Address, Path: t.config.WSPath}
// Dial the peer
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to connect to peer: %w", err)
}
pc := &PeerConnection{
Peer: peer,
Conn: conn,
LastActivity: time.Now(),
transport: t,
}
// Perform handshake first to exchange public keys
if err := t.performHandshake(pc); err != nil {
conn.Close()
return nil, fmt.Errorf("handshake failed: %w", err)
}
// Now derive shared secret using the received public key
sharedSecret, err := t.node.DeriveSharedSecret(pc.Peer.PublicKey)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to derive shared secret: %w", err)
}
pc.SharedSecret = sharedSecret
// Store connection using the real peer ID from handshake
t.mu.Lock()
t.conns[pc.Peer.ID] = pc
t.mu.Unlock()
log.Printf("[Connect] Connected to %s, SharedSecret len: %d", pc.Peer.ID, len(pc.SharedSecret))
// Update registry
t.registry.SetConnected(pc.Peer.ID, true)
// Start read loop
t.wg.Add(1)
go t.readLoop(pc)
log.Printf("[Connect] Started readLoop for %s", pc.Peer.ID)
// Start keepalive
t.wg.Add(1)
go t.keepalive(pc)
return pc, nil
}
// Send sends a message to a specific peer.
func (t *Transport) Send(peerID string, msg *Message) error {
t.mu.RLock()
pc, exists := t.conns[peerID]
t.mu.RUnlock()
if !exists {
return fmt.Errorf("peer %s not connected", peerID)
}
return pc.Send(msg)
}
// Broadcast sends a message to all connected peers.
func (t *Transport) Broadcast(msg *Message) error {
t.mu.RLock()
conns := make([]*PeerConnection, 0, len(t.conns))
for _, pc := range t.conns {
conns = append(conns, pc)
}
t.mu.RUnlock()
var lastErr error
for _, pc := range conns {
if err := pc.Send(msg); err != nil {
lastErr = err
}
}
return lastErr
}
// GetConnection returns an active connection to a peer.
func (t *Transport) GetConnection(peerID string) *PeerConnection {
t.mu.RLock()
defer t.mu.RUnlock()
return t.conns[peerID]
}
// handleWSUpgrade handles incoming WebSocket connections.
func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
// Enforce MaxConns limit
t.mu.RLock()
currentConns := len(t.conns)
t.mu.RUnlock()
if currentConns >= t.config.MaxConns {
http.Error(w, "Too many connections", http.StatusServiceUnavailable)
return
}
conn, err := t.upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
// Wait for handshake from client
_, data, err := conn.ReadMessage()
if err != nil {
conn.Close()
return
}
// Decode handshake message (not encrypted yet, contains public key)
var msg Message
if err := json.Unmarshal(data, &msg); err != nil {
conn.Close()
return
}
if msg.Type != MsgHandshake {
conn.Close()
return
}
var payload HandshakePayload
if err := msg.ParsePayload(&payload); err != nil {
conn.Close()
return
}
// Derive shared secret from peer's public key
sharedSecret, err := t.node.DeriveSharedSecret(payload.Identity.PublicKey)
if err != nil {
conn.Close()
return
}
// Create peer if not exists
peer := t.registry.GetPeer(payload.Identity.ID)
if peer == nil {
// Auto-register for now (could require pre-registration)
peer = &Peer{
ID: payload.Identity.ID,
Name: payload.Identity.Name,
PublicKey: payload.Identity.PublicKey,
Role: payload.Identity.Role,
AddedAt: time.Now(),
Score: 50,
}
t.registry.AddPeer(peer)
}
pc := &PeerConnection{
Peer: peer,
Conn: conn,
SharedSecret: sharedSecret,
LastActivity: time.Now(),
transport: t,
}
// Send handshake acknowledgment
identity := t.node.GetIdentity()
ackPayload := HandshakeAckPayload{
Identity: *identity,
Accepted: true,
}
ackMsg, err := NewMessage(MsgHandshakeAck, identity.ID, peer.ID, ackPayload)
if err != nil {
conn.Close()
return
}
// First ack is unencrypted (peer needs to know our public key)
ackData, err := json.Marshal(ackMsg)
if err != nil {
conn.Close()
return
}
if err := conn.WriteMessage(websocket.TextMessage, ackData); err != nil {
conn.Close()
return
}
// Store connection
t.mu.Lock()
t.conns[peer.ID] = pc
t.mu.Unlock()
// Update registry
t.registry.SetConnected(peer.ID, true)
// Start read loop
t.wg.Add(1)
go t.readLoop(pc)
// Start keepalive
t.wg.Add(1)
go t.keepalive(pc)
}
// performHandshake initiates handshake with a peer.
func (t *Transport) performHandshake(pc *PeerConnection) error {
identity := t.node.GetIdentity()
payload := HandshakePayload{
Identity: *identity,
Version: "1.0",
}
msg, err := NewMessage(MsgHandshake, identity.ID, pc.Peer.ID, payload)
if err != nil {
return err
}
// First message is unencrypted (peer needs our public key)
data, err := json.Marshal(msg)
if err != nil {
return err
}
if err := pc.Conn.WriteMessage(websocket.TextMessage, data); err != nil {
return err
}
// Wait for ack
_, ackData, err := pc.Conn.ReadMessage()
if err != nil {
return err
}
var ackMsg Message
if err := json.Unmarshal(ackData, &ackMsg); err != nil {
return err
}
if ackMsg.Type != MsgHandshakeAck {
return fmt.Errorf("expected handshake_ack, got %s", ackMsg.Type)
}
var ackPayload HandshakeAckPayload
if err := ackMsg.ParsePayload(&ackPayload); err != nil {
return err
}
if !ackPayload.Accepted {
return fmt.Errorf("handshake rejected: %s", ackPayload.Reason)
}
// Update peer with the received identity info
pc.Peer.ID = ackPayload.Identity.ID
pc.Peer.PublicKey = ackPayload.Identity.PublicKey
pc.Peer.Name = ackPayload.Identity.Name
pc.Peer.Role = ackPayload.Identity.Role
// Update the peer in registry with the real identity
if err := t.registry.UpdatePeer(pc.Peer); err != nil {
// If update fails (peer not found with old ID), add as new
t.registry.AddPeer(pc.Peer)
}
return nil
}
// readLoop reads messages from a peer connection.
func (t *Transport) readLoop(pc *PeerConnection) {
defer t.wg.Done()
defer t.removeConnection(pc)
for {
select {
case <-t.ctx.Done():
return
default:
}
// Set read deadline to prevent blocking forever on unresponsive connections
readDeadline := t.config.PingInterval + t.config.PongTimeout
if err := pc.Conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil {
log.Printf("[readLoop] SetReadDeadline error for %s: %v", pc.Peer.ID, err)
return
}
_, data, err := pc.Conn.ReadMessage()
if err != nil {
log.Printf("[readLoop] Read error from %s: %v", pc.Peer.ID, err)
return
}
pc.LastActivity = time.Now()
// Decrypt message using SMSG with shared secret
msg, err := t.decryptMessage(data, pc.SharedSecret)
if err != nil {
log.Printf("[readLoop] Decrypt error from %s: %v (data len: %d)", pc.Peer.ID, err, len(data))
continue // Skip invalid messages
}
log.Printf("[readLoop] Received %s from %s (reply to: %s)", msg.Type, pc.Peer.ID, msg.ReplyTo)
// Dispatch to handler
if t.handler != nil {
t.handler(pc, msg)
}
}
}
// keepalive sends periodic pings.
func (t *Transport) keepalive(pc *PeerConnection) {
defer t.wg.Done()
ticker := time.NewTicker(t.config.PingInterval)
defer ticker.Stop()
for {
select {
case <-t.ctx.Done():
return
case <-ticker.C:
// Check if connection is still alive
if time.Since(pc.LastActivity) > t.config.PingInterval+t.config.PongTimeout {
t.removeConnection(pc)
return
}
// Send ping
identity := t.node.GetIdentity()
pingMsg, err := NewMessage(MsgPing, identity.ID, pc.Peer.ID, PingPayload{
SentAt: time.Now().UnixMilli(),
})
if err != nil {
continue
}
if err := pc.Send(pingMsg); err != nil {
t.removeConnection(pc)
return
}
}
}
}
// removeConnection removes and cleans up a connection.
func (t *Transport) removeConnection(pc *PeerConnection) {
t.mu.Lock()
delete(t.conns, pc.Peer.ID)
t.mu.Unlock()
t.registry.SetConnected(pc.Peer.ID, false)
pc.Close()
}
// Send sends an encrypted message over the connection.
func (pc *PeerConnection) Send(msg *Message) error {
pc.writeMu.Lock()
defer pc.writeMu.Unlock()
// Encrypt message using SMSG
data, err := pc.transport.encryptMessage(msg, pc.SharedSecret)
if err != nil {
return err
}
// Set write deadline to prevent blocking forever
if err := pc.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil {
return fmt.Errorf("failed to set write deadline: %w", err)
}
return pc.Conn.WriteMessage(websocket.BinaryMessage, data)
}
// Close closes the connection.
func (pc *PeerConnection) Close() error {
return pc.Conn.Close()
}
// encryptMessage encrypts a message using SMSG with the shared secret.
func (t *Transport) encryptMessage(msg *Message, sharedSecret []byte) ([]byte, error) {
// Serialize message to JSON
msgData, err := json.Marshal(msg)
if err != nil {
return nil, err
}
// Create SMSG message
smsgMsg := smsg.NewMessage(string(msgData))
// Encrypt using shared secret as password (base64 encoded)
password := base64.StdEncoding.EncodeToString(sharedSecret)
encrypted, err := smsg.Encrypt(smsgMsg, password)
if err != nil {
return nil, err
}
return encrypted, nil
}
// decryptMessage decrypts a message using SMSG with the shared secret.
func (t *Transport) decryptMessage(data []byte, sharedSecret []byte) (*Message, error) {
// Decrypt using shared secret as password
password := base64.StdEncoding.EncodeToString(sharedSecret)
smsgMsg, err := smsg.Decrypt(data, password)
if err != nil {
return nil, err
}
// Parse message from JSON
var msg Message
if err := json.Unmarshal([]byte(smsgMsg.Body), &msg); err != nil {
return nil, err
}
return &msg, nil
}
// ConnectedPeers returns the number of connected peers.
func (t *Transport) ConnectedPeers() int {
t.mu.RLock()
defer t.mu.RUnlock()
return len(t.conns)
}