diff --git a/modules.go b/modules.go index 9337d4d..8a2d580 100644 --- a/modules.go +++ b/modules.go @@ -2115,6 +2115,7 @@ func (e *Executor) moduleIncludeVars(args map[string]any) (*TaskResult, error) { dir := getStringArg(args, "dir", "") name := getStringArg(args, "name", "") filesMatching := getStringArg(args, "files_matching", "") + extensions := normalizeIncludeVarsExtensions(normalizeStringList(args["extensions"])) hashBehaviour := lower(getStringArg(args, "hash_behaviour", "replace")) depth := getIntArg(args, "depth", 0) @@ -2149,7 +2150,7 @@ func (e *Executor) moduleIncludeVars(args map[string]any) (*TaskResult, error) { if dir != "" { dir = e.resolveLocalPath(dir) - files, err := collectIncludeVarsFiles(dir, depth, filesMatching) + files, err := collectIncludeVarsFiles(dir, depth, filesMatching, extensions) if err != nil { return nil, err } @@ -2176,7 +2177,31 @@ func (e *Executor) moduleIncludeVars(args map[string]any) (*TaskResult, error) { return &TaskResult{Changed: true, Msg: msg}, nil } -func collectIncludeVarsFiles(dir string, depth int, filesMatching string) ([]string, error) { +func normalizeIncludeVarsExtensions(values []string) []string { + if len(values) == 0 { + return []string{".json", ".yml", ".yaml"} + } + + extensions := make([]string, 0, len(values)) + seen := make(map[string]bool, len(values)) + for _, value := range values { + ext := lower(corexTrimSpace(value)) + if ext == "" { + continue + } + if !corexHasPrefix(ext, ".") { + ext = "." + ext + } + if seen[ext] { + continue + } + seen[ext] = true + extensions = append(extensions, ext) + } + return extensions +} + +func collectIncludeVarsFiles(dir string, depth int, filesMatching string, extensions []string) ([]string, error) { info, err := os.Stat(dir) if err != nil { return nil, coreerr.E("Executor.moduleIncludeVars", "read vars dir", err) @@ -2199,6 +2224,10 @@ func collectIncludeVarsFiles(dir string, depth int, filesMatching string) ([]str } var files []string + allowed := make(map[string]bool, len(extensions)) + for _, ext := range extensions { + allowed[ext] = true + } stack := []dirEntry{{path: dir, depth: 0}} for len(stack) > 0 { current := stack[len(stack)-1] @@ -2221,12 +2250,13 @@ func collectIncludeVarsFiles(dir string, depth int, filesMatching string) ([]str } ext := lower(filepath.Ext(entry.Name())) - if ext == ".yml" || ext == ".yaml" { - if matcher != nil && !matcher.MatchString(entry.Name()) { - continue - } - files = append(files, fullPath) + if !allowed[ext] { + continue } + if matcher != nil && !matcher.MatchString(entry.Name()) { + continue + } + files = append(files, fullPath) } } diff --git a/modules_adv_test.go b/modules_adv_test.go index bea3b3a..d27e5e6 100644 --- a/modules_adv_test.go +++ b/modules_adv_test.go @@ -908,6 +908,45 @@ func TestModulesAdv_ModuleIncludeVars_Good_LoadSingleFile(t *testing.T) { assert.Equal(t, true, nested["enabled"]) } +func TestModulesAdv_ModuleIncludeVars_Good_LoadJSONFileByDefault(t *testing.T) { + dir := t.TempDir() + varsPath := joinPath(dir, "vars.json") + require.NoError(t, writeTestFile(varsPath, []byte(`{"app_name":"demo","app_port":8080}`), 0644)) + + e := NewExecutor("/tmp") + + result, err := e.moduleIncludeVars(map[string]any{ + "file": varsPath, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.Equal(t, "demo", e.vars["app_name"]) + assert.Equal(t, 8080, e.vars["app_port"]) +} + +func TestModulesAdv_ModuleIncludeVars_Good_CustomExtensionsFilter(t *testing.T) { + dir := t.TempDir() + require.NoError(t, writeTestFile(joinPath(dir, "01-ignored.yml"), []byte("ignored_value: false\n"), 0644)) + require.NoError(t, writeTestFile(joinPath(dir, "02-selected.vars"), []byte("selected_value: included\n"), 0644)) + + e := NewExecutor("/tmp") + + result, err := e.moduleIncludeVars(map[string]any{ + "dir": dir, + "extensions": []any{"vars"}, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, "included", e.vars["selected_value"]) + _, hasIgnored := e.vars["ignored_value"] + assert.False(t, hasIgnored) + assert.Contains(t, result.Msg, joinPath(dir, "02-selected.vars")) + assert.NotContains(t, result.Msg, joinPath(dir, "01-ignored.yml")) +} + func TestModulesAdv_ModuleIncludeVars_Good_LoadDirectoryWithMerge(t *testing.T) { dir := t.TempDir() require.NoError(t, writeTestFile(joinPath(dir, "01-base.yml"), []byte("app_name: demo\nnested:\n a: 1\n"), 0644))