diff --git a/go.sum b/go.sum index 389ca89..6e9b7ec 100644 --- a/go.sum +++ b/go.sum @@ -4,25 +4,40 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= +github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ws.go b/ws.go index 3d8ff94..199b12c 100644 --- a/ws.go +++ b/ws.go @@ -62,7 +62,10 @@ import ( "context" "encoding/json" "fmt" + "iter" + "maps" "net/http" + "slices" "sync" "time" @@ -349,10 +352,7 @@ func (h *Hub) SendToChannel(channel string, msg Message) error { } // Copy client references under lock to avoid races during iteration - targets := make([]*Client, 0, len(clients)) - for client := range clients { - targets = append(targets, client) - } + targets := slices.Collect(maps.Keys(clients)) h.mu.RUnlock() for _, client := range targets { @@ -429,6 +429,20 @@ func (h *Hub) ChannelSubscriberCount(channel string) int { return 0 } +// AllClients returns an iterator for all connected clients. +func (h *Hub) AllClients() iter.Seq[*Client] { + h.mu.RLock() + defer h.mu.RUnlock() + return slices.Values(slices.Collect(maps.Keys(h.clients))) +} + +// AllChannels returns an iterator for all active channels. +func (h *Hub) AllChannels() iter.Seq[string] { + h.mu.RLock() + defer h.mu.RUnlock() + return slices.Values(slices.Collect(maps.Keys(h.channels))) +} + // HubStats contains hub statistics. type HubStats struct { Clients int `json:"clients"` @@ -586,11 +600,14 @@ func (c *Client) Subscriptions() []string { c.mu.RLock() defer c.mu.RUnlock() - result := make([]string, 0, len(c.subscriptions)) - for channel := range c.subscriptions { - result = append(result, channel) - } - return result + return slices.Collect(maps.Keys(c.subscriptions)) +} + +// AllSubscriptions returns an iterator for the client's current subscriptions. +func (c *Client) AllSubscriptions() iter.Seq[string] { + c.mu.RLock() + defer c.mu.RUnlock() + return slices.Values(slices.Collect(maps.Keys(c.subscriptions))) } // Close closes the client connection. @@ -796,7 +813,7 @@ func (rc *ReconnectingClient) setState(state ConnectionState) { func (rc *ReconnectingClient) calculateBackoff(attempt int) time.Duration { backoff := rc.config.InitialBackoff - for i := 1; i < attempt; i++ { + for range attempt - 1 { backoff = time.Duration(float64(backoff) * rc.config.BackoffMultiplier) if backoff > rc.config.MaxBackoff { backoff = rc.config.MaxBackoff diff --git a/ws_test.go b/ws_test.go index 88522fd..e803494 100644 --- a/ws_test.go +++ b/ws_test.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/http/httptest" + "slices" "strings" "sync" "testing" @@ -442,6 +443,52 @@ func TestClient_Subscriptions(t *testing.T) { }) } +func TestClient_AllSubscriptions(t *testing.T) { + t.Run("returns iterator over subscriptions", func(t *testing.T) { + client := &Client{subscriptions: make(map[string]bool)} + client.subscriptions["sub1"] = true + client.subscriptions["sub2"] = true + + subs := slices.Collect(client.AllSubscriptions()) + assert.Len(t, subs, 2) + assert.Contains(t, subs, "sub1") + assert.Contains(t, subs, "sub2") + }) +} + +func TestHub_AllClients(t *testing.T) { + t.Run("returns iterator over all clients", func(t *testing.T) { + hub := NewHub() + client1 := &Client{subscriptions: make(map[string]bool)} + client2 := &Client{subscriptions: make(map[string]bool)} + + hub.mu.Lock() + hub.clients[client1] = true + hub.clients[client2] = true + hub.mu.Unlock() + + clients := slices.Collect(hub.AllClients()) + assert.Len(t, clients, 2) + assert.Contains(t, clients, client1) + assert.Contains(t, clients, client2) + }) +} + +func TestHub_AllChannels(t *testing.T) { + t.Run("returns iterator over all active channels", func(t *testing.T) { + hub := NewHub() + hub.mu.Lock() + hub.channels["ch1"] = make(map[*Client]bool) + hub.channels["ch2"] = make(map[*Client]bool) + hub.mu.Unlock() + + channels := slices.Collect(hub.AllChannels()) + assert.Len(t, channels, 2) + assert.Contains(t, channels, "ch1") + assert.Contains(t, channels, "ch2") + }) +} + func TestMessage_JSON(t *testing.T) { t.Run("marshals correctly", func(t *testing.T) { msg := Message{ @@ -1390,7 +1437,7 @@ func BenchmarkBroadcast(b *testing.B) { msg := Message{Type: TypeEvent, Data: "benchmark"} b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = hub.Broadcast(msg) } } @@ -1416,7 +1463,7 @@ func BenchmarkSendToChannel(b *testing.B) { msg := Message{Type: TypeEvent, Data: "benchmark"} b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = hub.SendToChannel("bench-channel", msg) } }