862 lines
23 KiB
Go
862 lines
23 KiB
Go
package node
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
core "dappco.re/go/core"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// --- Test Helpers ---
|
|
|
|
func newTestNodeManager(t *testing.T, name string, role NodeRole) *NodeManager {
|
|
t.Helper()
|
|
dir := t.TempDir()
|
|
nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir))
|
|
if err != nil {
|
|
t.Fatalf("create node manager %q: %v", name, err)
|
|
}
|
|
if err := nm.GenerateIdentity(name, role); err != nil {
|
|
t.Fatalf("generate identity %q: %v", name, err)
|
|
}
|
|
return nm
|
|
}
|
|
|
|
func newTestPeerRegistry(t *testing.T) *PeerRegistry {
|
|
t.Helper()
|
|
dir := t.TempDir()
|
|
reg, err := NewPeerRegistryFromPath(testJoinPath(dir, "peers.json"))
|
|
if err != nil {
|
|
t.Fatalf("create registry: %v", err)
|
|
}
|
|
t.Cleanup(func() { reg.Close() })
|
|
return reg
|
|
}
|
|
|
|
// testTransportPair holds everything needed for transport pair tests.
|
|
type testTransportPair struct {
|
|
Server *Transport
|
|
Client *Transport
|
|
ServerNode *NodeManager
|
|
ClientNode *NodeManager
|
|
ServerReg *PeerRegistry
|
|
ClientReg *PeerRegistry
|
|
HTTPServer *httptest.Server
|
|
ServerAddr string // "127.0.0.1:PORT"
|
|
}
|
|
|
|
// setupTestTransportPair creates a server transport (backed by httptest) and a
|
|
// client transport, both with generated identities and open-auth registries.
|
|
func setupTestTransportPair(t *testing.T) *testTransportPair {
|
|
return setupTestTransportPairWithConfig(t, DefaultTransportConfig(), DefaultTransportConfig())
|
|
}
|
|
|
|
// setupTestTransportPairWithConfig allows custom configs for server and client.
|
|
func setupTestTransportPairWithConfig(t *testing.T, serverCfg, clientCfg TransportConfig) *testTransportPair {
|
|
t.Helper()
|
|
|
|
serverNM := newTestNodeManager(t, "server", RoleWorker)
|
|
clientNM := newTestNodeManager(t, "client", RoleController)
|
|
serverReg := newTestPeerRegistry(t)
|
|
clientReg := newTestPeerRegistry(t)
|
|
|
|
serverTransport := NewTransport(serverNM, serverReg, serverCfg)
|
|
clientTransport := NewTransport(clientNM, clientReg, clientCfg)
|
|
|
|
// Use httptest.Server with the transport's WebSocket handler
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(serverCfg.WebSocketPath, serverTransport.handleWebSocketUpgrade)
|
|
ts := httptest.NewServer(mux)
|
|
|
|
u, _ := url.Parse(ts.URL)
|
|
|
|
tp := &testTransportPair{
|
|
Server: serverTransport,
|
|
Client: clientTransport,
|
|
ServerNode: serverNM,
|
|
ClientNode: clientNM,
|
|
ServerReg: serverReg,
|
|
ClientReg: clientReg,
|
|
HTTPServer: ts,
|
|
ServerAddr: u.Host,
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
clientTransport.Stop()
|
|
serverTransport.Stop()
|
|
ts.Close()
|
|
})
|
|
|
|
return tp
|
|
}
|
|
|
|
// connectClient establishes a connection from client to server transport.
|
|
func (tp *testTransportPair) connectClient(t *testing.T) *PeerConnection {
|
|
t.Helper()
|
|
|
|
peer := &Peer{
|
|
ID: tp.ServerNode.GetIdentity().ID,
|
|
Name: "server",
|
|
Address: tp.ServerAddr,
|
|
Role: RoleWorker,
|
|
}
|
|
tp.ClientReg.AddPeer(peer)
|
|
|
|
pc, err := tp.Client.Connect(peer)
|
|
if err != nil {
|
|
t.Fatalf("client connect failed: %v", err)
|
|
}
|
|
return pc
|
|
}
|
|
|
|
// --- Unit Tests for Sub-Components ---
|
|
|
|
func TestTransport_MessageDeduplicator_Good(t *testing.T) {
|
|
t.Run("MarkAndCheck", func(t *testing.T) {
|
|
d := NewMessageDeduplicator(5 * time.Minute)
|
|
|
|
if d.IsDuplicate("msg-1") {
|
|
t.Error("should not be duplicate before marking")
|
|
}
|
|
|
|
d.Mark("msg-1")
|
|
|
|
if !d.IsDuplicate("msg-1") {
|
|
t.Error("should be duplicate after marking")
|
|
}
|
|
|
|
if d.IsDuplicate("msg-2") {
|
|
t.Error("different ID should not be duplicate")
|
|
}
|
|
})
|
|
|
|
t.Run("Cleanup", func(t *testing.T) {
|
|
d := NewMessageDeduplicator(50 * time.Millisecond)
|
|
d.Mark("msg-1")
|
|
|
|
if !d.IsDuplicate("msg-1") {
|
|
t.Error("should be duplicate immediately after marking")
|
|
}
|
|
|
|
time.Sleep(60 * time.Millisecond)
|
|
d.Cleanup()
|
|
|
|
if d.IsDuplicate("msg-1") {
|
|
t.Error("should not be duplicate after TTL + cleanup")
|
|
}
|
|
})
|
|
|
|
t.Run("ExpiredEntriesDoNotLinger", func(t *testing.T) {
|
|
d := NewMessageDeduplicator(50 * time.Millisecond)
|
|
d.Mark("msg-1")
|
|
|
|
time.Sleep(75 * time.Millisecond)
|
|
|
|
if d.IsDuplicate("msg-1") {
|
|
t.Error("should not be duplicate after TTL even before cleanup runs")
|
|
}
|
|
})
|
|
|
|
t.Run("ConcurrentAccess", func(t *testing.T) {
|
|
d := NewMessageDeduplicator(5 * time.Minute)
|
|
var wg sync.WaitGroup
|
|
for i := range 100 {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
msgID := "msg-" + time.Now().String()
|
|
d.Mark(msgID)
|
|
d.IsDuplicate(msgID)
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
})
|
|
}
|
|
|
|
func TestTransport_PeerRateLimiter_Good(t *testing.T) {
|
|
t.Run("AllowUpToBurst", func(t *testing.T) {
|
|
rl := NewPeerRateLimiter(10, 5)
|
|
|
|
for i := range 10 {
|
|
if !rl.Allow() {
|
|
t.Errorf("should allow message %d (within burst)", i)
|
|
}
|
|
}
|
|
|
|
if rl.Allow() {
|
|
t.Error("should reject message after burst exhausted")
|
|
}
|
|
})
|
|
|
|
t.Run("RefillAfterTime", func(t *testing.T) {
|
|
rl := NewPeerRateLimiter(5, 10) // 5 burst, 10/sec refill
|
|
|
|
// Exhaust all tokens
|
|
for range 5 {
|
|
rl.Allow()
|
|
}
|
|
|
|
if rl.Allow() {
|
|
t.Error("should reject after exhaustion")
|
|
}
|
|
|
|
// Wait for refill
|
|
time.Sleep(1100 * time.Millisecond)
|
|
|
|
if !rl.Allow() {
|
|
t.Error("should allow after refill")
|
|
}
|
|
})
|
|
}
|
|
|
|
// --- Transport Integration Tests ---
|
|
|
|
func TestTransport_FullHandshake_Good(t *testing.T) {
|
|
tp := setupTestTransportPair(t)
|
|
pc := tp.connectClient(t)
|
|
|
|
// Shared secret must be derived
|
|
if len(pc.SharedSecret) == 0 {
|
|
t.Error("shared secret should be derived after handshake")
|
|
}
|
|
|
|
// Allow server goroutines to register the connection
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if tp.Server.ConnectedPeerCount() != 1 {
|
|
t.Errorf("server connected peers: got %d, want 1", tp.Server.ConnectedPeerCount())
|
|
}
|
|
if tp.Client.ConnectedPeerCount() != 1 {
|
|
t.Errorf("client connected peers: got %d, want 1", tp.Client.ConnectedPeerCount())
|
|
}
|
|
|
|
// Verify peer identity was exchanged correctly
|
|
serverID := tp.ServerNode.GetIdentity().ID
|
|
serverConn := tp.Client.GetConnection(serverID)
|
|
if serverConn == nil {
|
|
t.Fatal("client should have connection to server by server ID")
|
|
}
|
|
if serverConn.Peer.Name != "server" {
|
|
t.Errorf("peer name: got %q, want %q", serverConn.Peer.Name, "server")
|
|
}
|
|
}
|
|
|
|
func TestTransport_ConnectSendsAgentUserAgent_Good(t *testing.T) {
|
|
serverNM := newTestNodeManager(t, "ua-server", RoleWorker)
|
|
clientNM := newTestNodeManager(t, "ua-client", RoleController)
|
|
serverReg := newTestPeerRegistry(t)
|
|
clientReg := newTestPeerRegistry(t)
|
|
|
|
serverCfg := DefaultTransportConfig()
|
|
clientCfg := DefaultTransportConfig()
|
|
serverTransport := NewTransport(serverNM, serverReg, serverCfg)
|
|
clientTransport := NewTransport(clientNM, clientReg, clientCfg)
|
|
|
|
var capturedUserAgent atomic.Value
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(serverCfg.WebSocketPath, func(w http.ResponseWriter, r *http.Request) {
|
|
capturedUserAgent.Store(r.Header.Get("User-Agent"))
|
|
serverTransport.handleWebSocketUpgrade(w, r)
|
|
})
|
|
|
|
ts := httptest.NewServer(mux)
|
|
t.Cleanup(func() {
|
|
clientTransport.Stop()
|
|
serverTransport.Stop()
|
|
ts.Close()
|
|
})
|
|
|
|
u, _ := url.Parse(ts.URL)
|
|
serverAddr := u.Host
|
|
|
|
peer := &Peer{
|
|
ID: serverNM.GetIdentity().ID,
|
|
Name: "server",
|
|
Address: serverAddr,
|
|
Role: RoleWorker,
|
|
}
|
|
clientReg.AddPeer(peer)
|
|
|
|
pc, err := clientTransport.Connect(peer)
|
|
if err != nil {
|
|
t.Fatalf("client connect failed: %v", err)
|
|
}
|
|
|
|
ua, ok := capturedUserAgent.Load().(string)
|
|
if !ok || ua == "" {
|
|
t.Fatal("expected user-agent to be captured during websocket upgrade")
|
|
}
|
|
if !strings.HasPrefix(ua, agentUserAgentPrefix) {
|
|
t.Fatalf("user-agent prefix: got %q, want prefix %q", ua, agentUserAgentPrefix)
|
|
}
|
|
if !strings.Contains(ua, "id="+clientNM.GetIdentity().ID) {
|
|
t.Fatalf("user-agent should include client identity, got %q", ua)
|
|
}
|
|
if pc.UserAgent != ua {
|
|
t.Fatalf("client connection user-agent: got %q, want %q", pc.UserAgent, ua)
|
|
}
|
|
|
|
serverConn := serverTransport.GetConnection(clientNM.GetIdentity().ID)
|
|
if serverConn == nil {
|
|
t.Fatal("server should retain the accepted connection")
|
|
}
|
|
if serverConn.UserAgent != ua {
|
|
t.Fatalf("server connection user-agent: got %q, want %q", serverConn.UserAgent, ua)
|
|
}
|
|
}
|
|
|
|
func TestTransport_HandshakeRejectWrongVersion_Bad(t *testing.T) {
|
|
tp := setupTestTransportPair(t)
|
|
|
|
// Dial raw WebSocket and send handshake with unsupported version
|
|
wsURL := "ws://" + tp.ServerAddr + "/ws"
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("raw dial: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
clientIdentity := tp.ClientNode.GetIdentity()
|
|
payload := HandshakePayload{
|
|
Identity: *clientIdentity,
|
|
Version: "99.99", // Unsupported
|
|
}
|
|
msg, _ := NewMessage(MessageHandshake, clientIdentity.ID, "", payload)
|
|
data, _ := MarshalJSON(msg)
|
|
|
|
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
|
t.Fatalf("write handshake: %v", err)
|
|
}
|
|
|
|
_, respData, err := conn.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("read response: %v", err)
|
|
}
|
|
|
|
var resp Message
|
|
testJSONUnmarshal(t, respData, &resp)
|
|
|
|
var ack HandshakeAckPayload
|
|
resp.ParsePayload(&ack)
|
|
|
|
if ack.Accepted {
|
|
t.Error("should reject incompatible protocol version")
|
|
}
|
|
if !core.Contains(ack.Reason, "incompatible protocol version") {
|
|
t.Errorf("expected version rejection reason, got: %s", ack.Reason)
|
|
}
|
|
}
|
|
|
|
func TestTransport_HandshakeRejectAllowlist_Bad(t *testing.T) {
|
|
tp := setupTestTransportPair(t)
|
|
|
|
// Switch server to allowlist mode WITHOUT adding client's key
|
|
tp.ServerReg.SetAuthMode(PeerAuthAllowlist)
|
|
|
|
peer := &Peer{
|
|
ID: tp.ServerNode.GetIdentity().ID,
|
|
Name: "server",
|
|
Address: tp.ServerAddr,
|
|
Role: RoleWorker,
|
|
}
|
|
tp.ClientReg.AddPeer(peer)
|
|
|
|
_, err := tp.Client.Connect(peer)
|
|
if err == nil {
|
|
t.Fatal("should reject peer not in allowlist")
|
|
}
|
|
if !core.Contains(err.Error(), "rejected") {
|
|
t.Errorf("expected rejection error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestTransport_EncryptedMessageRoundTrip_Ugly(t *testing.T) {
|
|
tp := setupTestTransportPair(t)
|
|
|
|
received := make(chan *Message, 1)
|
|
tp.Server.OnMessage(func(conn *PeerConnection, msg *Message) {
|
|
received <- msg
|
|
})
|
|
|
|
pc := tp.connectClient(t)
|
|
|
|
// Send an encrypted message from client to server
|
|
clientID := tp.ClientNode.GetIdentity().ID
|
|
serverID := tp.ServerNode.GetIdentity().ID
|
|
sentMsg, _ := NewMessage(MessagePing, clientID, serverID, PingPayload{
|
|
SentAt: time.Now().UnixMilli(),
|
|
})
|
|
|
|
if err := pc.Send(sentMsg); err != nil {
|
|
t.Fatalf("send: %v", err)
|
|
}
|
|
|
|
select {
|
|
case msg := <-received:
|
|
if msg.Type != MessagePing {
|
|
t.Errorf("type: got %s, want %s", msg.Type, MessagePing)
|
|
}
|
|
if msg.ID != sentMsg.ID {
|
|
t.Error("message ID mismatch after encrypt/decrypt round-trip")
|
|
}
|
|
if msg.From != clientID {
|
|
t.Errorf("from: got %s, want %s", msg.From, clientID)
|
|
}
|
|
|
|
var payload PingPayload
|
|
msg.ParsePayload(&payload)
|
|
if payload.SentAt == 0 {
|
|
t.Error("payload should have SentAt timestamp")
|
|
}
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout waiting for message")
|
|
}
|
|
}
|
|
|
|
func TestTransport_MessageDedup_Good(t *testing.T) {
|
|
tp := setupTestTransportPair(t)
|
|
|
|
var count atomic.Int32
|
|
tp.Server.OnMessage(func(conn *PeerConnection, msg *Message) {
|
|
count.Add(1)
|
|
})
|
|
|
|
pc := tp.connectClient(t)
|
|
|
|
clientID := tp.ClientNode.GetIdentity().ID
|
|
serverID := tp.ServerNode.GetIdentity().ID
|
|
msg, _ := NewMessage(MessagePing, clientID, serverID, PingPayload{SentAt: time.Now().UnixMilli()})
|
|
|
|
// Send the same message twice
|
|
if err := pc.Send(msg); err != nil {
|
|
t.Fatalf("first send: %v", err)
|
|
}
|
|
time.Sleep(100 * time.Millisecond) // Ensure first is processed and marked
|
|
|
|
if err := pc.Send(msg); err != nil {
|
|
t.Fatalf("second send: %v", err)
|
|
}
|
|
time.Sleep(100 * time.Millisecond) // Allow time for second to be processed (or dropped)
|
|
|
|
if got := count.Load(); got != 1 {
|
|
t.Errorf("expected 1 message delivered (dedup), got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestTransport_RateLimiting_Good(t *testing.T) {
|
|
tp := setupTestTransportPair(t)
|
|
|
|
var count atomic.Int32
|
|
tp.Server.OnMessage(func(conn *PeerConnection, msg *Message) {
|
|
count.Add(1)
|
|
})
|
|
|
|
pc := tp.connectClient(t)
|
|
|
|
clientID := tp.ClientNode.GetIdentity().ID
|
|
serverID := tp.ServerNode.GetIdentity().ID
|
|
|
|
// Send 150 messages rapidly (rate limiter burst = 100)
|
|
for range 150 {
|
|
msg, _ := NewMessage(MessagePing, clientID, serverID, PingPayload{SentAt: time.Now().UnixMilli()})
|
|
pc.Send(msg)
|
|
}
|
|
|
|
time.Sleep(1 * time.Second) // Allow processing
|
|
|
|
received := int(count.Load())
|
|
t.Logf("rate limiting: %d/150 messages delivered", received)
|
|
|
|
if received >= 150 {
|
|
t.Error("rate limiting should have dropped some messages")
|
|
}
|
|
if received < 50 {
|
|
t.Errorf("too few messages received (%d), rate limiter may be too aggressive", received)
|
|
}
|
|
}
|
|
|
|
func TestTransport_MaxConnectionsEnforcement_Good(t *testing.T) {
|
|
// Server with MaxConnections=1
|
|
serverNM := newTestNodeManager(t, "maxconns-server", RoleWorker)
|
|
serverReg := newTestPeerRegistry(t)
|
|
|
|
serverCfg := DefaultTransportConfig()
|
|
serverCfg.MaxConnections = 1
|
|
serverTransport := NewTransport(serverNM, serverReg, serverCfg)
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(serverCfg.WebSocketPath, serverTransport.handleWebSocketUpgrade)
|
|
ts := httptest.NewServer(mux)
|
|
t.Cleanup(func() {
|
|
serverTransport.Stop()
|
|
ts.Close()
|
|
})
|
|
|
|
u, _ := url.Parse(ts.URL)
|
|
serverAddr := u.Host
|
|
|
|
// First client connects successfully
|
|
client1NM := newTestNodeManager(t, "client1", RoleController)
|
|
client1Reg := newTestPeerRegistry(t)
|
|
client1Transport := NewTransport(client1NM, client1Reg, DefaultTransportConfig())
|
|
t.Cleanup(func() { client1Transport.Stop() })
|
|
|
|
peer1 := &Peer{ID: serverNM.GetIdentity().ID, Name: "server", Address: serverAddr, Role: RoleWorker}
|
|
client1Reg.AddPeer(peer1)
|
|
|
|
_, err := client1Transport.Connect(peer1)
|
|
if err != nil {
|
|
t.Fatalf("first connection should succeed: %v", err)
|
|
}
|
|
|
|
// Allow server to register the connection
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Second client should be rejected (MaxConnections=1 reached)
|
|
client2NM := newTestNodeManager(t, "client2", RoleController)
|
|
client2Reg := newTestPeerRegistry(t)
|
|
client2Transport := NewTransport(client2NM, client2Reg, DefaultTransportConfig())
|
|
t.Cleanup(func() { client2Transport.Stop() })
|
|
|
|
peer2 := &Peer{ID: serverNM.GetIdentity().ID, Name: "server", Address: serverAddr, Role: RoleWorker}
|
|
client2Reg.AddPeer(peer2)
|
|
|
|
_, err = client2Transport.Connect(peer2)
|
|
if err == nil {
|
|
t.Fatal("second connection should be rejected when MaxConnections=1")
|
|
}
|
|
}
|
|
|
|
func TestTransport_KeepaliveTimeout_Bad(t *testing.T) {
|
|
// Use short keepalive settings so the test is fast
|
|
serverCfg := DefaultTransportConfig()
|
|
serverCfg.PingInterval = 100 * time.Millisecond
|
|
serverCfg.PongTimeout = 100 * time.Millisecond
|
|
|
|
clientCfg := DefaultTransportConfig()
|
|
clientCfg.PingInterval = 100 * time.Millisecond
|
|
clientCfg.PongTimeout = 100 * time.Millisecond
|
|
|
|
tp := setupTestTransportPairWithConfig(t, serverCfg, clientCfg)
|
|
tp.connectClient(t)
|
|
|
|
// Verify connection is established
|
|
time.Sleep(50 * time.Millisecond)
|
|
if tp.Server.ConnectedPeerCount() != 1 {
|
|
t.Fatalf("server should have 1 peer initially, got %d", tp.Server.ConnectedPeerCount())
|
|
}
|
|
|
|
// Close the underlying WebSocket on the client side to simulate network failure.
|
|
// The server's readLoop will detect the broken connection and clean up.
|
|
clientID := tp.ClientNode.GetIdentity().ID
|
|
serverPeerID := tp.ServerNode.GetIdentity().ID
|
|
clientConn := tp.Client.GetConnection(serverPeerID)
|
|
if clientConn == nil {
|
|
t.Fatal("client should have connection to server")
|
|
}
|
|
clientConn.WebSocketConnection.Close()
|
|
|
|
// Wait for server to detect and clean up
|
|
deadline := time.After(2 * time.Second)
|
|
for {
|
|
select {
|
|
case <-deadline:
|
|
t.Fatalf("server did not clean up connection: still has %d peers", tp.Server.ConnectedPeerCount())
|
|
default:
|
|
if tp.Server.ConnectedPeerCount() == 0 {
|
|
// Verify registry updated
|
|
peer := tp.ServerReg.GetPeer(clientID)
|
|
if peer != nil && peer.Connected {
|
|
t.Error("registry should show peer as disconnected")
|
|
}
|
|
return
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTransport_GracefulClose_Ugly(t *testing.T) {
|
|
tp := setupTestTransportPair(t)
|
|
|
|
received := make(chan *Message, 10)
|
|
tp.Server.OnMessage(func(conn *PeerConnection, msg *Message) {
|
|
received <- msg
|
|
})
|
|
|
|
pc := tp.connectClient(t)
|
|
|
|
// Allow connection to fully establish
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Graceful close should send a MessageDisconnect before closing
|
|
pc.GracefulClose("test shutdown", DisconnectNormal)
|
|
|
|
// Check if disconnect message was received
|
|
select {
|
|
case msg := <-received:
|
|
if msg.Type != MessageDisconnect {
|
|
t.Errorf("expected disconnect message, got %s", msg.Type)
|
|
}
|
|
var payload DisconnectPayload
|
|
msg.ParsePayload(&payload)
|
|
if payload.Reason != "test shutdown" {
|
|
t.Errorf("disconnect reason: got %q, want %q", payload.Reason, "test shutdown")
|
|
}
|
|
if payload.Code != DisconnectNormal {
|
|
t.Errorf("disconnect code: got %d, want %d", payload.Code, DisconnectNormal)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Error("timeout waiting for disconnect message")
|
|
}
|
|
}
|
|
|
|
func TestTransport_PeerConnectionClose_ReleasesState_Good(t *testing.T) {
|
|
tp := setupTestTransportPair(t)
|
|
|
|
pc := tp.connectClient(t)
|
|
clientID := tp.ClientNode.GetIdentity().ID
|
|
if tp.Client.GetConnection(tp.ServerNode.GetIdentity().ID) == nil {
|
|
t.Fatal("client should have an active connection before close")
|
|
}
|
|
|
|
if err := pc.Close(); err != nil {
|
|
t.Fatalf("close peer connection: %v", err)
|
|
}
|
|
|
|
deadline := time.After(2 * time.Second)
|
|
for {
|
|
select {
|
|
case <-deadline:
|
|
t.Fatal("connection state was not released after close")
|
|
default:
|
|
if tp.Client.GetConnection(tp.ServerNode.GetIdentity().ID) == nil && tp.Client.ConnectedPeerCount() == 0 {
|
|
peer := tp.ClientReg.GetPeer(clientID)
|
|
if peer != nil && peer.Connected {
|
|
t.Fatal("registry should show peer as disconnected after close")
|
|
}
|
|
return
|
|
}
|
|
time.Sleep(20 * time.Millisecond)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTransport_ConcurrentSends_Ugly(t *testing.T) {
|
|
tp := setupTestTransportPair(t)
|
|
|
|
var count atomic.Int32
|
|
tp.Server.OnMessage(func(conn *PeerConnection, msg *Message) {
|
|
count.Add(1)
|
|
})
|
|
|
|
pc := tp.connectClient(t)
|
|
|
|
clientID := tp.ClientNode.GetIdentity().ID
|
|
serverID := tp.ServerNode.GetIdentity().ID
|
|
|
|
// Spawn 10 goroutines each sending 5 messages concurrently
|
|
const goroutines = 10
|
|
const msgsPerGoroutine = 5
|
|
var wg sync.WaitGroup
|
|
|
|
for range goroutines {
|
|
wg.Go(func() {
|
|
for range msgsPerGoroutine {
|
|
msg, _ := NewMessage(MessagePing, clientID, serverID, PingPayload{SentAt: time.Now().UnixMilli()})
|
|
pc.Send(msg)
|
|
}
|
|
})
|
|
}
|
|
|
|
wg.Wait()
|
|
time.Sleep(1 * time.Second) // Allow delivery
|
|
|
|
got := int(count.Load())
|
|
// All messages should be delivered (unique IDs, within rate limit burst of 100)
|
|
expected := goroutines * msgsPerGoroutine
|
|
if got != expected {
|
|
t.Errorf("concurrent sends: got %d/%d messages delivered", got, expected)
|
|
}
|
|
}
|
|
|
|
// --- Additional coverage tests ---
|
|
|
|
func TestTransport_Broadcast_Good(t *testing.T) {
|
|
// Set up a controller with two worker peers connected.
|
|
controllerNM := newTestNodeManager(t, "broadcast-controller", RoleController)
|
|
controllerReg := newTestPeerRegistry(t)
|
|
controllerTransport := NewTransport(controllerNM, controllerReg, DefaultTransportConfig())
|
|
t.Cleanup(func() { controllerTransport.Stop() })
|
|
|
|
const numWorkers = 2
|
|
var receiveCounters [numWorkers]*atomic.Int32
|
|
|
|
for i := range numWorkers {
|
|
receiveCounters[i] = &atomic.Int32{}
|
|
counter := receiveCounters[i]
|
|
|
|
nm, addr, srv := makeWorkerServer(t)
|
|
srv.OnMessage(func(conn *PeerConnection, msg *Message) {
|
|
counter.Add(1)
|
|
})
|
|
|
|
wID := nm.GetIdentity().ID
|
|
peer := &Peer{
|
|
ID: wID,
|
|
Name: "worker",
|
|
Address: addr,
|
|
Role: RoleWorker,
|
|
}
|
|
controllerReg.AddPeer(peer)
|
|
|
|
_, err := controllerTransport.Connect(peer)
|
|
if err != nil {
|
|
t.Fatalf("failed to connect to worker %d: %v", i, err)
|
|
}
|
|
}
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// Broadcast a message from the controller
|
|
controllerID := controllerNM.GetIdentity().ID
|
|
msg, _ := NewMessage(MessagePing, controllerID, "", PingPayload{
|
|
SentAt: time.Now().UnixMilli(),
|
|
})
|
|
|
|
err := controllerTransport.Broadcast(msg)
|
|
if err != nil {
|
|
t.Fatalf("Broadcast failed: %v", err)
|
|
}
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
// Both workers should have received the broadcast
|
|
for i, counter := range receiveCounters {
|
|
if counter.Load() != 1 {
|
|
t.Errorf("worker %d received %d messages, expected 1", i, counter.Load())
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTransport_BroadcastExcludesSender_Good(t *testing.T) {
|
|
// Verify that Broadcast excludes the sender.
|
|
tp := setupTestTransportPair(t)
|
|
|
|
serverReceived := &atomic.Int32{}
|
|
tp.Server.OnMessage(func(conn *PeerConnection, msg *Message) {
|
|
serverReceived.Add(1)
|
|
})
|
|
|
|
tp.connectClient(t)
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Broadcast from the server side with From = server ID.
|
|
// The server has a connection to the client, but msg.From matches the client's
|
|
// connection peer ID check, not the server's own ID. Let's verify sender exclusion
|
|
// by broadcasting from the server with its own ID.
|
|
serverID := tp.ServerNode.GetIdentity().ID
|
|
msg, _ := NewMessage(MessagePing, serverID, "", PingPayload{SentAt: time.Now().UnixMilli()})
|
|
|
|
// This broadcasts from server to all connected peers (the client).
|
|
// The server itself won't receive it back because it's not connected to itself.
|
|
err := tp.Server.Broadcast(msg)
|
|
if err != nil {
|
|
t.Fatalf("Broadcast failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestTransport_NewTransport_DefaultMaxMessageSize_Good(t *testing.T) {
|
|
nm := newTestNodeManager(t, "defaults", RoleWorker)
|
|
reg := newTestPeerRegistry(t)
|
|
cfg := TransportConfig{
|
|
MaxMessageSize: 0, // should use default
|
|
}
|
|
tr := NewTransport(nm, reg, cfg)
|
|
|
|
if tr == nil {
|
|
t.Fatal("NewTransport returned nil")
|
|
}
|
|
if tr.config.MaxMessageSize != 0 {
|
|
t.Errorf("config should preserve 0 value, got %d", tr.config.MaxMessageSize)
|
|
}
|
|
// The actual default is applied at usage time (readLoop, handleWebSocketUpgrade)
|
|
}
|
|
|
|
func TestTransport_ConnectedPeerCount_Good(t *testing.T) {
|
|
tp := setupTestTransportPair(t)
|
|
|
|
if tp.Server.ConnectedPeerCount() != 0 {
|
|
t.Errorf("expected 0 connected peers initially, got %d", tp.Server.ConnectedPeerCount())
|
|
}
|
|
|
|
tp.connectClient(t)
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if tp.Server.ConnectedPeerCount() != 1 {
|
|
t.Errorf("expected 1 connected peer after connect, got %d", tp.Server.ConnectedPeerCount())
|
|
}
|
|
}
|
|
|
|
func TestTransport_StartAndStop_Good(t *testing.T) {
|
|
nm := newTestNodeManager(t, "start-test", RoleWorker)
|
|
reg := newTestPeerRegistry(t)
|
|
cfg := DefaultTransportConfig()
|
|
cfg.ListenAddress = ":0" // Let OS pick a free port
|
|
|
|
tr := NewTransport(nm, reg, cfg)
|
|
|
|
err := tr.Start()
|
|
if err != nil {
|
|
t.Fatalf("Start failed: %v", err)
|
|
}
|
|
|
|
// Small wait for server goroutine to start
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
err = tr.Stop()
|
|
if err != nil {
|
|
t.Fatalf("Stop failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestTransport_CheckOrigin_Good(t *testing.T) {
|
|
nm := newTestNodeManager(t, "origin-test", RoleWorker)
|
|
reg := newTestPeerRegistry(t)
|
|
cfg := DefaultTransportConfig()
|
|
tr := NewTransport(nm, reg, cfg)
|
|
|
|
tests := []struct {
|
|
name string
|
|
origin string
|
|
allowed bool
|
|
}{
|
|
{"no origin", "", true},
|
|
{"localhost", "http://localhost:8080", true},
|
|
{"127.0.0.1", "http://127.0.0.1:8080", true},
|
|
{"ipv6 loopback", "http://[::1]:8080", true},
|
|
{"remote host", "http://evil.example.com", false},
|
|
{"invalid origin", "://not-a-url", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
r := &http.Request{Header: http.Header{}}
|
|
if tt.origin != "" {
|
|
r.Header.Set("Origin", tt.origin)
|
|
}
|
|
result := tr.upgrader.CheckOrigin(r)
|
|
if result != tt.allowed {
|
|
t.Errorf("CheckOrigin(%q) = %v, want %v", tt.origin, result, tt.allowed)
|
|
}
|
|
})
|
|
}
|
|
}
|