diff --git a/pkg/display/display.go b/pkg/display/display.go index 1ee1bb9e..2cc0d97a 100644 --- a/pkg/display/display.go +++ b/pkg/display/display.go @@ -754,7 +754,7 @@ func (s *Service) handleWSMessage(msg WSMessage) core.Result { } return c.QUERY(webview.QueryTitle{Window: w}) default: - return core.Result{} + return core.Result{Value: coreerr.E("display.handleWSMessage", "unknown websocket action: "+msg.Action, nil), OK: false} } } diff --git a/pkg/display/events.go b/pkg/display/events.go index ed688a04..29483aa5 100644 --- a/pkg/display/events.go +++ b/pkg/display/events.go @@ -86,6 +86,7 @@ type WSEventManager struct { // clientState tracks a client's subscriptions. type clientState struct { subscriptions map[string]*Subscription + writeMu sync.Mutex mu sync.RWMutex } @@ -124,10 +125,13 @@ func trustedWebSocketOrigin(r *http.Request) bool { if !trustedWebSocketHost(r.Host) { return false } + if !trustedWSRequestOrigin(r.RemoteAddr) { + return false + } origin := strings.TrimSpace(r.Header.Get("Origin")) if origin == "" || strings.EqualFold(origin, "null") { - return trustedWSRequestOrigin(r.RemoteAddr) + return true } parsed, err := url.Parse(origin) @@ -221,10 +225,10 @@ func (em *WSEventManager) clientSubscribed(state *clientState, eventType EventTy // sendEvent sends an event to a specific client. func (em *WSEventManager) sendEvent(conn *websocket.Conn, event Event) { em.mu.RLock() - _, exists := em.clients[conn] + state, exists := em.clients[conn] em.mu.RUnlock() - if !exists { + if !exists || state == nil { return } @@ -234,8 +238,11 @@ func (em *WSEventManager) sendEvent(conn *websocket.Conn, event Event) { } data, _ := marshalResult.Value.([]byte) + state.writeMu.Lock() conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { + err := conn.WriteMessage(websocket.TextMessage, data) + state.writeMu.Unlock() + if err != nil { em.removeClient(conn) } } @@ -258,6 +265,8 @@ func (em *WSEventManager) HandleWebSocket(w http.ResponseWriter, r *http.Request } em.mu.Unlock() + conn.SetReadLimit(64 * 1024) + // Handle incoming messages go em.handleMessages(conn) } @@ -279,9 +288,11 @@ func (em *WSEventManager) handleMessages(conn *websocket.Conn) { } if unmarshalResult := core.JSONUnmarshal(message, &msg); !unmarshalResult.OK { - continue + em.closeWithPolicyViolation(conn, "invalid websocket message") + return } + handled := true switch msg.Action { case "subscribe": em.subscribe(conn, msg.ID, msg.EventTypes) @@ -289,10 +300,28 @@ func (em *WSEventManager) handleMessages(conn *websocket.Conn) { em.unsubscribe(conn, msg.ID) case "list": em.listSubscriptions(conn) + default: + handled = false + } + if !handled { + em.closeWithPolicyViolation(conn, "unknown websocket action") + return } } } +func (em *WSEventManager) closeWithPolicyViolation(conn *websocket.Conn, reason string) { + em.mu.RLock() + state, exists := em.clients[conn] + em.mu.RUnlock() + if !exists || state == nil { + return + } + state.writeMu.Lock() + defer state.writeMu.Unlock() + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, reason), time.Now().Add(2*time.Second)) +} + // subscribe adds a subscription for a client. func (em *WSEventManager) subscribe(conn *websocket.Conn, id string, eventTypes []EventType) { em.mu.RLock() @@ -326,7 +355,7 @@ func (em *WSEventManager) subscribe(conn *websocket.Conn, id string, eventTypes } if marshalResult := core.JSONMarshal(response); marshalResult.OK { responseData, _ := marshalResult.Value.([]byte) - conn.WriteMessage(websocket.TextMessage, responseData) + em.writeClientMessage(state, conn, responseData) } } @@ -351,7 +380,7 @@ func (em *WSEventManager) unsubscribe(conn *websocket.Conn, id string) { } if marshalResult := core.JSONMarshal(response); marshalResult.OK { responseData, _ := marshalResult.Value.([]byte) - conn.WriteMessage(websocket.TextMessage, responseData) + em.writeClientMessage(state, conn, responseData) } } @@ -378,10 +407,17 @@ func (em *WSEventManager) listSubscriptions(conn *websocket.Conn) { } if marshalResult := core.JSONMarshal(response); marshalResult.OK { responseData, _ := marshalResult.Value.([]byte) - conn.WriteMessage(websocket.TextMessage, responseData) + em.writeClientMessage(state, conn, responseData) } } +func (em *WSEventManager) writeClientMessage(state *clientState, conn *websocket.Conn, data []byte) { + state.writeMu.Lock() + defer state.writeMu.Unlock() + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + _ = conn.WriteMessage(websocket.TextMessage, data) +} + // removeClient removes a client and its subscriptions. func (em *WSEventManager) removeClient(conn *websocket.Conn) { em.mu.Lock() diff --git a/pkg/display/events_test.go b/pkg/display/events_test.go index 52cf55c9..e4dfe6d2 100644 --- a/pkg/display/events_test.go +++ b/pkg/display/events_test.go @@ -153,6 +153,41 @@ func TestWSEventManager_HandleWebSocket_RejectsRemoteOrigin(t *testing.T) { assert.Equal(t, http.StatusForbidden, recorder.Code) } +func TestWSEventManager_HandleWebSocket_RejectsLoopbackSpoofedOrigin(t *testing.T) { + em := NewWSEventManager() + + req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil) + req.RemoteAddr = "203.0.113.10:12345" + req.Header.Set("Origin", "file://malicious") + recorder := httptest.NewRecorder() + + em.HandleWebSocket(recorder, req) + + assert.Equal(t, http.StatusForbidden, recorder.Code) +} + +func TestWSEventManager_HandleWebSocket_ClosesOnMalformedMessage(t *testing.T) { + em := NewWSEventManager() + conn, cleanup := dialWSEventManager(t, em) + defer cleanup() + + require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"action":`))) + + _, _, err := conn.ReadMessage() + require.Error(t, err) +} + +func TestWSEventManager_HandleWebSocket_ClosesOnUnknownAction(t *testing.T) { + em := NewWSEventManager() + conn, cleanup := dialWSEventManager(t, em) + defer cleanup() + + require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"action":"bogus"}`))) + + _, _, err := conn.ReadMessage() + require.Error(t, err) +} + func TestWSEventManager_Emit_Ugly(t *testing.T) { em := &WSEventManager{ clients: map[*websocket.Conn]*clientState{}, diff --git a/pkg/display/preload.go b/pkg/display/preload.go index db9d5e99..d2ef8abf 100644 --- a/pkg/display/preload.go +++ b/pkg/display/preload.go @@ -1,6 +1,8 @@ package display import ( + "net" + "net/url" "strings" core "dappco.re/go/core" @@ -39,12 +41,47 @@ func (s *Service) BuildPreloadScript(pageURL string) (string, error) { s.injectCoreMLShim(), s.buildHLCRFComponents(pageURL), } - if appPreloads, err := s.injectAppPreloads(pageURL); err == nil && strings.TrimSpace(appPreloads) != "" { + if appPreloads, err := s.injectAppPreloads(pageURL); err != nil { + if !strings.Contains(err.Error(), "view manifest not found") { + return "", err + } + } else if strings.TrimSpace(appPreloads) != "" { parts = append(parts, appPreloads) } return strings.Join(parts, "\n"), nil } +func validatedLocalMLAPIURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "http://localhost:8090" + } + parsed, err := url.Parse(trimmed) + if err != nil { + return "http://localhost:8090" + } + switch strings.ToLower(parsed.Scheme) { + case "http", "https": + default: + return "http://localhost:8090" + } + host := strings.TrimSpace(parsed.Host) + if host == "" { + return "http://localhost:8090" + } + name := host + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + name = parsedHost + } + name = strings.Trim(strings.ToLower(name), "[]") + switch name { + case "localhost", "127.0.0.1", "::1": + return strings.TrimRight(parsed.String(), "/") + default: + return "http://localhost:8090" + } +} + func (s *Service) injectStoragePolyfills(pageOrigin string, bootstrap map[string]map[string]string) string { return `(function() { const __corePageURL = ` + core.JSONMarshalString(pageOrigin) + `; @@ -569,7 +606,7 @@ func (s *Service) injectBackgroundServiceShims() string { func (s *Service) injectCoreMLShim() string { return `(function() { - const __coreMLApiURL = ` + core.JSONMarshalString(strings.TrimRight(core.Env("CORE_ML_API_URL"), "/")) + ` || "http://localhost:8090"; + const __coreMLApiURL = ` + core.JSONMarshalString(validatedLocalMLAPIURL(core.Env("CORE_ML_API_URL"))) + ` || "http://localhost:8090"; globalThis.core = globalThis.core || {}; globalThis.core.ml = globalThis.core.ml || { async generate(input) { diff --git a/pkg/marketplace/marketplace.go b/pkg/marketplace/marketplace.go index 5502c390..cb60ad79 100644 --- a/pkg/marketplace/marketplace.go +++ b/pkg/marketplace/marketplace.go @@ -14,6 +14,7 @@ import ( "os/exec" "path/filepath" "strings" + "unicode" "gopkg.in/yaml.v3" ) @@ -40,6 +41,8 @@ type Installer struct { InstallDir string } +const maxManifestBytes = 1 << 20 + func (i Installer) FetchManifest(ctx context.Context, manifestURL string) (Manifest, error) { client := i.HTTPClient if client == nil { @@ -57,10 +60,13 @@ func (i Installer) FetchManifest(ctx context.Context, manifestURL string) (Manif if resp.StatusCode >= http.StatusBadRequest { return Manifest{}, fmt.Errorf("manifest fetch failed: %s", resp.Status) } - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxManifestBytes+1)) if err != nil { return Manifest{}, err } + if len(body) > maxManifestBytes { + return Manifest{}, fmt.Errorf("manifest fetch failed: manifest exceeds %d bytes", maxManifestBytes) + } var manifest Manifest if err := yaml.Unmarshal(body, &manifest); err != nil { return Manifest{}, err @@ -72,8 +78,11 @@ func (i Installer) FetchManifest(ctx context.Context, manifestURL string) (Manif } func VerifyManifest(manifest Manifest) error { + if strings.ToLower(strings.TrimSpace(manifest.Signature.Algorithm)) != "ed25519" { + return errors.New("manifest signature algorithm must be ed25519") + } if manifest.Signature.Value == "" || manifest.Signature.PublicKey == "" { - return nil + return errors.New("manifest signature is required") } payload := manifest.Name + "\n" + manifest.Version + "\n" + manifest.Repository + "\n" + manifest.Ref signature, err := base64.StdEncoding.DecodeString(manifest.Signature.Value) @@ -84,6 +93,12 @@ func VerifyManifest(manifest Manifest) error { if err != nil { return err } + if len(signature) != ed25519.SignatureSize { + return errors.New("manifest signature has invalid size") + } + if len(publicKey) != ed25519.PublicKeySize { + return errors.New("manifest public key has invalid size") + } if !ed25519.Verify(ed25519.PublicKey(publicKey), []byte(payload), signature) { return errors.New("manifest signature verification failed") } @@ -91,16 +106,34 @@ func VerifyManifest(manifest Manifest) error { } func (i Installer) Install(ctx context.Context, manifest Manifest) (string, error) { + if strings.TrimSpace(i.InstallDir) == "" { + return "", errors.New("install dir is required") + } if err := VerifyManifest(manifest); err != nil { return "", err } - if strings.TrimSpace(i.InstallDir) == "" { - return "", errors.New("install dir is required") + if err := validateManifestName(manifest.Name); err != nil { + return "", err } if err := os.MkdirAll(i.InstallDir, 0o755); err != nil { return "", err } targetDir := filepath.Join(i.InstallDir, safeName(manifest.Name)) + rootAbs, err := filepath.Abs(i.InstallDir) + if err != nil { + return "", err + } + targetAbs, err := filepath.Abs(targetDir) + if err != nil { + return "", err + } + rel, err := filepath.Rel(rootAbs, targetAbs) + if err != nil { + return "", err + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", errors.New("install path escapes install dir") + } _ = os.RemoveAll(targetDir) args := []string{"clone", "--depth", "1"} if manifest.Ref != "" { @@ -118,6 +151,20 @@ func (i Installer) Install(ctx context.Context, manifest Manifest) (string, erro return targetDir, nil } +func validateManifestName(value string) error { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return errors.New("manifest name is required") + } + if strings.ContainsAny(trimmed, `/\`) { + return errors.New("manifest name must not contain path separators") + } + if strings.Contains(trimmed, "..") { + return errors.New("manifest name must not contain path traversal segments") + } + return nil +} + func DigestManifest(manifest Manifest) string { hash := sha256.Sum256([]byte(manifest.Name + ":" + manifest.Version + ":" + manifest.Repository + ":" + manifest.Ref)) return hex.EncodeToString(hash[:]) @@ -125,9 +172,29 @@ func DigestManifest(manifest Manifest) string { func safeName(value string) string { value = strings.TrimSpace(strings.ToLower(value)) - value = strings.ReplaceAll(value, " ", "-") if value == "" { return "module" } - return value + var builder strings.Builder + lastDash := false + for _, r := range value { + switch { + case unicode.IsLetter(r), unicode.IsDigit(r): + builder.WriteRune(r) + lastDash = false + case r == '-' || r == '_' || r == '.': + builder.WriteRune(r) + lastDash = false + default: + if !lastDash { + builder.WriteRune('-') + lastDash = true + } + } + } + cleaned := strings.Trim(builder.String(), "-._") + if cleaned == "" { + return "module" + } + return cleaned } diff --git a/pkg/marketplace/marketplace_test.go b/pkg/marketplace/marketplace_test.go index 7c8ffae5..5d046afa 100644 --- a/pkg/marketplace/marketplace_test.go +++ b/pkg/marketplace/marketplace_test.go @@ -112,6 +112,17 @@ func TestMarketplace_VerifyManifest_Ugly(t *testing.T) { require.Error(t, VerifyManifest(manifest)) } +func TestMarketplace_VerifyManifest_RequiresSignature(t *testing.T) { + manifest := Manifest{ + Name: "core-ui", + Version: "1.2.3", + Repository: "https://example.com/core-ui.git", + Ref: "main", + } + + require.Error(t, VerifyManifest(manifest)) +} + func TestMarketplace_Install_Good(t *testing.T) { scriptDir := t.TempDir() logFile := filepath.Join(scriptDir, "git.log") @@ -125,12 +136,12 @@ func TestMarketplace_Install_Good(t *testing.T) { InstallDir: targetRoot, } - targetDir, err := installer.Install(context.Background(), Manifest{ + targetDir, err := installer.Install(context.Background(), signedManifest(t, Manifest{ Name: "Core UI", Version: "1.2.3", Repository: "https://example.com/core-ui.git", Ref: "main", - }) + })) require.NoError(t, err) assert.Equal(t, filepath.Join(targetRoot, "core-ui"), targetDir) _, err = os.Stat(targetDir) @@ -150,6 +161,18 @@ func TestMarketplace_Install_Bad(t *testing.T) { assert.Contains(t, err.Error(), "install dir is required") } +func TestMarketplace_Install_RejectsTraversalName(t *testing.T) { + installer := Installer{InstallDir: t.TempDir()} + _, err := installer.Install(context.Background(), signedManifest(t, Manifest{ + Name: "../../escape", + Version: "1.2.3", + Repository: "https://example.com/core-ui.git", + Ref: "main", + })) + require.Error(t, err) + assert.Contains(t, err.Error(), "path separators") +} + func TestMarketplace_Install_Ugly(t *testing.T) { scriptDir := t.TempDir() scriptPath := filepath.Join(scriptDir, "git") @@ -160,10 +183,10 @@ func TestMarketplace_Install_Ugly(t *testing.T) { InstallDir: t.TempDir(), } - _, err := installer.Install(context.Background(), Manifest{ + _, err := installer.Install(context.Background(), signedManifest(t, Manifest{ Name: "core-ui", Repository: "https://example.com/core-ui.git", - }) + })) require.Error(t, err) assert.Contains(t, err.Error(), "git clone failed") }