diff --git a/node/transport.go b/node/transport.go index 6004aed..08eee04 100644 --- a/node/transport.go +++ b/node/transport.go @@ -969,12 +969,7 @@ func (t *Transport) keepalive(pc *PeerConnection) { } func (t *Transport) removeConnection(pc *PeerConnection) { - t.mutex.Lock() - delete(t.connections, pc.Peer.ID) - t.mutex.Unlock() - - t.peerRegistry.SetConnected(pc.Peer.ID, false) - pc.Close() + _ = pc.Close() } // err := peerConnection.Send(message) @@ -1008,6 +1003,9 @@ func (pc *PeerConnection) sendLocked(msg *Message) error { func (pc *PeerConnection) Close() error { var err error pc.closeOnce.Do(func() { + if pc.transport != nil { + pc.transport.detachConnection(pc) + } connection := pc.webSocketConnection() if connection == nil { return @@ -1034,35 +1032,40 @@ const ( // err := peerConnection.GracefulClose("server shutdown", DisconnectShutdown) func (pc *PeerConnection) GracefulClose(reason string, code int) error { - var err error - pc.closeOnce.Do(func() { - pc.writeMutex.Lock() - defer pc.writeMutex.Unlock() - - connection := pc.webSocketConnection() - if connection == nil { - return - } - - // Try to send disconnect message (best effort). - if pc.transport != nil && pc.SharedSecret != nil { - identity := pc.transport.nodeManager.GetIdentity() - if identity != nil { - payload := DisconnectPayload{ - Reason: reason, - Code: code, - } - msg, msgErr := NewMessage(MessageDisconnect, identity.ID, pc.Peer.ID, payload) - if msgErr == nil { - _ = pc.sendLocked(msg) - } + pc.writeMutex.Lock() + connection := pc.webSocketConnection() + if connection != nil && pc.transport != nil && pc.SharedSecret != nil { + identity := pc.transport.nodeManager.GetIdentity() + if identity != nil { + payload := DisconnectPayload{ + Reason: reason, + Code: code, + } + msg, msgErr := NewMessage(MessageDisconnect, identity.ID, pc.Peer.ID, payload) + if msgErr == nil { + _ = pc.sendLocked(msg) } } + } + pc.writeMutex.Unlock() - // Close the underlying connection - err = connection.Close() - }) - return err + return pc.Close() +} + +func (t *Transport) detachConnection(pc *PeerConnection) { + if pc == nil || pc.Peer == nil { + return + } + + t.mutex.Lock() + current, exists := t.connections[pc.Peer.ID] + if exists && current == pc { + delete(t.connections, pc.Peer.ID) + t.mutex.Unlock() + t.peerRegistry.SetConnected(pc.Peer.ID, false) + return + } + t.mutex.Unlock() } func (t *Transport) encryptMessage(msg *Message, sharedSecret []byte) ([]byte, error) { diff --git a/node/transport_test.go b/node/transport_test.go index 5d00ab7..378f8b8 100644 --- a/node/transport_test.go +++ b/node/transport_test.go @@ -619,6 +619,37 @@ func TestTransport_GracefulClose_Ugly(t *testing.T) { } } +func TestTransport_PeerConnectionClose_ReleasesState_Good(t *testing.T) { + tp := setupTestTransportPair(t) + + pc := tp.connectClient(t) + clientID := tp.ClientNode.GetIdentity().ID + if tp.Client.GetConnection(tp.ServerNode.GetIdentity().ID) == nil { + t.Fatal("client should have an active connection before close") + } + + if err := pc.Close(); err != nil { + t.Fatalf("close peer connection: %v", err) + } + + deadline := time.After(2 * time.Second) + for { + select { + case <-deadline: + t.Fatal("connection state was not released after close") + default: + if tp.Client.GetConnection(tp.ServerNode.GetIdentity().ID) == nil && tp.Client.ConnectedPeerCount() == 0 { + peer := tp.ClientReg.GetPeer(clientID) + if peer != nil && peer.Connected { + t.Fatal("registry should show peer as disconnected after close") + } + return + } + time.Sleep(20 * time.Millisecond) + } + } +} + func TestTransport_ConcurrentSends_Ugly(t *testing.T) { tp := setupTestTransportPair(t)