From 69e1fbe7e3652b771237d39a61d3a3b767e2a56c Mon Sep 17 00:00:00 2001 From: Snider Date: Fri, 17 Apr 2026 19:52:26 +0100 Subject: [PATCH] Validate websocket layout numeric fields --- .core/TODO.md | 1 - pkg/display/display.go | 114 ++++++++++++++-- pkg/display/display_test.go | 264 ++++++++++++++++++++++++++++++++++++ 3 files changed, 370 insertions(+), 9 deletions(-) diff --git a/.core/TODO.md b/.core/TODO.md index 874ce2e8..e69de29b 100644 --- a/.core/TODO.md +++ b/.core/TODO.md @@ -1 +0,0 @@ -- @security pkg/display/display.go:954 — WebSocket layout commands still coerce missing or malformed numeric fields to zero instead of rejecting the request. diff --git a/pkg/display/display.go b/pkg/display/display.go index 10782301..cf3b5eb1 100644 --- a/pkg/display/display.go +++ b/pkg/display/display.go @@ -3,6 +3,7 @@ package display import ( "context" "errors" + "math" "net/url" "runtime" "sync" @@ -514,6 +515,85 @@ func wsRequire(data map[string]any, key string) (string, error) { return requireStringField(data, key) } +func requireFloatField(data map[string]any, key string) (float64, error) { + value, ok := data[key] + if !ok || value == nil { + return 0, coreerr.E("display.handleWSMessage", "missing required field \""+key+"\"", nil) + } + + switch v := value.(type) { + case float64: + return v, nil + case float32: + return float64(v), nil + case int: + return float64(v), nil + case int8: + return float64(v), nil + case int16: + return float64(v), nil + case int32: + return float64(v), nil + case int64: + return float64(v), nil + case uint: + return float64(v), nil + case uint8: + return float64(v), nil + case uint16: + return float64(v), nil + case uint32: + return float64(v), nil + case uint64: + return float64(v), nil + default: + return 0, coreerr.E("display.handleWSMessage", "invalid required field \""+key+"\"", nil) + } +} + +func requireIntField(data map[string]any, key string) (int, error) { + value, ok := data[key] + if !ok || value == nil { + return 0, coreerr.E("display.handleWSMessage", "missing required field \""+key+"\"", nil) + } + + switch v := value.(type) { + case int: + return v, nil + case int8: + return int(v), nil + case int16: + return int(v), nil + case int32: + return int(v), nil + case int64: + return int(v), nil + case uint: + return int(v), nil + case uint8: + return int(v), nil + case uint16: + return int(v), nil + case uint32: + return int(v), nil + case uint64: + return int(v), nil + case float64: + if math.Trunc(v) != v { + return 0, coreerr.E("display.handleWSMessage", "invalid required field \""+key+"\"", nil) + } + return int(v), nil + case float32: + f := float64(v) + if math.Trunc(f) != f { + return 0, coreerr.E("display.handleWSMessage", "invalid required field \""+key+"\"", nil) + } + return int(f), nil + default: + return 0, coreerr.E("display.handleWSMessage", "invalid required field \""+key+"\"", nil) + } +} + func optionsFromMap(data map[string]any) core.Options { items := make([]core.Option, 0, len(data)) for key, value := range data { @@ -926,7 +1006,10 @@ func (s *Service) handleWSMessage(msg WSMessage) core.Result { } editor, _ := msg.Data["editor"].(string) side, _ := msg.Data["side"].(string) - ratio, _ := msg.Data["ratio"].(float64) + ratio, e := requireFloatField(msg.Data, "ratio") + if e != nil { + return core.Result{Value: e, OK: false} + } return c.Action("window.layoutBesideEditor").Run(ctx, core.NewOptions( core.Option{Key: "task", Value: window.TaskLayoutBesideEditor{ Name: name, Editor: editor, Side: side, Ratio: ratio, @@ -934,20 +1017,32 @@ func (s *Service) handleWSMessage(msg WSMessage) core.Result { )) case "layout:suggest": screenID, _ := msg.Data["screen_id"].(string) - windowCount, _ := msg.Data["window_count"].(float64) + windowCount, e := requireIntField(msg.Data, "window_count") + if e != nil { + return core.Result{Value: e, OK: false} + } return c.Action("window.layoutSuggest").Run(ctx, core.NewOptions( core.Option{Key: "task", Value: window.TaskLayoutSuggest{ - ScreenID: screenID, WindowCount: int(windowCount), + ScreenID: screenID, WindowCount: windowCount, }}, )) case "screen:find-space": screenID, _ := msg.Data["screen_id"].(string) - width, _ := msg.Data["width"].(float64) - height, _ := msg.Data["height"].(float64) - padding, _ := msg.Data["padding"].(float64) + width, e := requireIntField(msg.Data, "width") + if e != nil { + return core.Result{Value: e, OK: false} + } + height, e := requireIntField(msg.Data, "height") + if e != nil { + return core.Result{Value: e, OK: false} + } + padding, e := requireIntField(msg.Data, "padding") + if e != nil { + return core.Result{Value: e, OK: false} + } return c.Action("window.findSpace").Run(ctx, core.NewOptions( core.Option{Key: "task", Value: window.TaskScreenFindSpace{ - ScreenID: screenID, Width: int(width), Height: int(height), Padding: int(padding), + ScreenID: screenID, Width: width, Height: height, Padding: padding, }}, )) case "window:arrange-pair": @@ -960,7 +1055,10 @@ func (s *Service) handleWSMessage(msg WSMessage) core.Result { return core.Result{Value: e, OK: false} } screenID, _ := msg.Data["screen_id"].(string) - ratio, _ := msg.Data["ratio"].(float64) + ratio, e := requireFloatField(msg.Data, "ratio") + if e != nil { + return core.Result{Value: e, OK: false} + } return c.Action("window.arrangePair").Run(ctx, core.NewOptions( core.Option{Key: "task", Value: window.TaskWindowArrangePair{ Primary: primary, Secondary: secondary, ScreenID: screenID, Ratio: ratio, diff --git a/pkg/display/display_test.go b/pkg/display/display_test.go index 5d703c61..9c0c1898 100644 --- a/pkg/display/display_test.go +++ b/pkg/display/display_test.go @@ -676,6 +676,270 @@ func TestDisplay_handleWSMessage_Ugly(t *testing.T) { assert.Contains(t, result.Value.(error).Error(), "missing required field \"opacity\"") } +func TestDisplay_handleWSMessage_LayoutCommands_Good(t *testing.T) { + cases := []struct { + name string + action string + msg WSMessage + check func(*testing.T, core.Options) + }{ + { + name: "LayoutBesideEditor", + action: "window.layoutBesideEditor", + msg: WSMessage{ + Action: "layout:beside-editor", + Data: map[string]any{ + "name": "preview", + "editor": "code", + "side": "right", + "ratio": 0.62, + }, + }, + check: func(t *testing.T, opts core.Options) { + t.Helper() + task := opts.Get("task").Value.(window.TaskLayoutBesideEditor) + assert.Equal(t, "preview", task.Name) + assert.Equal(t, "code", task.Editor) + assert.Equal(t, "right", task.Side) + assert.InDelta(t, 0.62, task.Ratio, 0.0001) + }, + }, + { + name: "LayoutSuggest", + action: "window.layoutSuggest", + msg: WSMessage{ + Action: "layout:suggest", + Data: map[string]any{ + "screen_id": "screen-1", + "window_count": 3, + }, + }, + check: func(t *testing.T, opts core.Options) { + t.Helper() + task := opts.Get("task").Value.(window.TaskLayoutSuggest) + assert.Equal(t, "screen-1", task.ScreenID) + assert.Equal(t, 3, task.WindowCount) + }, + }, + { + name: "FindScreenSpace", + action: "window.findSpace", + msg: WSMessage{ + Action: "screen:find-space", + Data: map[string]any{ + "screen_id": "screen-1", + "width": 800, + "height": 600, + "padding": 24, + }, + }, + check: func(t *testing.T, opts core.Options) { + t.Helper() + task := opts.Get("task").Value.(window.TaskScreenFindSpace) + assert.Equal(t, "screen-1", task.ScreenID) + assert.Equal(t, 800, task.Width) + assert.Equal(t, 600, task.Height) + assert.Equal(t, 24, task.Padding) + }, + }, + { + name: "ArrangeWindowPair", + action: "window.arrangePair", + msg: WSMessage{ + Action: "window:arrange-pair", + Data: map[string]any{ + "primary": "editor", + "secondary": "preview", + "screen_id": "screen-1", + "ratio": 0.55, + }, + }, + check: func(t *testing.T, opts core.Options) { + t.Helper() + task := opts.Get("task").Value.(window.TaskWindowArrangePair) + assert.Equal(t, "editor", task.Primary) + assert.Equal(t, "preview", task.Secondary) + assert.Equal(t, "screen-1", task.ScreenID) + assert.InDelta(t, 0.55, task.Ratio, 0.0001) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svc, c := newTestDisplayAPIService(t) + called := false + c.Action(tc.action, func(_ context.Context, opts core.Options) core.Result { + called = true + tc.check(t, opts) + return core.Result{OK: true} + }) + + result := svc.handleWSMessage(tc.msg) + require.True(t, result.OK) + assert.True(t, called) + }) + } +} + +func TestDisplay_handleWSMessage_LayoutCommands_Bad(t *testing.T) { + cases := []struct { + name string + action string + msg WSMessage + field string + }{ + { + name: "LayoutBesideEditor", + action: "window.layoutBesideEditor", + msg: WSMessage{ + Action: "layout:beside-editor", + Data: map[string]any{ + "name": "preview", + "editor": "code", + "side": "right", + }, + }, + field: "ratio", + }, + { + name: "LayoutSuggest", + action: "window.layoutSuggest", + msg: WSMessage{ + Action: "layout:suggest", + Data: map[string]any{ + "screen_id": "screen-1", + }, + }, + field: "window_count", + }, + { + name: "FindScreenSpace", + action: "window.findSpace", + msg: WSMessage{ + Action: "screen:find-space", + Data: map[string]any{ + "screen_id": "screen-1", + "width": 800, + "height": 600, + }, + }, + field: "padding", + }, + { + name: "ArrangeWindowPair", + action: "window.arrangePair", + msg: WSMessage{ + Action: "window:arrange-pair", + Data: map[string]any{ + "primary": "editor", + "secondary": "preview", + "screen_id": "screen-1", + }, + }, + field: "ratio", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svc, c := newTestDisplayAPIService(t) + called := false + c.Action(tc.action, func(_ context.Context, _ core.Options) core.Result { + called = true + return core.Result{OK: true} + }) + + result := svc.handleWSMessage(tc.msg) + + require.False(t, result.OK) + assert.False(t, called) + assert.Contains(t, result.Value.(error).Error(), "missing required field \""+tc.field+"\"") + }) + } +} + +func TestDisplay_handleWSMessage_LayoutCommands_Ugly(t *testing.T) { + cases := []struct { + name string + action string + msg WSMessage + field string + }{ + { + name: "LayoutBesideEditor", + action: "window.layoutBesideEditor", + msg: WSMessage{ + Action: "layout:beside-editor", + Data: map[string]any{ + "name": "preview", + "editor": "code", + "side": "right", + "ratio": "0.62", + }, + }, + field: "ratio", + }, + { + name: "LayoutSuggest", + action: "window.layoutSuggest", + msg: WSMessage{ + Action: "layout:suggest", + Data: map[string]any{ + "screen_id": "screen-1", + "window_count": 2.5, + }, + }, + field: "window_count", + }, + { + name: "FindScreenSpace", + action: "window.findSpace", + msg: WSMessage{ + Action: "screen:find-space", + Data: map[string]any{ + "screen_id": "screen-1", + "width": "800", + "height": 600, + "padding": 24, + }, + }, + field: "width", + }, + { + name: "ArrangeWindowPair", + action: "window.arrangePair", + msg: WSMessage{ + Action: "window:arrange-pair", + Data: map[string]any{ + "primary": "editor", + "secondary": "preview", + "screen_id": "screen-1", + "ratio": true, + }, + }, + field: "ratio", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svc, c := newTestDisplayAPIService(t) + called := false + c.Action(tc.action, func(_ context.Context, _ core.Options) core.Result { + called = true + return core.Result{OK: true} + }) + + result := svc.handleWSMessage(tc.msg) + + require.False(t, result.OK) + assert.False(t, called) + assert.Contains(t, result.Value.(error).Error(), "invalid required field \""+tc.field+"\"") + }) + } +} + func TestDisplay_handleTrayAction_Good(t *testing.T) { platform := window.NewMockPlatform() c := core.New(