diff --git a/node.go b/node.go index 37172f4..2ef708e 100644 --- a/node.go +++ b/node.go @@ -4,6 +4,7 @@ import ( "html" "iter" "maps" + "reflect" "slices" "strconv" ) @@ -98,24 +99,42 @@ func El(tag string, children ...Node) Node { // It recursively traverses through wrappers like If, Unless, Entitled, Each, // EachSeq, Switch, Layout, and Responsive when present. func Attr(n Node, key, value string) Node { - if n == nil { - return n + if isNilNode(n) { + return nil } switch t := n.(type) { case *elNode: + if t == nil { + return nil + } t.attrs[key] = value case *ifNode: + if t == nil { + return nil + } Attr(t.node, key, value) case *unlessNode: + if t == nil { + return nil + } Attr(t.node, key, value) case *entitledNode: + if t == nil { + return nil + } Attr(t.node, key, value) case *switchNode: + if t == nil { + return nil + } for _, child := range t.cases { Attr(child, key, value) } case *Layout: + if t == nil { + return nil + } if t.slots != nil { for slot, children := range t.slots { for i := range children { @@ -134,6 +153,20 @@ func Attr(n Node, key, value string) Node { return n } +func isNilNode(n Node) bool { + if n == nil { + return true + } + + v := reflect.ValueOf(n) + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: + return v.IsNil() + default: + return false + } +} + // AriaLabel sets an aria-label attribute on an element node. // Usage example: AriaLabel(El("button", Text("save")), "Save changes") func AriaLabel(n Node, label string) Node { diff --git a/node_test.go b/node_test.go index 9bebd18..050ba58 100644 --- a/node_test.go +++ b/node_test.go @@ -456,6 +456,29 @@ func TestAttr_NonElement_Ugly(t *testing.T) { } } +func TestAttr_TypedNilWrappers_Ugly(t *testing.T) { + tests := []struct { + name string + node Node + }{ + {name: "layout", node: (*Layout)(nil)}, + {name: "responsive", node: (*Responsive)(nil)}, + {name: "if", node: (*ifNode)(nil)}, + {name: "unless", node: (*unlessNode)(nil)}, + {name: "entitled", node: (*entitledNode)(nil)}, + {name: "switch", node: (*switchNode)(nil)}, + {name: "each", node: (*eachNode[string])(nil)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Attr(tt.node, "data-test", "x"); got != nil { + t.Fatalf("Attr on typed nil %s should return nil, got %#v", tt.name, got) + } + }) + } +} + func TestUnlessNode_True_Good(t *testing.T) { ctx := NewContext() node := Unless(func(*Context) bool { return true }, Raw("hidden"))