From 473fda5894ac4d7a28b8c3bfeef31042daf0000c Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 9 Mar 2026 08:29:54 +0000 Subject: [PATCH] security: use html.EscapeString for XSS protection in rendering Replace manual string-replace escaping with stdlib html.EscapeString. Escape tag names and attribute keys in elNode and Layout rendering. Improve Attr() to traverse wrapper nodes (If, Unless, Entitled). Co-Authored-By: Claude Opus 4.6 --- edge_test.go | 2 +- layout.go | 6 +++--- node.go | 34 ++++++++++++++++------------------ node_test.go | 2 +- 4 files changed, 21 insertions(+), 23 deletions(-) diff --git a/edge_test.go b/edge_test.go index 854415e..c1f5e83 100644 --- a/edge_test.go +++ b/edge_test.go @@ -447,7 +447,7 @@ func TestEscapeAttr_AllSpecialChars(t *testing.T) { if strings.Contains(got, `"&<>"'"`) { t.Error("attribute value with special chars must be fully escaped") } - if !strings.Contains(got, "&<>"'") { + if !strings.Contains(got, "&<>"'") { t.Errorf("expected all special chars escaped in attribute, got: %s", got) } } diff --git a/layout.go b/layout.go index 00890d1..bc2775c 100644 --- a/layout.go +++ b/layout.go @@ -89,11 +89,11 @@ func (l *Layout) Render(ctx *Context) string { bid := l.blockID(slot) b.WriteByte('<') - b.WriteString(meta.tag) + b.WriteString(escapeHTML(meta.tag)) b.WriteString(` role="`) - b.WriteString(meta.role) + b.WriteString(escapeAttr(meta.role)) b.WriteString(`" data-block="`) - b.WriteString(bid) + b.WriteString(escapeAttr(bid)) b.WriteString(`">`) for _, child := range children { diff --git a/node.go b/node.go index 71823c3..9416e0e 100644 --- a/node.go +++ b/node.go @@ -1,6 +1,7 @@ package html import ( + "html" "iter" "maps" "slices" @@ -33,12 +34,7 @@ var voidElements = map[string]bool{ // escapeAttr escapes a string for use in an HTML attribute value. func escapeAttr(s string) string { - s = strings.ReplaceAll(s, "&", "&") - s = strings.ReplaceAll(s, "\"", """) - s = strings.ReplaceAll(s, "'", "'") - s = strings.ReplaceAll(s, "<", "<") - s = strings.ReplaceAll(s, ">", ">") - return s + return html.EscapeString(s) } // --- rawNode --- @@ -74,10 +70,17 @@ func El(tag string, children ...Node) Node { } // Attr sets an attribute on an El node. Returns the node for chaining. -// If the node is not an *elNode, returns it unchanged. +// It recursively traverses through wrappers like If, Unless, and Entitled. func Attr(n Node, key, value string) Node { - if el, ok := n.(*elNode); ok { - el.attrs[key] = value + switch t := n.(type) { + case *elNode: + t.attrs[key] = value + case *ifNode: + Attr(t.node, key, value) + case *unlessNode: + Attr(t.node, key, value) + case *entitledNode: + Attr(t.node, key, value) } return n } @@ -86,14 +89,14 @@ func (n *elNode) Render(ctx *Context) string { var b strings.Builder b.WriteByte('<') - b.WriteString(n.tag) + b.WriteString(escapeHTML(n.tag)) // Sort attribute keys for deterministic output. keys := slices.Collect(maps.Keys(n.attrs)) slices.Sort(keys) for _, key := range keys { b.WriteByte(' ') - b.WriteString(key) + b.WriteString(escapeHTML(key)) b.WriteString(`="`) b.WriteString(escapeAttr(n.attrs[key])) b.WriteByte('"') @@ -110,7 +113,7 @@ func (n *elNode) Render(ctx *Context) string { } b.WriteString("') return b.String() @@ -120,12 +123,7 @@ func (n *elNode) Render(ctx *Context) string { // escapeHTML escapes a string for safe inclusion in HTML text content. func escapeHTML(s string) string { - s = strings.ReplaceAll(s, "&", "&") - s = strings.ReplaceAll(s, "<", "<") - s = strings.ReplaceAll(s, ">", ">") - s = strings.ReplaceAll(s, "\"", """) - s = strings.ReplaceAll(s, "'", "'") - return s + return html.EscapeString(s) } // --- textNode --- diff --git a/node_test.go b/node_test.go index 77fd26d..1b6cd2b 100644 --- a/node_test.go +++ b/node_test.go @@ -169,7 +169,7 @@ func TestElNode_AttrEscaping(t *testing.T) { ctx := NewContext() node := Attr(El("img"), "alt", `he said "hello"`) got := node.Render(ctx) - if !strings.Contains(got, `alt="he said "hello""`) { + if !strings.Contains(got, `alt="he said "hello""`) { t.Errorf("Attr should escape attribute values, got %q", got) } }