diff --git a/modules.go b/modules.go index 0d6f6be..1c349cb 100644 --- a/modules.go +++ b/modules.go @@ -2018,6 +2018,7 @@ func (e *Executor) moduleSetup(ctx context.Context, host string, client sshFacts } factMap := factsToMap(facts) + factMap = applyGatherSubsetFilter(factMap, normalizeStringList(args["gather_subset"])) filteredFactMap := filterFactsMap(factMap, normalizeStringList(args["filter"])) filteredFacts := factsFromMap(filteredFactMap) @@ -2032,6 +2033,155 @@ func (e *Executor) moduleSetup(ctx context.Context, host string, client sshFacts }, nil } +func applyGatherSubsetFilter(facts map[string]any, subsets []string) map[string]any { + if len(facts) == 0 || len(subsets) == 0 { + return facts + } + + normalized := make([]string, 0, len(subsets)) + for _, subset := range subsets { + if trimmed := lower(corexTrimSpace(subset)); trimmed != "" { + normalized = append(normalized, trimmed) + } + } + if len(normalized) == 0 { + return facts + } + + includeAll := false + excludeAll := false + excludeMin := false + positives := make([]string, 0, len(normalized)) + exclusions := make([]string, 0, len(normalized)) + for _, subset := range normalized { + if corexHasPrefix(subset, "!") { + name := corexTrimPrefix(subset, "!") + if name != "" { + exclusions = append(exclusions, name) + } + switch name { + case "all": + excludeAll = true + case "min": + excludeMin = true + } + continue + } + + positives = append(positives, subset) + switch subset { + case "all": + includeAll = true + case "min": + // handled below + } + } + + if includeAll && !excludeAll { + return facts + } + + selected := make(map[string]bool) + if len(positives) == 0 { + if !excludeAll { + for key := range facts { + selected[key] = true + } + } else if !excludeMin { + addSubsetKeys(selected, "min") + } + } else { + if !excludeMin { + addSubsetKeys(selected, "min") + } + } + + for _, subset := range positives { + addSubsetKeys(selected, subset) + } + for _, subset := range exclusions { + removeSubsetKeys(selected, subset) + } + + if len(selected) == 0 { + return map[string]any{} + } + + filtered := make(map[string]any) + for key, value := range facts { + if selected[key] { + filtered[key] = value + } + } + + return filtered +} + +func addSubsetKeys(selected map[string]bool, subset string) { + for _, key := range gatherSubsetKeys(subset) { + selected[key] = true + } +} + +func removeSubsetKeys(selected map[string]bool, subset string) { + if subset == "all" { + return + } + for _, key := range gatherSubsetKeys(subset) { + delete(selected, key) + } + delete(selected, subset) +} + +func gatherSubsetKeys(subset string) []string { + switch subset { + case "all": + return []string{ + "ansible_hostname", + "ansible_fqdn", + "ansible_os_family", + "ansible_distribution", + "ansible_distribution_version", + "ansible_architecture", + "ansible_kernel", + "ansible_memtotal_mb", + "ansible_processor_vcpus", + "ansible_default_ipv4_address", + } + case "min": + return []string{ + "ansible_hostname", + "ansible_fqdn", + "ansible_os_family", + "ansible_distribution", + "ansible_distribution_version", + "ansible_architecture", + "ansible_kernel", + } + case "hardware": + return []string{ + "ansible_architecture", + "ansible_kernel", + "ansible_memtotal_mb", + "ansible_processor_vcpus", + } + case "network": + return []string{ + "ansible_default_ipv4_address", + } + case "distribution": + return []string{ + "ansible_os_family", + "ansible_distribution", + "ansible_distribution_version", + } + case "virtual": + return nil + default: + return nil + } +} + func (e *Executor) collectFacts(ctx context.Context, client sshFactsRunner) (*Facts, error) { facts := &Facts{} diff --git a/modules_infra_test.go b/modules_infra_test.go index 13968a3..21fd61e 100644 --- a/modules_infra_test.go +++ b/modules_infra_test.go @@ -1038,6 +1038,44 @@ func TestModulesInfra_ModuleSetup_Good_FilteredFacts(t *testing.T) { assert.Equal(t, "debian", e.facts["host1"].Distribution) } +func TestModulesInfra_ModuleSetup_Good_GatherSubset(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + mock.expectCommand(`hostname -f`, "web1.example.com\n", "", 0) + mock.expectCommand(`hostname -s`, "web1\n", "", 0) + mock.expectCommand(`cat /etc/os-release`, "ID=debian\nVERSION_ID=12\n", "", 0) + mock.expectCommand(`uname -m`, "x86_64\n", "", 0) + mock.expectCommand(`uname -r`, "6.1.0\n", "", 0) + mock.expectCommand(`nproc`, "8\n", "", 0) + mock.expectCommand(`free -m`, "16384\n", "", 0) + mock.expectCommand(`hostname -I`, "10.0.0.11\n", "", 0) + + task := &Task{ + Module: "setup", + Args: map[string]any{ + "gather_subset": "!all,!min,network", + }, + } + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Data) + + facts, ok := result.Data["ansible_facts"].(map[string]any) + require.True(t, ok) + assert.Len(t, facts, 1) + assert.Equal(t, "10.0.0.11", facts["ansible_default_ipv4_address"]) + assert.NotContains(t, facts, "ansible_hostname") + assert.NotContains(t, facts, "ansible_distribution") + + require.NotNil(t, e.facts["host1"]) + assert.Equal(t, "", e.facts["host1"].Hostname) + assert.Equal(t, "10.0.0.11", e.facts["host1"].IPv4) + assert.Equal(t, "", e.templateString("{{ ansible_hostname }}", "host1", nil)) + assert.Equal(t, "10.0.0.11", e.templateString("{{ ansible_default_ipv4_address }}", "host1", nil)) +} + func TestModulesInfra_ModuleArchive_Good_CreateZipArchive(t *testing.T) { e, mock := newTestExecutorWithMock("host1")