From 9d3ce2df2a3ca5868ee9a20ca3fb30c19a828d21 Mon Sep 17 00:00:00 2001 From: Snider Date: Wed, 15 Apr 2026 19:38:32 +0100 Subject: [PATCH] Harden preload bridge and storage bounds --- pkg/display/display.go | 7 ++++- pkg/display/display_test.go | 15 +++++++++ pkg/display/preload.go | 61 +++++++++++++++++++++++++++++++++---- pkg/display/preload_test.go | 39 +++++++++++++++++++++++- pkg/display/storage.go | 32 +++++++++++++++++-- pkg/display/storage_test.go | 9 ++++++ 6 files changed, 153 insertions(+), 10 deletions(-) diff --git a/pkg/display/display.go b/pkg/display/display.go index 53384a86..92d98a0a 100644 --- a/pkg/display/display.go +++ b/pkg/display/display.go @@ -121,7 +121,12 @@ func (s *Service) OnStartup(_ context.Context) core.Result { bucket := opts.String("bucket") key := opts.String("key") value := opts.String("value") - s.storage.Set(origin, bucket, key, value) + if s.storage == nil { + return core.Result{Value: coreerr.E("display.storage.set", "storage registry unavailable", nil), OK: false} + } + if !s.storage.Set(origin, bucket, key, value) { + return core.Result{Value: coreerr.E("display.storage.set", "invalid storage entry", nil), OK: false} + } return core.Result{Value: map[string]string{"origin": origin, "bucket": bucket, "key": key}, OK: true} }) s.Core().Action("display.storage.search", func(_ context.Context, opts core.Options) core.Result { diff --git a/pkg/display/display_test.go b/pkg/display/display_test.go index 229ade5e..a253c13b 100644 --- a/pkg/display/display_test.go +++ b/pkg/display/display_test.go @@ -3,6 +3,7 @@ package display import ( "context" "os" + "strings" "testing" core "dappco.re/go/core" @@ -111,6 +112,20 @@ func TestConfigTask_Good(t *testing.T) { assert.Equal(t, 800, cfg["default_width"]) } +func TestStorageTask_Bad(t *testing.T) { + _, c := newTestDisplayService(t) + + r := c.Action("display.storage.set").Run(context.Background(), core.NewOptions( + core.Option{Key: "origin", Value: "core://settings"}, + core.Option{Key: "bucket", Value: "localStorage"}, + core.Option{Key: "key", Value: strings.Repeat("k", maxStorageKeyBytes+1)}, + core.Option{Key: "value", Value: "dark"}, + )) + + require.False(t, r.OK) + assert.Contains(t, r.Value.(error).Error(), "invalid storage entry") +} + func TestResolveScheme_StoreRoute_Good(t *testing.T) { svc, _ := newTestDisplayService(t) diff --git a/pkg/display/preload.go b/pkg/display/preload.go index 3ca214ac..83d8b1c8 100644 --- a/pkg/display/preload.go +++ b/pkg/display/preload.go @@ -30,17 +30,22 @@ func (s *Service) InjectPreload(webview PreloadTarget, origin string) error { // before page code runs. // Use: script, _ := display.BuildPreloadScript("https://example.com") func (s *Service) BuildPreloadScript(pageURL string) (string, error) { + trustedOrigin := trustedPreloadOrigin(pageURL) storageBootstrap := map[string]map[string]string{} if s.storage != nil { storageBootstrap = s.storage.Snapshot(pageURL) } parts := []string{ - s.injectStoragePolyfills(pageURL, storageBootstrap), - s.injectBackgroundServiceShims(), - s.injectElectronShim(), - s.injectCoreMLShim(), + s.injectStoragePolyfills(pageURL, storageBootstrap, trustedOrigin), + s.injectCoreMLShim(trustedOrigin), s.buildHLCRFComponents(pageURL), } + if trustedOrigin { + parts = append(parts, + s.injectBackgroundServiceShims(), + s.injectElectronShim(), + ) + } if appPreloads, err := s.injectAppPreloads(pageURL); err != nil { if !strings.Contains(err.Error(), "view manifest not found") { return "", err @@ -51,6 +56,39 @@ func (s *Service) BuildPreloadScript(pageURL string) (string, error) { return strings.Join(parts, "\n"), nil } +func trustedPreloadOrigin(pageURL string) bool { + trimmed := strings.TrimSpace(pageURL) + if trimmed == "" { + return false + } + parsed, err := url.Parse(trimmed) + if err != nil { + return false + } + switch strings.ToLower(parsed.Scheme) { + case "core", "file", "wails", "app": + return true + case "http", "https": + host := strings.TrimSpace(parsed.Host) + if host == "" { + return false + } + name := host + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + name = parsedHost + } + name = strings.Trim(strings.ToLower(name), "[]") + switch name { + case "localhost", "127.0.0.1", "::1": + return true + default: + return false + } + default: + return false + } +} + func validatedLocalMLAPIURL(raw string) string { trimmed := strings.TrimSpace(raw) if trimmed == "" { @@ -82,15 +120,19 @@ func validatedLocalMLAPIURL(raw string) string { } } -func (s *Service) injectStoragePolyfills(pageOrigin string, bootstrap map[string]map[string]string) string { +func (s *Service) injectStoragePolyfills(pageOrigin string, bootstrap map[string]map[string]string, trustedOrigin bool) string { return `(function() { const __corePageURL = ` + core.JSONMarshalString(pageOrigin) + `; const __coreOrigin = ` + core.JSONMarshalString(storageOriginForPageURL(pageOrigin)) + ` || __corePageURL; + const __coreCanInvoke = ` + core.JSONMarshalString(trustedOrigin) + `; const __coreBootstrapStorage = ` + core.JSONMarshalString(bootstrap) + `; const __coreScopes = globalThis.__coreStorageScopes || (globalThis.__coreStorageScopes = {}); const __scope = __coreScopes[__coreOrigin] || (__coreScopes[__coreOrigin] = { localStorage: {}, sessionStorage: {}, cookies: {}, indexedDB: {}, caches: {}, buckets: {}, opfs: {} }); const __coreBridge = globalThis.__coreBridge || (globalThis.__coreBridge = { invoke(route, payload) { + if (!__coreCanInvoke) { + return Promise.reject(new Error("Core bridge unavailable for this origin")); + } if (typeof globalThis.__CORE_GUI_INVOKE__ === 'function') { return Promise.resolve(globalThis.__CORE_GUI_INVOKE__(route, payload)); } @@ -172,6 +214,9 @@ func (s *Service) injectStoragePolyfills(pageOrigin string, bootstrap map[string hydrateBucket(bucketName, __scope[bucketName], bucket); }); const persist = (bucket, key, value) => { + if (!__coreCanInvoke) { + return; + } if (bucket === "sessionStorage") { return; } @@ -800,9 +845,10 @@ func (s *Service) injectBackgroundServiceShims() string { })();` } -func (s *Service) injectCoreMLShim() string { +func (s *Service) injectCoreMLShim(trustedOrigin bool) string { return `(function() { const __coreMLApiURL = ` + core.JSONMarshalString(validatedLocalMLAPIURL(core.Env("CORE_ML_API_URL"))) + ` || "http://localhost:8090"; + const __coreCanInvoke = ` + core.JSONMarshalString(trustedOrigin) + `; globalThis.core = globalThis.core || {}; globalThis.core.ml = globalThis.core.ml || { async generate(input) { @@ -854,6 +900,9 @@ func (s *Service) injectCoreMLShim() string { }); }, async state() { + if (!__coreCanInvoke) { + return { available: false, models: [] }; + } return invokeBridge('display.models.state', {}).then((value) => value); }, async models() { diff --git a/pkg/display/preload_test.go b/pkg/display/preload_test.go index 39f573b6..1f1681a3 100644 --- a/pkg/display/preload_test.go +++ b/pkg/display/preload_test.go @@ -33,10 +33,36 @@ func TestDisplay_Good_WindowOpenIncludesPreload(t *testing.T) { require.True(t, result.OK) require.Len(t, platform.Windows, 1) assert.NotEmpty(t, platform.Windows[0].ExecJSCalls()) - assert.Contains(t, platform.Windows[0].ExecJSCalls()[0], "globalThis.electron") assert.Contains(t, platform.Windows[0].ExecJSCalls()[0], "globalThis.core.ml") assert.Contains(t, platform.Windows[0].ExecJSCalls()[0], "globalThis.core.storage.cookies") assert.Contains(t, platform.Windows[0].ExecJSCalls()[0], "Document.prototype, 'cookie'") + assert.NotContains(t, platform.Windows[0].ExecJSCalls()[0], "globalThis.electron") + assert.NotContains(t, platform.Windows[0].ExecJSCalls()[0], "core.background.serviceWorker.register") +} + +func TestDisplay_Good_WindowOpenTrustedOriginIncludesPrivilegedBridge(t *testing.T) { + platform := window.NewMockPlatform() + c := core.New( + core.WithService(Register(nil)), + core.WithService(window.Register(platform)), + core.WithServiceLock(), + ) + require.True(t, c.ServiceStartup(context.Background(), nil).OK) + + result := c.Action("window.open").Run(context.Background(), core.NewOptions( + core.Option{Key: "task", Value: window.TaskOpenWindow{ + Options: []window.WindowOption{ + window.WithName("preload"), + window.WithURL("http://localhost:3000"), + }, + }}, + )) + require.True(t, result.OK) + require.Len(t, platform.Windows, 1) + script := platform.Windows[0].ExecJSCalls()[0] + assert.Contains(t, script, "globalThis.electron") + assert.Contains(t, script, "core.background.serviceWorker.register") + assert.Contains(t, script, "globalThis.core.ml") } func TestDisplay_Good_CoreSchemeRoutesThroughBackend(t *testing.T) { @@ -77,3 +103,14 @@ func TestPreload_ValidatedLocalMLAPIURL_Ugly(t *testing.T) { assert.Equal(t, "http://localhost:8090", validatedLocalMLAPIURL("")) assert.Equal(t, "http://localhost:8090", validatedLocalMLAPIURL("not a url")) } + +func TestPreload_TrustedPreloadOrigin_Good(t *testing.T) { + assert.True(t, trustedPreloadOrigin("core://store")) + assert.True(t, trustedPreloadOrigin("http://localhost:3000")) + assert.True(t, trustedPreloadOrigin("https://127.0.0.1:8443")) +} + +func TestPreload_TrustedPreloadOrigin_Bad(t *testing.T) { + assert.False(t, trustedPreloadOrigin("https://example.com")) + assert.False(t, trustedPreloadOrigin("http://10.0.0.1:3000")) +} diff --git a/pkg/display/storage.go b/pkg/display/storage.go index 738e4986..0ba3f004 100644 --- a/pkg/display/storage.go +++ b/pkg/display/storage.go @@ -14,6 +14,14 @@ import ( gostore "dappco.re/go/store" ) +const ( + maxStorageOriginBytes = 512 + maxStorageBucketBytes = 128 + maxStorageKeyBytes = 1024 + maxStorageValueBytes = 1 << 20 + maxStorageSearchResults = 200 +) + type StorageEntry struct { Origin string `json:"origin"` Bucket string `json:"bucket"` @@ -149,7 +157,16 @@ func (r *StorageRegistry) loadPersistedEntries() { } } -func (r *StorageRegistry) Set(origin, bucket, key, value string) { +func (r *StorageRegistry) Set(origin, bucket, key, value string) bool { + if !validStorageField(origin, maxStorageOriginBytes) || + !validStorageField(bucket, maxStorageBucketBytes) || + !validStorageField(key, maxStorageKeyBytes) || + (len(value) > maxStorageValueBytes) { + return false + } + origin = strings.TrimSpace(origin) + bucket = strings.TrimSpace(bucket) + key = strings.TrimSpace(key) r.mu.Lock() defer r.mu.Unlock() entry := StorageEntry{ @@ -162,8 +179,11 @@ func (r *StorageRegistry) Set(origin, bucket, key, value string) { composite := makeStorageEntryKey(origin, bucket, key) r.entries[composite] = entry if r.store != nil { - _ = r.store.Set("storage", storageCompositeKey(origin, bucket, key), core.JSONMarshalString(entry)) + if err := r.store.Set("storage", storageCompositeKey(origin, bucket, key), core.JSONMarshalString(entry)); err != nil { + return false + } } + return true } func (r *StorageRegistry) Get(origin, bucket, key string) (StorageEntry, bool) { @@ -206,6 +226,9 @@ func (r *StorageRegistry) Search(query string) []StorageEntry { strings.Contains(strings.ToLower(entry.Key), needle) || strings.Contains(strings.ToLower(entry.Value), needle) { results = append(results, entry) + if len(results) >= maxStorageSearchResults { + break + } } } sort.Slice(results, func(i, j int) bool { @@ -214,6 +237,11 @@ func (r *StorageRegistry) Search(query string) []StorageEntry { return results } +func validStorageField(value string, limit int) bool { + trimmed := strings.TrimSpace(value) + return trimmed != "" && len(trimmed) <= limit +} + func (r *StorageRegistry) Snapshot(pageURL string) map[string]map[string]string { r.mu.RLock() defer r.mu.RUnlock() diff --git a/pkg/display/storage_test.go b/pkg/display/storage_test.go index d6d2aae6..6df22a22 100644 --- a/pkg/display/storage_test.go +++ b/pkg/display/storage_test.go @@ -108,6 +108,15 @@ func TestStorageRegistry_Snapshot_Good(t *testing.T) { assert.False(t, otherOriginPresent) } +func TestStorageRegistry_Set_Bad(t *testing.T) { + r := NewStorageRegistry() + + assert.False(t, r.Set("", "localStorage", "theme", "dark")) + assert.False(t, r.Set("core://settings", "", "theme", "dark")) + assert.False(t, r.Set("core://settings", "localStorage", "", "dark")) + assert.False(t, r.Set("core://settings", "localStorage", "theme", strings.Repeat("x", maxStorageValueBytes+1))) +} + func TestStorage_StorageOriginForPageURL_Good(t *testing.T) { assert.Equal(t, "https://app.example.com", storageOriginForPageURL("https://app.example.com/path?q=1")) assert.Equal(t, "core://settings", storageOriginForPageURL("core://settings/view"))