414 lines
13 KiB
Go
414 lines
13 KiB
Go
package display
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"forge.lthn.ai/core/gui/pkg/window"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type testWindowListener struct {
|
|
handler func(window.WindowEvent)
|
|
}
|
|
|
|
func (w *testWindowListener) Name() string { return "listener" }
|
|
func (w *testWindowListener) Title() string { return "listener" }
|
|
func (w *testWindowListener) Position() (int, int) { return 0, 0 }
|
|
func (w *testWindowListener) Size() (int, int) { return 0, 0 }
|
|
func (w *testWindowListener) IsMaximised() bool { return false }
|
|
func (w *testWindowListener) IsFocused() bool { return false }
|
|
func (w *testWindowListener) IsVisible() bool { return false }
|
|
func (w *testWindowListener) IsFullscreen() bool { return false }
|
|
func (w *testWindowListener) IsMinimised() bool { return false }
|
|
func (w *testWindowListener) GetBounds() (int, int, int, int) { return 0, 0, 0, 0 }
|
|
func (w *testWindowListener) GetZoom() float64 { return 1 }
|
|
func (w *testWindowListener) SetTitle(string) {}
|
|
func (w *testWindowListener) SetPosition(int, int) {}
|
|
func (w *testWindowListener) SetSize(int, int) {}
|
|
func (w *testWindowListener) SetBackgroundColour(uint8, uint8, uint8, uint8) {
|
|
}
|
|
func (w *testWindowListener) SetVisibility(bool) {}
|
|
func (w *testWindowListener) SetAlwaysOnTop(bool) {}
|
|
func (w *testWindowListener) SetBounds(int, int, int, int) {}
|
|
func (w *testWindowListener) SetURL(string) {}
|
|
func (w *testWindowListener) SetHTML(string) {}
|
|
func (w *testWindowListener) SetZoom(float64) {}
|
|
func (w *testWindowListener) SetContentProtection(bool) {}
|
|
func (w *testWindowListener) Maximise() {}
|
|
func (w *testWindowListener) Restore() {}
|
|
func (w *testWindowListener) Minimise() {}
|
|
func (w *testWindowListener) Focus() {}
|
|
func (w *testWindowListener) Close() {}
|
|
func (w *testWindowListener) Show() {}
|
|
func (w *testWindowListener) Hide() {}
|
|
func (w *testWindowListener) Fullscreen() {}
|
|
func (w *testWindowListener) UnFullscreen() {}
|
|
func (w *testWindowListener) ToggleFullscreen() {}
|
|
func (w *testWindowListener) ToggleMaximise() {}
|
|
func (w *testWindowListener) ExecJS(string) {}
|
|
func (w *testWindowListener) Flash(bool) {}
|
|
func (w *testWindowListener) Print() error { return nil }
|
|
func (w *testWindowListener) OnWindowEvent(handler func(window.WindowEvent)) {
|
|
w.handler = handler
|
|
}
|
|
func (w *testWindowListener) OnFileDrop(func([]string, string)) {}
|
|
|
|
func (w *testWindowListener) emit(event window.WindowEvent) {
|
|
if w.handler != nil {
|
|
w.handler(event)
|
|
}
|
|
}
|
|
|
|
func dialWSEventManager(t *testing.T, em *WSEventManager) (*websocket.Conn, func()) {
|
|
t.Helper()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(em.HandleWebSocket))
|
|
t.Cleanup(server.Close)
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
|
|
return conn, func() {
|
|
_ = conn.Close()
|
|
em.Close()
|
|
}
|
|
}
|
|
|
|
func readJSONMessage(t *testing.T, conn *websocket.Conn) map[string]any {
|
|
t.Helper()
|
|
|
|
_, data, err := conn.ReadMessage()
|
|
require.NoError(t, err)
|
|
|
|
var payload map[string]any
|
|
require.NoError(t, json.Unmarshal(data, &payload))
|
|
return payload
|
|
}
|
|
|
|
func writeJSONMessage(t *testing.T, conn *websocket.Conn, payload map[string]any) {
|
|
t.Helper()
|
|
require.NoError(t, conn.WriteJSON(payload))
|
|
}
|
|
|
|
func TestWSEventManager_HandleWebSocket_Good(t *testing.T) {
|
|
em := NewWSEventManager()
|
|
conn, cleanup := dialWSEventManager(t, em)
|
|
defer cleanup()
|
|
|
|
writeJSONMessage(t, conn, map[string]any{
|
|
"action": "subscribe",
|
|
"id": "sub-1",
|
|
"eventTypes": []string{"*"},
|
|
})
|
|
subscribeAck := readJSONMessage(t, conn)
|
|
assert.Equal(t, "subscribed", subscribeAck["type"])
|
|
assert.Equal(t, "sub-1", subscribeAck["id"])
|
|
|
|
info := em.Info()
|
|
assert.Equal(t, 1, info.ConnectedClients)
|
|
assert.Equal(t, 1, info.SubscriptionCount)
|
|
|
|
em.Emit(Event{Type: EventCustomEvent, Window: "main", Data: map[string]any{"hello": "world"}})
|
|
eventPayload := readJSONMessage(t, conn)
|
|
assert.Equal(t, string(EventCustomEvent), eventPayload["type"])
|
|
assert.Equal(t, "main", eventPayload["window"])
|
|
|
|
writeJSONMessage(t, conn, map[string]any{"action": "list"})
|
|
listPayload := readJSONMessage(t, conn)
|
|
assert.Equal(t, "subscriptions", listPayload["type"])
|
|
assert.Len(t, listPayload["subscriptions"].([]any), 1)
|
|
|
|
writeJSONMessage(t, conn, map[string]any{"action": "unsubscribe", "id": "sub-1"})
|
|
unsubscribeAck := readJSONMessage(t, conn)
|
|
assert.Equal(t, "unsubscribed", unsubscribeAck["type"])
|
|
|
|
writeJSONMessage(t, conn, map[string]any{"action": "list"})
|
|
emptyList := readJSONMessage(t, conn)
|
|
assert.Equal(t, "subscriptions", emptyList["type"])
|
|
assert.Len(t, emptyList["subscriptions"].([]any), 0)
|
|
|
|
require.NoError(t, conn.Close())
|
|
require.Eventually(t, func() bool {
|
|
return em.ConnectedClients() == 0
|
|
}, 2*time.Second, 20*time.Millisecond)
|
|
}
|
|
|
|
func TestWSEventManager_HandleWebSocket_RejectsRemoteOrigin(t *testing.T) {
|
|
em := NewWSEventManager()
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil)
|
|
req.Header.Set("Origin", "https://evil.example")
|
|
recorder := httptest.NewRecorder()
|
|
|
|
em.HandleWebSocket(recorder, req)
|
|
|
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
|
}
|
|
|
|
func TestWSEventManager_HandleWebSocket_RejectsLoopbackSpoofedOrigin(t *testing.T) {
|
|
em := NewWSEventManager()
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil)
|
|
req.RemoteAddr = "203.0.113.10:12345"
|
|
req.Header.Set("Origin", "file://malicious")
|
|
recorder := httptest.NewRecorder()
|
|
|
|
em.HandleWebSocket(recorder, req)
|
|
|
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
|
}
|
|
|
|
func TestEvents_trustedWebSocketOrigin_Good(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
req *http.Request
|
|
want bool
|
|
}{
|
|
{
|
|
name: "localhost without origin",
|
|
req: func() *http.Request {
|
|
r := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil)
|
|
r.RemoteAddr = "127.0.0.1:12345"
|
|
return r
|
|
}(),
|
|
want: true,
|
|
},
|
|
{
|
|
name: "local origin",
|
|
req: func() *http.Request {
|
|
r := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil)
|
|
r.RemoteAddr = "127.0.0.1:12345"
|
|
r.Header.Set("Origin", "http://localhost:8080")
|
|
return r
|
|
}(),
|
|
want: true,
|
|
},
|
|
{
|
|
name: "file origin",
|
|
req: func() *http.Request {
|
|
r := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil)
|
|
r.RemoteAddr = "[::1]:12345"
|
|
r.Header.Set("Origin", "file://local")
|
|
return r
|
|
}(),
|
|
want: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
assert.Equal(t, tc.want, trustedWebSocketOrigin(tc.req))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEvents_trustedWebSocketOrigin_Bad(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
req *http.Request
|
|
}{
|
|
{
|
|
name: "nil request",
|
|
},
|
|
{
|
|
name: "wrong path",
|
|
req: func() *http.Request {
|
|
r := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/other", nil)
|
|
r.RemoteAddr = "127.0.0.1:12345"
|
|
return r
|
|
}(),
|
|
},
|
|
{
|
|
name: "remote client",
|
|
req: func() *http.Request {
|
|
r := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil)
|
|
r.RemoteAddr = "203.0.113.10:2222"
|
|
return r
|
|
}(),
|
|
},
|
|
{
|
|
name: "remote origin",
|
|
req: func() *http.Request {
|
|
r := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil)
|
|
r.RemoteAddr = "127.0.0.1:12345"
|
|
r.Header.Set("Origin", "https://evil.example")
|
|
return r
|
|
}(),
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
assert.False(t, trustedWebSocketOrigin(tc.req))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEvents_trustedWebSocketOrigin_Ugly(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
req.Header.Set("Origin", "://bad")
|
|
|
|
assert.False(t, trustedWebSocketOrigin(req))
|
|
assert.False(t, trustedWebSocketOrigin(&http.Request{}))
|
|
}
|
|
|
|
func TestEvents_trustedWebSocketHost_Good(t *testing.T) {
|
|
assert.True(t, trustedWebSocketHost("localhost"))
|
|
assert.True(t, trustedWebSocketHost("127.0.0.1:443"))
|
|
assert.True(t, trustedWebSocketHost("[::1]:80"))
|
|
}
|
|
|
|
func TestEvents_trustedWebSocketHost_Bad(t *testing.T) {
|
|
assert.False(t, trustedWebSocketHost(""))
|
|
assert.False(t, trustedWebSocketHost("example.com"))
|
|
}
|
|
|
|
func TestEvents_trustedWebSocketHost_Ugly(t *testing.T) {
|
|
assert.False(t, trustedWebSocketHost("not a host"))
|
|
}
|
|
|
|
func TestEvents_isLoopbackHost_Good(t *testing.T) {
|
|
assert.True(t, isLoopbackHost("localhost"))
|
|
assert.True(t, isLoopbackHost("127.0.0.1"))
|
|
assert.True(t, isLoopbackHost("::1"))
|
|
}
|
|
|
|
func TestEvents_isLoopbackHost_Bad(t *testing.T) {
|
|
assert.False(t, isLoopbackHost(""))
|
|
assert.False(t, isLoopbackHost("example.com"))
|
|
}
|
|
|
|
func TestEvents_isLoopbackHost_Ugly(t *testing.T) {
|
|
assert.False(t, isLoopbackHost("203.0.113.10"))
|
|
}
|
|
|
|
func TestWSEventManager_HandleWebSocket_ClosesOnMalformedMessage(t *testing.T) {
|
|
em := NewWSEventManager()
|
|
conn, cleanup := dialWSEventManager(t, em)
|
|
defer cleanup()
|
|
|
|
require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"action":`)))
|
|
|
|
payload := readJSONMessage(t, conn)
|
|
assert.Equal(t, "invalid websocket message", payload["error"])
|
|
assert.Equal(t, float64(websocket.ClosePolicyViolation), payload["status"])
|
|
|
|
_, _, err := conn.ReadMessage()
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestWSEventManager_HandleWebSocket_ClosesOnUnknownAction(t *testing.T) {
|
|
em := NewWSEventManager()
|
|
conn, cleanup := dialWSEventManager(t, em)
|
|
defer cleanup()
|
|
|
|
require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"action":"bogus"}`)))
|
|
|
|
payload := readJSONMessage(t, conn)
|
|
assert.Equal(t, "unknown websocket action", payload["error"])
|
|
assert.Equal(t, float64(websocket.ClosePolicyViolation), payload["status"])
|
|
|
|
_, _, err := conn.ReadMessage()
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestWSEventManager_Emit_Ugly(t *testing.T) {
|
|
em := &WSEventManager{
|
|
clients: map[*websocket.Conn]*clientState{},
|
|
eventBuffer: make(chan Event, 1),
|
|
}
|
|
|
|
em.Emit(Event{Type: EventTrayClick})
|
|
em.Emit(Event{Type: EventTrayClick})
|
|
|
|
assert.Len(t, em.eventBuffer, 1)
|
|
info := em.Info()
|
|
assert.Equal(t, 0, info.ConnectedClients)
|
|
assert.Equal(t, 1, info.BufferLength)
|
|
}
|
|
|
|
func TestWSEventManager_EmitWindowEvent_Good(t *testing.T) {
|
|
em := &WSEventManager{
|
|
clients: map[*websocket.Conn]*clientState{},
|
|
eventBuffer: make(chan Event, 2),
|
|
}
|
|
|
|
em.EmitWindowEvent(EventWindowMove, "editor", map[string]any{"x": 10, "y": 20})
|
|
require.Len(t, em.eventBuffer, 1)
|
|
|
|
event := <-em.eventBuffer
|
|
assert.Equal(t, EventWindowMove, event.Type)
|
|
assert.Equal(t, "editor", event.Window)
|
|
assert.Equal(t, 10, event.Data["x"])
|
|
}
|
|
|
|
func TestWSEventManager_ClientSubscribed_Good(t *testing.T) {
|
|
em := &WSEventManager{}
|
|
state := &clientState{
|
|
subscriptions: map[string]*Subscription{
|
|
"sub-1": {ID: "sub-1", EventTypes: []EventType{EventWindowFocus}},
|
|
"sub-2": {ID: "sub-2", EventTypes: []EventType{"*"}},
|
|
},
|
|
}
|
|
|
|
assert.True(t, em.clientSubscribed(state, EventWindowFocus))
|
|
assert.True(t, em.clientSubscribed(state, EventCustomEvent))
|
|
}
|
|
|
|
func TestWSEventManager_ClientSubscribed_Bad(t *testing.T) {
|
|
em := &WSEventManager{}
|
|
state := &clientState{subscriptions: map[string]*Subscription{}}
|
|
|
|
assert.False(t, em.clientSubscribed(state, EventWindowClose))
|
|
}
|
|
|
|
func TestWSEventManager_ConnectedClients_Good(t *testing.T) {
|
|
em := &WSEventManager{
|
|
clients: map[*websocket.Conn]*clientState{},
|
|
}
|
|
em.clients[nil] = &clientState{subscriptions: map[string]*Subscription{}}
|
|
|
|
assert.Equal(t, 1, em.ConnectedClients())
|
|
}
|
|
|
|
func TestWSEventManager_AttachWindowListeners_Good(t *testing.T) {
|
|
em := &WSEventManager{
|
|
clients: map[*websocket.Conn]*clientState{},
|
|
eventBuffer: make(chan Event, 1),
|
|
}
|
|
|
|
listener := &testWindowListener{}
|
|
em.AttachWindowListeners(listener)
|
|
listener.emit(window.WindowEvent{Type: "focus", Name: "main", Data: map[string]any{"reason": "user"}})
|
|
|
|
require.Len(t, em.eventBuffer, 1)
|
|
event := <-em.eventBuffer
|
|
assert.Equal(t, EventType("window.focus"), event.Type)
|
|
assert.Equal(t, "main", event.Window)
|
|
assert.Equal(t, "user", event.Data["reason"])
|
|
}
|
|
|
|
func TestWSEventManager_AttachWindowListeners_Bad(t *testing.T) {
|
|
em := &WSEventManager{}
|
|
em.AttachWindowListeners(nil)
|
|
assert.Empty(t, em.eventBuffer)
|
|
}
|
|
|
|
func TestWSEventManager_CloseIsIdempotent(t *testing.T) {
|
|
em := NewWSEventManager()
|
|
|
|
assert.NotPanics(t, func() {
|
|
em.Close()
|
|
em.Close()
|
|
em.Emit(Event{Type: EventCustomEvent})
|
|
})
|
|
}
|