Harden GUI security boundaries
Some checks are pending
Security Scan / security (push) Waiting to run
Test / test (push) Waiting to run

This commit is contained in:
Snider 2026-04-15 19:20:58 +01:00
parent 65ccf50c2b
commit 723116acb7
6 changed files with 219 additions and 21 deletions

View file

@ -754,7 +754,7 @@ func (s *Service) handleWSMessage(msg WSMessage) core.Result {
}
return c.QUERY(webview.QueryTitle{Window: w})
default:
return core.Result{}
return core.Result{Value: coreerr.E("display.handleWSMessage", "unknown websocket action: "+msg.Action, nil), OK: false}
}
}

View file

@ -86,6 +86,7 @@ type WSEventManager struct {
// clientState tracks a client's subscriptions.
type clientState struct {
subscriptions map[string]*Subscription
writeMu sync.Mutex
mu sync.RWMutex
}
@ -124,10 +125,13 @@ func trustedWebSocketOrigin(r *http.Request) bool {
if !trustedWebSocketHost(r.Host) {
return false
}
if !trustedWSRequestOrigin(r.RemoteAddr) {
return false
}
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" || strings.EqualFold(origin, "null") {
return trustedWSRequestOrigin(r.RemoteAddr)
return true
}
parsed, err := url.Parse(origin)
@ -221,10 +225,10 @@ func (em *WSEventManager) clientSubscribed(state *clientState, eventType EventTy
// sendEvent sends an event to a specific client.
func (em *WSEventManager) sendEvent(conn *websocket.Conn, event Event) {
em.mu.RLock()
_, exists := em.clients[conn]
state, exists := em.clients[conn]
em.mu.RUnlock()
if !exists {
if !exists || state == nil {
return
}
@ -234,8 +238,11 @@ func (em *WSEventManager) sendEvent(conn *websocket.Conn, event Event) {
}
data, _ := marshalResult.Value.([]byte)
state.writeMu.Lock()
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
err := conn.WriteMessage(websocket.TextMessage, data)
state.writeMu.Unlock()
if err != nil {
em.removeClient(conn)
}
}
@ -258,6 +265,8 @@ func (em *WSEventManager) HandleWebSocket(w http.ResponseWriter, r *http.Request
}
em.mu.Unlock()
conn.SetReadLimit(64 * 1024)
// Handle incoming messages
go em.handleMessages(conn)
}
@ -279,9 +288,11 @@ func (em *WSEventManager) handleMessages(conn *websocket.Conn) {
}
if unmarshalResult := core.JSONUnmarshal(message, &msg); !unmarshalResult.OK {
continue
em.closeWithPolicyViolation(conn, "invalid websocket message")
return
}
handled := true
switch msg.Action {
case "subscribe":
em.subscribe(conn, msg.ID, msg.EventTypes)
@ -289,10 +300,28 @@ func (em *WSEventManager) handleMessages(conn *websocket.Conn) {
em.unsubscribe(conn, msg.ID)
case "list":
em.listSubscriptions(conn)
default:
handled = false
}
if !handled {
em.closeWithPolicyViolation(conn, "unknown websocket action")
return
}
}
}
func (em *WSEventManager) closeWithPolicyViolation(conn *websocket.Conn, reason string) {
em.mu.RLock()
state, exists := em.clients[conn]
em.mu.RUnlock()
if !exists || state == nil {
return
}
state.writeMu.Lock()
defer state.writeMu.Unlock()
_ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, reason), time.Now().Add(2*time.Second))
}
// subscribe adds a subscription for a client.
func (em *WSEventManager) subscribe(conn *websocket.Conn, id string, eventTypes []EventType) {
em.mu.RLock()
@ -326,7 +355,7 @@ func (em *WSEventManager) subscribe(conn *websocket.Conn, id string, eventTypes
}
if marshalResult := core.JSONMarshal(response); marshalResult.OK {
responseData, _ := marshalResult.Value.([]byte)
conn.WriteMessage(websocket.TextMessage, responseData)
em.writeClientMessage(state, conn, responseData)
}
}
@ -351,7 +380,7 @@ func (em *WSEventManager) unsubscribe(conn *websocket.Conn, id string) {
}
if marshalResult := core.JSONMarshal(response); marshalResult.OK {
responseData, _ := marshalResult.Value.([]byte)
conn.WriteMessage(websocket.TextMessage, responseData)
em.writeClientMessage(state, conn, responseData)
}
}
@ -378,10 +407,17 @@ func (em *WSEventManager) listSubscriptions(conn *websocket.Conn) {
}
if marshalResult := core.JSONMarshal(response); marshalResult.OK {
responseData, _ := marshalResult.Value.([]byte)
conn.WriteMessage(websocket.TextMessage, responseData)
em.writeClientMessage(state, conn, responseData)
}
}
func (em *WSEventManager) writeClientMessage(state *clientState, conn *websocket.Conn, data []byte) {
state.writeMu.Lock()
defer state.writeMu.Unlock()
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
_ = conn.WriteMessage(websocket.TextMessage, data)
}
// removeClient removes a client and its subscriptions.
func (em *WSEventManager) removeClient(conn *websocket.Conn) {
em.mu.Lock()

View file

@ -153,6 +153,41 @@ func TestWSEventManager_HandleWebSocket_RejectsRemoteOrigin(t *testing.T) {
assert.Equal(t, http.StatusForbidden, recorder.Code)
}
func TestWSEventManager_HandleWebSocket_RejectsLoopbackSpoofedOrigin(t *testing.T) {
em := NewWSEventManager()
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil)
req.RemoteAddr = "203.0.113.10:12345"
req.Header.Set("Origin", "file://malicious")
recorder := httptest.NewRecorder()
em.HandleWebSocket(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
}
func TestWSEventManager_HandleWebSocket_ClosesOnMalformedMessage(t *testing.T) {
em := NewWSEventManager()
conn, cleanup := dialWSEventManager(t, em)
defer cleanup()
require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"action":`)))
_, _, err := conn.ReadMessage()
require.Error(t, err)
}
func TestWSEventManager_HandleWebSocket_ClosesOnUnknownAction(t *testing.T) {
em := NewWSEventManager()
conn, cleanup := dialWSEventManager(t, em)
defer cleanup()
require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"action":"bogus"}`)))
_, _, err := conn.ReadMessage()
require.Error(t, err)
}
func TestWSEventManager_Emit_Ugly(t *testing.T) {
em := &WSEventManager{
clients: map[*websocket.Conn]*clientState{},

View file

@ -1,6 +1,8 @@
package display
import (
"net"
"net/url"
"strings"
core "dappco.re/go/core"
@ -39,12 +41,47 @@ func (s *Service) BuildPreloadScript(pageURL string) (string, error) {
s.injectCoreMLShim(),
s.buildHLCRFComponents(pageURL),
}
if appPreloads, err := s.injectAppPreloads(pageURL); err == nil && strings.TrimSpace(appPreloads) != "" {
if appPreloads, err := s.injectAppPreloads(pageURL); err != nil {
if !strings.Contains(err.Error(), "view manifest not found") {
return "", err
}
} else if strings.TrimSpace(appPreloads) != "" {
parts = append(parts, appPreloads)
}
return strings.Join(parts, "\n"), nil
}
func validatedLocalMLAPIURL(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "http://localhost:8090"
}
parsed, err := url.Parse(trimmed)
if err != nil {
return "http://localhost:8090"
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
default:
return "http://localhost:8090"
}
host := strings.TrimSpace(parsed.Host)
if host == "" {
return "http://localhost:8090"
}
name := host
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
name = parsedHost
}
name = strings.Trim(strings.ToLower(name), "[]")
switch name {
case "localhost", "127.0.0.1", "::1":
return strings.TrimRight(parsed.String(), "/")
default:
return "http://localhost:8090"
}
}
func (s *Service) injectStoragePolyfills(pageOrigin string, bootstrap map[string]map[string]string) string {
return `(function() {
const __corePageURL = ` + core.JSONMarshalString(pageOrigin) + `;
@ -569,7 +606,7 @@ func (s *Service) injectBackgroundServiceShims() string {
func (s *Service) injectCoreMLShim() string {
return `(function() {
const __coreMLApiURL = ` + core.JSONMarshalString(strings.TrimRight(core.Env("CORE_ML_API_URL"), "/")) + ` || "http://localhost:8090";
const __coreMLApiURL = ` + core.JSONMarshalString(validatedLocalMLAPIURL(core.Env("CORE_ML_API_URL"))) + ` || "http://localhost:8090";
globalThis.core = globalThis.core || {};
globalThis.core.ml = globalThis.core.ml || {
async generate(input) {

View file

@ -14,6 +14,7 @@ import (
"os/exec"
"path/filepath"
"strings"
"unicode"
"gopkg.in/yaml.v3"
)
@ -40,6 +41,8 @@ type Installer struct {
InstallDir string
}
const maxManifestBytes = 1 << 20
func (i Installer) FetchManifest(ctx context.Context, manifestURL string) (Manifest, error) {
client := i.HTTPClient
if client == nil {
@ -57,10 +60,13 @@ func (i Installer) FetchManifest(ctx context.Context, manifestURL string) (Manif
if resp.StatusCode >= http.StatusBadRequest {
return Manifest{}, fmt.Errorf("manifest fetch failed: %s", resp.Status)
}
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(io.LimitReader(resp.Body, maxManifestBytes+1))
if err != nil {
return Manifest{}, err
}
if len(body) > maxManifestBytes {
return Manifest{}, fmt.Errorf("manifest fetch failed: manifest exceeds %d bytes", maxManifestBytes)
}
var manifest Manifest
if err := yaml.Unmarshal(body, &manifest); err != nil {
return Manifest{}, err
@ -72,8 +78,11 @@ func (i Installer) FetchManifest(ctx context.Context, manifestURL string) (Manif
}
func VerifyManifest(manifest Manifest) error {
if strings.ToLower(strings.TrimSpace(manifest.Signature.Algorithm)) != "ed25519" {
return errors.New("manifest signature algorithm must be ed25519")
}
if manifest.Signature.Value == "" || manifest.Signature.PublicKey == "" {
return nil
return errors.New("manifest signature is required")
}
payload := manifest.Name + "\n" + manifest.Version + "\n" + manifest.Repository + "\n" + manifest.Ref
signature, err := base64.StdEncoding.DecodeString(manifest.Signature.Value)
@ -84,6 +93,12 @@ func VerifyManifest(manifest Manifest) error {
if err != nil {
return err
}
if len(signature) != ed25519.SignatureSize {
return errors.New("manifest signature has invalid size")
}
if len(publicKey) != ed25519.PublicKeySize {
return errors.New("manifest public key has invalid size")
}
if !ed25519.Verify(ed25519.PublicKey(publicKey), []byte(payload), signature) {
return errors.New("manifest signature verification failed")
}
@ -91,16 +106,34 @@ func VerifyManifest(manifest Manifest) error {
}
func (i Installer) Install(ctx context.Context, manifest Manifest) (string, error) {
if strings.TrimSpace(i.InstallDir) == "" {
return "", errors.New("install dir is required")
}
if err := VerifyManifest(manifest); err != nil {
return "", err
}
if strings.TrimSpace(i.InstallDir) == "" {
return "", errors.New("install dir is required")
if err := validateManifestName(manifest.Name); err != nil {
return "", err
}
if err := os.MkdirAll(i.InstallDir, 0o755); err != nil {
return "", err
}
targetDir := filepath.Join(i.InstallDir, safeName(manifest.Name))
rootAbs, err := filepath.Abs(i.InstallDir)
if err != nil {
return "", err
}
targetAbs, err := filepath.Abs(targetDir)
if err != nil {
return "", err
}
rel, err := filepath.Rel(rootAbs, targetAbs)
if err != nil {
return "", err
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
return "", errors.New("install path escapes install dir")
}
_ = os.RemoveAll(targetDir)
args := []string{"clone", "--depth", "1"}
if manifest.Ref != "" {
@ -118,6 +151,20 @@ func (i Installer) Install(ctx context.Context, manifest Manifest) (string, erro
return targetDir, nil
}
func validateManifestName(value string) error {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return errors.New("manifest name is required")
}
if strings.ContainsAny(trimmed, `/\`) {
return errors.New("manifest name must not contain path separators")
}
if strings.Contains(trimmed, "..") {
return errors.New("manifest name must not contain path traversal segments")
}
return nil
}
func DigestManifest(manifest Manifest) string {
hash := sha256.Sum256([]byte(manifest.Name + ":" + manifest.Version + ":" + manifest.Repository + ":" + manifest.Ref))
return hex.EncodeToString(hash[:])
@ -125,9 +172,29 @@ func DigestManifest(manifest Manifest) string {
func safeName(value string) string {
value = strings.TrimSpace(strings.ToLower(value))
value = strings.ReplaceAll(value, " ", "-")
if value == "" {
return "module"
}
return value
var builder strings.Builder
lastDash := false
for _, r := range value {
switch {
case unicode.IsLetter(r), unicode.IsDigit(r):
builder.WriteRune(r)
lastDash = false
case r == '-' || r == '_' || r == '.':
builder.WriteRune(r)
lastDash = false
default:
if !lastDash {
builder.WriteRune('-')
lastDash = true
}
}
}
cleaned := strings.Trim(builder.String(), "-._")
if cleaned == "" {
return "module"
}
return cleaned
}

View file

@ -112,6 +112,17 @@ func TestMarketplace_VerifyManifest_Ugly(t *testing.T) {
require.Error(t, VerifyManifest(manifest))
}
func TestMarketplace_VerifyManifest_RequiresSignature(t *testing.T) {
manifest := Manifest{
Name: "core-ui",
Version: "1.2.3",
Repository: "https://example.com/core-ui.git",
Ref: "main",
}
require.Error(t, VerifyManifest(manifest))
}
func TestMarketplace_Install_Good(t *testing.T) {
scriptDir := t.TempDir()
logFile := filepath.Join(scriptDir, "git.log")
@ -125,12 +136,12 @@ func TestMarketplace_Install_Good(t *testing.T) {
InstallDir: targetRoot,
}
targetDir, err := installer.Install(context.Background(), Manifest{
targetDir, err := installer.Install(context.Background(), signedManifest(t, Manifest{
Name: "Core UI",
Version: "1.2.3",
Repository: "https://example.com/core-ui.git",
Ref: "main",
})
}))
require.NoError(t, err)
assert.Equal(t, filepath.Join(targetRoot, "core-ui"), targetDir)
_, err = os.Stat(targetDir)
@ -150,6 +161,18 @@ func TestMarketplace_Install_Bad(t *testing.T) {
assert.Contains(t, err.Error(), "install dir is required")
}
func TestMarketplace_Install_RejectsTraversalName(t *testing.T) {
installer := Installer{InstallDir: t.TempDir()}
_, err := installer.Install(context.Background(), signedManifest(t, Manifest{
Name: "../../escape",
Version: "1.2.3",
Repository: "https://example.com/core-ui.git",
Ref: "main",
}))
require.Error(t, err)
assert.Contains(t, err.Error(), "path separators")
}
func TestMarketplace_Install_Ugly(t *testing.T) {
scriptDir := t.TempDir()
scriptPath := filepath.Join(scriptDir, "git")
@ -160,10 +183,10 @@ func TestMarketplace_Install_Ugly(t *testing.T) {
InstallDir: t.TempDir(),
}
_, err := installer.Install(context.Background(), Manifest{
_, err := installer.Install(context.Background(), signedManifest(t, Manifest{
Name: "core-ui",
Repository: "https://example.com/core-ui.git",
})
}))
require.Error(t, err)
assert.Contains(t, err.Error(), "git clone failed")
}