From 4fe5484e1f4b42ce8db49a6f98dcd41802c51c0d Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 9 Mar 2026 11:37:27 +0000 Subject: [PATCH] feat: extract ansible package from go-devops Pure Go Ansible playbook engine: YAML parser, SSH executor, 30+ module implementations, Jinja2-compatible templating, privilege escalation, event-driven callbacks, and inventory pattern matching. 438 tests passing with race detector. Co-Authored-By: Virgil --- CLAUDE.md | 37 ++ executor.go | 1018 +++++++++++++++++++++++++++++ executor_test.go | 427 ++++++++++++ go.mod | 17 + go.sum | 26 + mock_ssh_test.go | 1416 ++++++++++++++++++++++++++++++++++++++++ modules.go | 1435 +++++++++++++++++++++++++++++++++++++++++ modules_adv_test.go | 1127 ++++++++++++++++++++++++++++++++ modules_cmd_test.go | 722 +++++++++++++++++++++ modules_file_test.go | 899 ++++++++++++++++++++++++++ modules_infra_test.go | 1261 ++++++++++++++++++++++++++++++++++++ modules_svc_test.go | 950 +++++++++++++++++++++++++++ parser.go | 510 +++++++++++++++ parser_test.go | 777 ++++++++++++++++++++++ ssh.go | 451 +++++++++++++ ssh_test.go | 36 ++ types.go | 258 ++++++++ types_test.go | 402 ++++++++++++ 18 files changed, 11769 insertions(+) create mode 100644 CLAUDE.md create mode 100644 executor.go create mode 100644 executor_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 mock_ssh_test.go create mode 100644 modules.go create mode 100644 modules_adv_test.go create mode 100644 modules_cmd_test.go create mode 100644 modules_file_test.go create mode 100644 modules_infra_test.go create mode 100644 modules_svc_test.go create mode 100644 parser.go create mode 100644 parser_test.go create mode 100644 ssh.go create mode 100644 ssh_test.go create mode 100644 types.go create mode 100644 types_test.go diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..ab6322b --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,37 @@ +# CLAUDE.md + +## Project Overview + +`core/go-ansible` is a pure Go Ansible playbook engine. It parses YAML playbooks, inventories, and roles, then executes tasks on remote hosts via SSH. 174 module implementations, Jinja2-compatible templating, privilege escalation (become), and event-driven callbacks. + +## Build & Development + +```bash +go test ./... +go test -race ./... +``` + +## Architecture + +Single `ansible` package: + +- `types.go` — Playbook, Play, Task, TaskResult, Inventory, Host, Facts structs +- `parser.go` — YAML parser for playbooks, inventories, tasks, roles +- `executor.go` — Task execution engine with module dispatch, templating, condition evaluation +- `modules.go` — 30+ module implementations (shell, copy, apt, systemd, git, docker-compose, etc.) +- `ssh.go` — SSH client with known_hosts verification, become/sudo, file transfer + +## Dependencies + +- `go-log` — Structured logging +- `golang.org/x/crypto` — SSH protocol +- `gopkg.in/yaml.v3` — YAML parsing +- `testify` — Test assertions + +## Coding Standards + +- UK English +- All functions have typed params/returns +- Tests use testify + mock SSH client +- Test naming: `_Good`, `_Bad`, `_Ugly` suffixes +- License: EUPL-1.2 diff --git a/executor.go b/executor.go new file mode 100644 index 0000000..d88f9b2 --- /dev/null +++ b/executor.go @@ -0,0 +1,1018 @@ +package ansible + +import ( + "context" + "fmt" + "os" + "regexp" + "slices" + "strings" + "sync" + "text/template" + "time" + + "forge.lthn.ai/core/go-log" +) + +// Executor runs Ansible playbooks. +type Executor struct { + parser *Parser + inventory *Inventory + vars map[string]any + facts map[string]*Facts + results map[string]map[string]*TaskResult // host -> register_name -> result + handlers map[string][]Task + notified map[string]bool + clients map[string]*SSHClient + mu sync.RWMutex + + // Callbacks + OnPlayStart func(play *Play) + OnTaskStart func(host string, task *Task) + OnTaskEnd func(host string, task *Task, result *TaskResult) + OnPlayEnd func(play *Play) + + // Options + Limit string + Tags []string + SkipTags []string + CheckMode bool + Diff bool + Verbose int +} + +// NewExecutor creates a new playbook executor. +func NewExecutor(basePath string) *Executor { + return &Executor{ + parser: NewParser(basePath), + vars: make(map[string]any), + facts: make(map[string]*Facts), + results: make(map[string]map[string]*TaskResult), + handlers: make(map[string][]Task), + notified: make(map[string]bool), + clients: make(map[string]*SSHClient), + } +} + +// SetInventory loads inventory from a file. +func (e *Executor) SetInventory(path string) error { + inv, err := e.parser.ParseInventory(path) + if err != nil { + return err + } + e.inventory = inv + return nil +} + +// SetInventoryDirect sets inventory directly. +func (e *Executor) SetInventoryDirect(inv *Inventory) { + e.inventory = inv +} + +// SetVar sets a variable. +func (e *Executor) SetVar(key string, value any) { + e.mu.Lock() + defer e.mu.Unlock() + e.vars[key] = value +} + +// Run executes a playbook. +func (e *Executor) Run(ctx context.Context, playbookPath string) error { + plays, err := e.parser.ParsePlaybook(playbookPath) + if err != nil { + return fmt.Errorf("parse playbook: %w", err) + } + + for i := range plays { + if err := e.runPlay(ctx, &plays[i]); err != nil { + return fmt.Errorf("play %d (%s): %w", i, plays[i].Name, err) + } + } + + return nil +} + +// runPlay executes a single play. +func (e *Executor) runPlay(ctx context.Context, play *Play) error { + if e.OnPlayStart != nil { + e.OnPlayStart(play) + } + defer func() { + if e.OnPlayEnd != nil { + e.OnPlayEnd(play) + } + }() + + // Get target hosts + hosts := e.getHosts(play.Hosts) + if len(hosts) == 0 { + return nil // No hosts matched + } + + // Merge play vars + for k, v := range play.Vars { + e.vars[k] = v + } + + // Gather facts if needed + gatherFacts := play.GatherFacts == nil || *play.GatherFacts + if gatherFacts { + for _, host := range hosts { + if err := e.gatherFacts(ctx, host, play); err != nil { + // Non-fatal + if e.Verbose > 0 { + log.Warn("gather facts failed", "host", host, "err", err) + } + } + } + } + + // Execute pre_tasks + for _, task := range play.PreTasks { + if err := e.runTaskOnHosts(ctx, hosts, &task, play); err != nil { + return err + } + } + + // Execute roles + for _, roleRef := range play.Roles { + if err := e.runRole(ctx, hosts, &roleRef, play); err != nil { + return err + } + } + + // Execute tasks + for _, task := range play.Tasks { + if err := e.runTaskOnHosts(ctx, hosts, &task, play); err != nil { + return err + } + } + + // Execute post_tasks + for _, task := range play.PostTasks { + if err := e.runTaskOnHosts(ctx, hosts, &task, play); err != nil { + return err + } + } + + // Run notified handlers + for _, handler := range play.Handlers { + if e.notified[handler.Name] { + if err := e.runTaskOnHosts(ctx, hosts, &handler, play); err != nil { + return err + } + } + } + + return nil +} + +// runRole executes a role on hosts. +func (e *Executor) runRole(ctx context.Context, hosts []string, roleRef *RoleRef, play *Play) error { + // Check when condition + if roleRef.When != nil { + if !e.evaluateWhen(roleRef.When, "", nil) { + return nil + } + } + + // Parse role tasks + tasks, err := e.parser.ParseRole(roleRef.Role, roleRef.TasksFrom) + if err != nil { + return log.E("executor.runRole", fmt.Sprintf("parse role %s", roleRef.Role), err) + } + + // Merge role vars + oldVars := make(map[string]any) + for k, v := range e.vars { + oldVars[k] = v + } + for k, v := range roleRef.Vars { + e.vars[k] = v + } + + // Execute tasks + for _, task := range tasks { + if err := e.runTaskOnHosts(ctx, hosts, &task, play); err != nil { + // Restore vars + e.vars = oldVars + return err + } + } + + // Restore vars + e.vars = oldVars + return nil +} + +// runTaskOnHosts runs a task on all hosts. +func (e *Executor) runTaskOnHosts(ctx context.Context, hosts []string, task *Task, play *Play) error { + // Check tags + if !e.matchesTags(task.Tags) { + return nil + } + + // Handle block tasks + if len(task.Block) > 0 { + return e.runBlock(ctx, hosts, task, play) + } + + // Handle include/import + if task.IncludeTasks != "" || task.ImportTasks != "" { + return e.runIncludeTasks(ctx, hosts, task, play) + } + if task.IncludeRole != nil || task.ImportRole != nil { + return e.runIncludeRole(ctx, hosts, task, play) + } + + for _, host := range hosts { + if err := e.runTaskOnHost(ctx, host, task, play); err != nil { + if !task.IgnoreErrors { + return err + } + } + } + + return nil +} + +// runTaskOnHost runs a task on a single host. +func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, play *Play) error { + start := time.Now() + + if e.OnTaskStart != nil { + e.OnTaskStart(host, task) + } + + // Initialize host results + if e.results[host] == nil { + e.results[host] = make(map[string]*TaskResult) + } + + // Check when condition + if task.When != nil { + if !e.evaluateWhen(task.When, host, task) { + result := &TaskResult{Skipped: true, Msg: "Skipped due to when condition"} + if task.Register != "" { + e.results[host][task.Register] = result + } + if e.OnTaskEnd != nil { + e.OnTaskEnd(host, task, result) + } + return nil + } + } + + // Get SSH client + client, err := e.getClient(host, play) + if err != nil { + return fmt.Errorf("get client for %s: %w", host, err) + } + + // Handle loops + if task.Loop != nil { + return e.runLoop(ctx, host, client, task, play) + } + + // Execute the task + result, err := e.executeModule(ctx, host, client, task, play) + if err != nil { + result = &TaskResult{Failed: true, Msg: err.Error()} + } + result.Duration = time.Since(start) + + // Store result + if task.Register != "" { + e.results[host][task.Register] = result + } + + // Handle notify + if result.Changed && task.Notify != nil { + e.handleNotify(task.Notify) + } + + if e.OnTaskEnd != nil { + e.OnTaskEnd(host, task, result) + } + + if result.Failed && !task.IgnoreErrors { + return fmt.Errorf("task failed: %s", result.Msg) + } + + return nil +} + +// runLoop handles task loops. +func (e *Executor) runLoop(ctx context.Context, host string, client *SSHClient, task *Task, play *Play) error { + items := e.resolveLoop(task.Loop, host) + + loopVar := "item" + if task.LoopControl != nil && task.LoopControl.LoopVar != "" { + loopVar = task.LoopControl.LoopVar + } + + // Save loop state to restore after loop + savedVars := make(map[string]any) + if v, ok := e.vars[loopVar]; ok { + savedVars[loopVar] = v + } + indexVar := "" + if task.LoopControl != nil && task.LoopControl.IndexVar != "" { + indexVar = task.LoopControl.IndexVar + if v, ok := e.vars[indexVar]; ok { + savedVars[indexVar] = v + } + } + + var results []TaskResult + for i, item := range items { + // Set loop variables + e.vars[loopVar] = item + if indexVar != "" { + e.vars[indexVar] = i + } + + result, err := e.executeModule(ctx, host, client, task, play) + if err != nil { + result = &TaskResult{Failed: true, Msg: err.Error()} + } + results = append(results, *result) + + if result.Failed && !task.IgnoreErrors { + break + } + } + + // Restore loop variables + if v, ok := savedVars[loopVar]; ok { + e.vars[loopVar] = v + } else { + delete(e.vars, loopVar) + } + if indexVar != "" { + if v, ok := savedVars[indexVar]; ok { + e.vars[indexVar] = v + } else { + delete(e.vars, indexVar) + } + } + + // Store combined result + if task.Register != "" { + combined := &TaskResult{ + Results: results, + Changed: false, + } + for _, r := range results { + if r.Changed { + combined.Changed = true + } + if r.Failed { + combined.Failed = true + } + } + e.results[host][task.Register] = combined + } + + return nil +} + +// runBlock handles block/rescue/always. +func (e *Executor) runBlock(ctx context.Context, hosts []string, task *Task, play *Play) error { + var blockErr error + + // Try block + for _, t := range task.Block { + if err := e.runTaskOnHosts(ctx, hosts, &t, play); err != nil { + blockErr = err + break + } + } + + // Run rescue if block failed + if blockErr != nil && len(task.Rescue) > 0 { + for _, t := range task.Rescue { + if err := e.runTaskOnHosts(ctx, hosts, &t, play); err != nil { + // Rescue also failed + break + } + } + } + + // Always run always block + for _, t := range task.Always { + if err := e.runTaskOnHosts(ctx, hosts, &t, play); err != nil { + if blockErr == nil { + blockErr = err + } + } + } + + if blockErr != nil && len(task.Rescue) == 0 { + return blockErr + } + + return nil +} + +// runIncludeTasks handles include_tasks/import_tasks. +func (e *Executor) runIncludeTasks(ctx context.Context, hosts []string, task *Task, play *Play) error { + path := task.IncludeTasks + if path == "" { + path = task.ImportTasks + } + + // Resolve path relative to playbook + path = e.templateString(path, "", nil) + + tasks, err := e.parser.ParseTasks(path) + if err != nil { + return fmt.Errorf("include_tasks %s: %w", path, err) + } + + for _, t := range tasks { + if err := e.runTaskOnHosts(ctx, hosts, &t, play); err != nil { + return err + } + } + + return nil +} + +// runIncludeRole handles include_role/import_role. +func (e *Executor) runIncludeRole(ctx context.Context, hosts []string, task *Task, play *Play) error { + var roleName, tasksFrom string + var roleVars map[string]any + + if task.IncludeRole != nil { + roleName = task.IncludeRole.Name + tasksFrom = task.IncludeRole.TasksFrom + roleVars = task.IncludeRole.Vars + } else { + roleName = task.ImportRole.Name + tasksFrom = task.ImportRole.TasksFrom + roleVars = task.ImportRole.Vars + } + + roleRef := &RoleRef{ + Role: roleName, + TasksFrom: tasksFrom, + Vars: roleVars, + } + + return e.runRole(ctx, hosts, roleRef, play) +} + +// getHosts returns hosts matching the pattern. +func (e *Executor) getHosts(pattern string) []string { + if e.inventory == nil { + if pattern == "localhost" { + return []string{"localhost"} + } + return nil + } + + hosts := GetHosts(e.inventory, pattern) + + // Apply limit - filter to hosts that are also in the limit group + if e.Limit != "" { + limitHosts := GetHosts(e.inventory, e.Limit) + limitSet := make(map[string]bool) + for _, h := range limitHosts { + limitSet[h] = true + } + + var filtered []string + for _, h := range hosts { + if limitSet[h] || h == e.Limit || strings.Contains(h, e.Limit) { + filtered = append(filtered, h) + } + } + hosts = filtered + } + + return hosts +} + +// getClient returns or creates an SSH client for a host. +func (e *Executor) getClient(host string, play *Play) (*SSHClient, error) { + e.mu.Lock() + defer e.mu.Unlock() + + if client, ok := e.clients[host]; ok { + return client, nil + } + + // Get host vars + vars := make(map[string]any) + if e.inventory != nil { + vars = GetHostVars(e.inventory, host) + } + + // Merge with play vars + for k, v := range e.vars { + if _, exists := vars[k]; !exists { + vars[k] = v + } + } + + // Build SSH config + cfg := SSHConfig{ + Host: host, + Port: 22, + User: "root", + } + + if h, ok := vars["ansible_host"].(string); ok { + cfg.Host = h + } + if p, ok := vars["ansible_port"].(int); ok { + cfg.Port = p + } + if u, ok := vars["ansible_user"].(string); ok { + cfg.User = u + } + if p, ok := vars["ansible_password"].(string); ok { + cfg.Password = p + } + if k, ok := vars["ansible_ssh_private_key_file"].(string); ok { + cfg.KeyFile = k + } + + // Apply play become settings + if play.Become { + cfg.Become = true + cfg.BecomeUser = play.BecomeUser + if bp, ok := vars["ansible_become_password"].(string); ok { + cfg.BecomePass = bp + } else if cfg.Password != "" { + // Use SSH password for sudo if no become password specified + cfg.BecomePass = cfg.Password + } + } + + client, err := NewSSHClient(cfg) + if err != nil { + return nil, err + } + + e.clients[host] = client + return client, nil +} + +// gatherFacts collects facts from a host. +func (e *Executor) gatherFacts(ctx context.Context, host string, play *Play) error { + if play.Connection == "local" || host == "localhost" { + // Local facts + e.facts[host] = &Facts{ + Hostname: "localhost", + } + return nil + } + + client, err := e.getClient(host, play) + if err != nil { + return err + } + + // Gather basic facts + facts := &Facts{} + + // Hostname + stdout, _, _, err := client.Run(ctx, "hostname -f 2>/dev/null || hostname") + if err == nil { + facts.FQDN = strings.TrimSpace(stdout) + } + + stdout, _, _, err = client.Run(ctx, "hostname -s 2>/dev/null || hostname") + if err == nil { + facts.Hostname = strings.TrimSpace(stdout) + } + + // OS info + stdout, _, _, _ = client.Run(ctx, "cat /etc/os-release 2>/dev/null | grep -E '^(ID|VERSION_ID)=' | head -2") + for line := range strings.SplitSeq(stdout, "\n") { + if strings.HasPrefix(line, "ID=") { + facts.Distribution = strings.Trim(strings.TrimPrefix(line, "ID="), "\"") + } + if strings.HasPrefix(line, "VERSION_ID=") { + facts.Version = strings.Trim(strings.TrimPrefix(line, "VERSION_ID="), "\"") + } + } + + // Architecture + stdout, _, _, _ = client.Run(ctx, "uname -m") + facts.Architecture = strings.TrimSpace(stdout) + + // Kernel + stdout, _, _, _ = client.Run(ctx, "uname -r") + facts.Kernel = strings.TrimSpace(stdout) + + e.mu.Lock() + e.facts[host] = facts + e.mu.Unlock() + + return nil +} + +// evaluateWhen evaluates a when condition. +func (e *Executor) evaluateWhen(when any, host string, task *Task) bool { + conditions := normalizeConditions(when) + + for _, cond := range conditions { + cond = e.templateString(cond, host, task) + if !e.evalCondition(cond, host) { + return false + } + } + + return true +} + +func normalizeConditions(when any) []string { + switch v := when.(type) { + case string: + return []string{v} + case []any: + var conds []string + for _, c := range v { + if s, ok := c.(string); ok { + conds = append(conds, s) + } + } + return conds + case []string: + return v + } + return nil +} + +// evalCondition evaluates a single condition. +func (e *Executor) evalCondition(cond string, host string) bool { + cond = strings.TrimSpace(cond) + + // Handle negation + if strings.HasPrefix(cond, "not ") { + return !e.evalCondition(strings.TrimPrefix(cond, "not "), host) + } + + // Handle boolean literals + if cond == "true" || cond == "True" { + return true + } + if cond == "false" || cond == "False" { + return false + } + + // Handle registered variable checks + // e.g., "result is success", "result.rc == 0" + if strings.Contains(cond, " is ") { + parts := strings.SplitN(cond, " is ", 2) + varName := strings.TrimSpace(parts[0]) + check := strings.TrimSpace(parts[1]) + + result := e.getRegisteredVar(host, varName) + if result == nil { + return check == "not defined" || check == "undefined" + } + + switch check { + case "defined": + return true + case "not defined", "undefined": + return false + case "success", "succeeded": + return !result.Failed + case "failed": + return result.Failed + case "changed": + return result.Changed + case "skipped": + return result.Skipped + } + } + + // Handle simple var checks + if strings.Contains(cond, " | default(") { + // Extract var name and check if defined + re := regexp.MustCompile(`(\w+)\s*\|\s*default\([^)]*\)`) + if match := re.FindStringSubmatch(cond); len(match) > 1 { + // Has default, so condition is satisfied + return true + } + } + + // Check if it's a variable that should be truthy + if result := e.getRegisteredVar(host, cond); result != nil { + return !result.Failed && !result.Skipped + } + + // Check vars + if val, ok := e.vars[cond]; ok { + switch v := val.(type) { + case bool: + return v + case string: + return v != "" && v != "false" && v != "False" + case int: + return v != 0 + } + } + + // Default to true for unknown conditions (be permissive) + return true +} + +// getRegisteredVar gets a registered task result. +func (e *Executor) getRegisteredVar(host string, name string) *TaskResult { + e.mu.RLock() + defer e.mu.RUnlock() + + // Handle dotted access (e.g., "result.stdout") + parts := strings.SplitN(name, ".", 2) + varName := parts[0] + + if hostResults, ok := e.results[host]; ok { + if result, ok := hostResults[varName]; ok { + return result + } + } + + return nil +} + +// templateString applies Jinja2-like templating. +func (e *Executor) templateString(s string, host string, task *Task) string { + // Handle {{ var }} syntax + re := regexp.MustCompile(`\{\{\s*([^}]+)\s*\}\}`) + + return re.ReplaceAllStringFunc(s, func(match string) string { + expr := strings.TrimSpace(match[2 : len(match)-2]) + return e.resolveExpr(expr, host, task) + }) +} + +// resolveExpr resolves a template expression. +func (e *Executor) resolveExpr(expr string, host string, task *Task) string { + // Handle filters + if strings.Contains(expr, " | ") { + parts := strings.SplitN(expr, " | ", 2) + value := e.resolveExpr(parts[0], host, task) + return e.applyFilter(value, parts[1]) + } + + // Handle lookups + if strings.HasPrefix(expr, "lookup(") { + return e.handleLookup(expr) + } + + // Handle registered vars + if strings.Contains(expr, ".") { + parts := strings.SplitN(expr, ".", 2) + if result := e.getRegisteredVar(host, parts[0]); result != nil { + switch parts[1] { + case "stdout": + return result.Stdout + case "stderr": + return result.Stderr + case "rc": + return fmt.Sprintf("%d", result.RC) + case "changed": + return fmt.Sprintf("%t", result.Changed) + case "failed": + return fmt.Sprintf("%t", result.Failed) + } + } + } + + // Check vars + if val, ok := e.vars[expr]; ok { + return fmt.Sprintf("%v", val) + } + + // Check task vars + if task != nil { + if val, ok := task.Vars[expr]; ok { + return fmt.Sprintf("%v", val) + } + } + + // Check host vars + if e.inventory != nil { + hostVars := GetHostVars(e.inventory, host) + if val, ok := hostVars[expr]; ok { + return fmt.Sprintf("%v", val) + } + } + + // Check facts + if facts, ok := e.facts[host]; ok { + switch expr { + case "ansible_hostname": + return facts.Hostname + case "ansible_fqdn": + return facts.FQDN + case "ansible_distribution": + return facts.Distribution + case "ansible_distribution_version": + return facts.Version + case "ansible_architecture": + return facts.Architecture + case "ansible_kernel": + return facts.Kernel + } + } + + return "{{ " + expr + " }}" // Return as-is if unresolved +} + +// applyFilter applies a Jinja2 filter. +func (e *Executor) applyFilter(value, filter string) string { + filter = strings.TrimSpace(filter) + + // Handle default filter + if strings.HasPrefix(filter, "default(") { + if value == "" || value == "{{ "+filter+" }}" { + // Extract default value + re := regexp.MustCompile(`default\(([^)]*)\)`) + if match := re.FindStringSubmatch(filter); len(match) > 1 { + return strings.Trim(match[1], "'\"") + } + } + return value + } + + // Handle bool filter + if filter == "bool" { + lower := strings.ToLower(value) + if lower == "true" || lower == "yes" || lower == "1" { + return "true" + } + return "false" + } + + // Handle trim + if filter == "trim" { + return strings.TrimSpace(value) + } + + // Handle b64decode + if filter == "b64decode" { + // Would need base64 decode + return value + } + + return value +} + +// handleLookup handles lookup() expressions. +func (e *Executor) handleLookup(expr string) string { + // Parse lookup('type', 'arg') + re := regexp.MustCompile(`lookup\s*\(\s*['"](\w+)['"]\s*,\s*['"]([^'"]+)['"]\s*`) + match := re.FindStringSubmatch(expr) + if len(match) < 3 { + return "" + } + + lookupType := match[1] + arg := match[2] + + switch lookupType { + case "env": + return os.Getenv(arg) + case "file": + if data, err := os.ReadFile(arg); err == nil { + return string(data) + } + } + + return "" +} + +// resolveLoop resolves loop items. +func (e *Executor) resolveLoop(loop any, host string) []any { + switch v := loop.(type) { + case []any: + return v + case []string: + items := make([]any, len(v)) + for i, s := range v { + items[i] = s + } + return items + case string: + // Template the string and see if it's a var reference + resolved := e.templateString(v, host, nil) + if val, ok := e.vars[resolved]; ok { + if items, ok := val.([]any); ok { + return items + } + } + } + return nil +} + +// matchesTags checks if task tags match execution tags. +func (e *Executor) matchesTags(taskTags []string) bool { + // If no tags specified, run all + if len(e.Tags) == 0 && len(e.SkipTags) == 0 { + return true + } + + // Check skip tags + for _, skip := range e.SkipTags { + if slices.Contains(taskTags, skip) { + return false + } + } + + // Check include tags + if len(e.Tags) > 0 { + for _, tag := range e.Tags { + if tag == "all" || slices.Contains(taskTags, tag) { + return true + } + } + return false + } + + return true +} + +// handleNotify marks handlers as notified. +func (e *Executor) handleNotify(notify any) { + switch v := notify.(type) { + case string: + e.notified[v] = true + case []any: + for _, n := range v { + if s, ok := n.(string); ok { + e.notified[s] = true + } + } + case []string: + for _, s := range v { + e.notified[s] = true + } + } +} + +// Close closes all SSH connections. +func (e *Executor) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + for _, client := range e.clients { + _ = client.Close() + } + e.clients = make(map[string]*SSHClient) +} + +// TemplateFile processes a template file. +func (e *Executor) TemplateFile(src, host string, task *Task) (string, error) { + content, err := os.ReadFile(src) + if err != nil { + return "", err + } + + // Convert Jinja2 to Go template syntax (basic conversion) + tmplContent := string(content) + tmplContent = strings.ReplaceAll(tmplContent, "{{", "{{ .") + tmplContent = strings.ReplaceAll(tmplContent, "{%", "{{") + tmplContent = strings.ReplaceAll(tmplContent, "%}", "}}") + + tmpl, err := template.New("template").Parse(tmplContent) + if err != nil { + // Fall back to simple replacement + return e.templateString(string(content), host, task), nil + } + + // Build context map + context := make(map[string]any) + for k, v := range e.vars { + context[k] = v + } + // Add host vars + if e.inventory != nil { + hostVars := GetHostVars(e.inventory, host) + for k, v := range hostVars { + context[k] = v + } + } + // Add facts + if facts, ok := e.facts[host]; ok { + context["ansible_hostname"] = facts.Hostname + context["ansible_fqdn"] = facts.FQDN + context["ansible_distribution"] = facts.Distribution + context["ansible_distribution_version"] = facts.Version + context["ansible_architecture"] = facts.Architecture + context["ansible_kernel"] = facts.Kernel + } + + var buf strings.Builder + if err := tmpl.Execute(&buf, context); err != nil { + return e.templateString(string(content), host, task), nil + } + + return buf.String(), nil +} diff --git a/executor_test.go b/executor_test.go new file mode 100644 index 0000000..4ac9d1c --- /dev/null +++ b/executor_test.go @@ -0,0 +1,427 @@ +package ansible + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// --- NewExecutor --- + +func TestNewExecutor_Good(t *testing.T) { + e := NewExecutor("/some/path") + + assert.NotNil(t, e) + assert.NotNil(t, e.parser) + assert.NotNil(t, e.vars) + assert.NotNil(t, e.facts) + assert.NotNil(t, e.results) + assert.NotNil(t, e.handlers) + assert.NotNil(t, e.notified) + assert.NotNil(t, e.clients) +} + +// --- SetVar --- + +func TestSetVar_Good(t *testing.T) { + e := NewExecutor("/tmp") + e.SetVar("foo", "bar") + e.SetVar("count", 42) + + assert.Equal(t, "bar", e.vars["foo"]) + assert.Equal(t, 42, e.vars["count"]) +} + +// --- SetInventoryDirect --- + +func TestSetInventoryDirect_Good(t *testing.T) { + e := NewExecutor("/tmp") + inv := &Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "web1": {AnsibleHost: "10.0.0.1"}, + }, + }, + } + + e.SetInventoryDirect(inv) + assert.Equal(t, inv, e.inventory) +} + +// --- getHosts --- + +func TestGetHosts_Executor_Good_WithInventory(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "host1": {}, + "host2": {}, + }, + }, + }) + + hosts := e.getHosts("all") + assert.Len(t, hosts, 2) +} + +func TestGetHosts_Executor_Good_Localhost(t *testing.T) { + e := NewExecutor("/tmp") + // No inventory set + + hosts := e.getHosts("localhost") + assert.Equal(t, []string{"localhost"}, hosts) +} + +func TestGetHosts_Executor_Good_NoInventory(t *testing.T) { + e := NewExecutor("/tmp") + + hosts := e.getHosts("webservers") + assert.Nil(t, hosts) +} + +func TestGetHosts_Executor_Good_WithLimit(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "host1": {}, + "host2": {}, + "host3": {}, + }, + }, + }) + e.Limit = "host2" + + hosts := e.getHosts("all") + assert.Len(t, hosts, 1) + assert.Contains(t, hosts, "host2") +} + +// --- matchesTags --- + +func TestMatchesTags_Good_NoTagsFilter(t *testing.T) { + e := NewExecutor("/tmp") + + assert.True(t, e.matchesTags(nil)) + assert.True(t, e.matchesTags([]string{"any", "tags"})) +} + +func TestMatchesTags_Good_IncludeTag(t *testing.T) { + e := NewExecutor("/tmp") + e.Tags = []string{"deploy"} + + assert.True(t, e.matchesTags([]string{"deploy"})) + assert.True(t, e.matchesTags([]string{"setup", "deploy"})) + assert.False(t, e.matchesTags([]string{"other"})) +} + +func TestMatchesTags_Good_SkipTag(t *testing.T) { + e := NewExecutor("/tmp") + e.SkipTags = []string{"slow"} + + assert.True(t, e.matchesTags([]string{"fast"})) + assert.False(t, e.matchesTags([]string{"slow"})) + assert.False(t, e.matchesTags([]string{"fast", "slow"})) +} + +func TestMatchesTags_Good_AllTag(t *testing.T) { + e := NewExecutor("/tmp") + e.Tags = []string{"all"} + + assert.True(t, e.matchesTags([]string{"anything"})) +} + +func TestMatchesTags_Good_NoTaskTags(t *testing.T) { + e := NewExecutor("/tmp") + e.Tags = []string{"deploy"} + + // Tasks with no tags should not match when include tags are set + assert.False(t, e.matchesTags(nil)) + assert.False(t, e.matchesTags([]string{})) +} + +// --- handleNotify --- + +func TestHandleNotify_Good_String(t *testing.T) { + e := NewExecutor("/tmp") + e.handleNotify("restart nginx") + + assert.True(t, e.notified["restart nginx"]) +} + +func TestHandleNotify_Good_StringList(t *testing.T) { + e := NewExecutor("/tmp") + e.handleNotify([]string{"restart nginx", "reload config"}) + + assert.True(t, e.notified["restart nginx"]) + assert.True(t, e.notified["reload config"]) +} + +func TestHandleNotify_Good_AnyList(t *testing.T) { + e := NewExecutor("/tmp") + e.handleNotify([]any{"restart nginx", "reload config"}) + + assert.True(t, e.notified["restart nginx"]) + assert.True(t, e.notified["reload config"]) +} + +// --- normalizeConditions --- + +func TestNormalizeConditions_Good_String(t *testing.T) { + result := normalizeConditions("my_var is defined") + assert.Equal(t, []string{"my_var is defined"}, result) +} + +func TestNormalizeConditions_Good_StringSlice(t *testing.T) { + result := normalizeConditions([]string{"cond1", "cond2"}) + assert.Equal(t, []string{"cond1", "cond2"}, result) +} + +func TestNormalizeConditions_Good_AnySlice(t *testing.T) { + result := normalizeConditions([]any{"cond1", "cond2"}) + assert.Equal(t, []string{"cond1", "cond2"}, result) +} + +func TestNormalizeConditions_Good_Nil(t *testing.T) { + result := normalizeConditions(nil) + assert.Nil(t, result) +} + +// --- evaluateWhen --- + +func TestEvaluateWhen_Good_TrueLiteral(t *testing.T) { + e := NewExecutor("/tmp") + assert.True(t, e.evaluateWhen("true", "host1", nil)) + assert.True(t, e.evaluateWhen("True", "host1", nil)) +} + +func TestEvaluateWhen_Good_FalseLiteral(t *testing.T) { + e := NewExecutor("/tmp") + assert.False(t, e.evaluateWhen("false", "host1", nil)) + assert.False(t, e.evaluateWhen("False", "host1", nil)) +} + +func TestEvaluateWhen_Good_Negation(t *testing.T) { + e := NewExecutor("/tmp") + assert.False(t, e.evaluateWhen("not true", "host1", nil)) + assert.True(t, e.evaluateWhen("not false", "host1", nil)) +} + +func TestEvaluateWhen_Good_RegisteredVarDefined(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "myresult": {Changed: true, Failed: false}, + } + + assert.True(t, e.evaluateWhen("myresult is defined", "host1", nil)) + assert.False(t, e.evaluateWhen("myresult is not defined", "host1", nil)) + assert.False(t, e.evaluateWhen("nonexistent is defined", "host1", nil)) + assert.True(t, e.evaluateWhen("nonexistent is not defined", "host1", nil)) +} + +func TestEvaluateWhen_Good_RegisteredVarStatus(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "success_result": {Changed: true, Failed: false}, + "failed_result": {Failed: true}, + "skipped_result": {Skipped: true}, + } + + assert.True(t, e.evaluateWhen("success_result is success", "host1", nil)) + assert.True(t, e.evaluateWhen("success_result is succeeded", "host1", nil)) + assert.True(t, e.evaluateWhen("success_result is changed", "host1", nil)) + assert.True(t, e.evaluateWhen("failed_result is failed", "host1", nil)) + assert.True(t, e.evaluateWhen("skipped_result is skipped", "host1", nil)) +} + +func TestEvaluateWhen_Good_VarTruthy(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["enabled"] = true + e.vars["disabled"] = false + e.vars["name"] = "hello" + e.vars["empty"] = "" + e.vars["count"] = 5 + e.vars["zero"] = 0 + + assert.True(t, e.evalCondition("enabled", "host1")) + assert.False(t, e.evalCondition("disabled", "host1")) + assert.True(t, e.evalCondition("name", "host1")) + assert.False(t, e.evalCondition("empty", "host1")) + assert.True(t, e.evalCondition("count", "host1")) + assert.False(t, e.evalCondition("zero", "host1")) +} + +func TestEvaluateWhen_Good_MultipleConditions(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["enabled"] = true + + // All conditions must be true (AND) + assert.True(t, e.evaluateWhen([]any{"true", "True"}, "host1", nil)) + assert.False(t, e.evaluateWhen([]any{"true", "false"}, "host1", nil)) +} + +// --- templateString --- + +func TestTemplateString_Good_SimpleVar(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["name"] = "world" + + result := e.templateString("hello {{ name }}", "", nil) + assert.Equal(t, "hello world", result) +} + +func TestTemplateString_Good_MultVars(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["host"] = "example.com" + e.vars["port"] = 8080 + + result := e.templateString("http://{{ host }}:{{ port }}", "", nil) + assert.Equal(t, "http://example.com:8080", result) +} + +func TestTemplateString_Good_Unresolved(t *testing.T) { + e := NewExecutor("/tmp") + result := e.templateString("{{ undefined_var }}", "", nil) + assert.Equal(t, "{{ undefined_var }}", result) +} + +func TestTemplateString_Good_NoTemplate(t *testing.T) { + e := NewExecutor("/tmp") + result := e.templateString("plain string", "", nil) + assert.Equal(t, "plain string", result) +} + +// --- applyFilter --- + +func TestApplyFilter_Good_Default(t *testing.T) { + e := NewExecutor("/tmp") + + assert.Equal(t, "hello", e.applyFilter("hello", "default('fallback')")) + assert.Equal(t, "fallback", e.applyFilter("", "default('fallback')")) +} + +func TestApplyFilter_Good_Bool(t *testing.T) { + e := NewExecutor("/tmp") + + assert.Equal(t, "true", e.applyFilter("true", "bool")) + assert.Equal(t, "true", e.applyFilter("yes", "bool")) + assert.Equal(t, "true", e.applyFilter("1", "bool")) + assert.Equal(t, "false", e.applyFilter("false", "bool")) + assert.Equal(t, "false", e.applyFilter("no", "bool")) + assert.Equal(t, "false", e.applyFilter("anything", "bool")) +} + +func TestApplyFilter_Good_Trim(t *testing.T) { + e := NewExecutor("/tmp") + assert.Equal(t, "hello", e.applyFilter(" hello ", "trim")) +} + +// --- resolveLoop --- + +func TestResolveLoop_Good_SliceAny(t *testing.T) { + e := NewExecutor("/tmp") + items := e.resolveLoop([]any{"a", "b", "c"}, "host1") + assert.Len(t, items, 3) +} + +func TestResolveLoop_Good_SliceString(t *testing.T) { + e := NewExecutor("/tmp") + items := e.resolveLoop([]string{"a", "b", "c"}, "host1") + assert.Len(t, items, 3) +} + +func TestResolveLoop_Good_Nil(t *testing.T) { + e := NewExecutor("/tmp") + items := e.resolveLoop(nil, "host1") + assert.Nil(t, items) +} + +// --- templateArgs --- + +func TestTemplateArgs_Good(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["myvar"] = "resolved" + + args := map[string]any{ + "plain": "no template", + "templated": "{{ myvar }}", + "number": 42, + } + + result := e.templateArgs(args, "host1", nil) + assert.Equal(t, "no template", result["plain"]) + assert.Equal(t, "resolved", result["templated"]) + assert.Equal(t, 42, result["number"]) +} + +func TestTemplateArgs_Good_NestedMap(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["port"] = "8080" + + args := map[string]any{ + "nested": map[string]any{ + "port": "{{ port }}", + }, + } + + result := e.templateArgs(args, "host1", nil) + nested := result["nested"].(map[string]any) + assert.Equal(t, "8080", nested["port"]) +} + +func TestTemplateArgs_Good_ArrayValues(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["pkg"] = "nginx" + + args := map[string]any{ + "packages": []any{"{{ pkg }}", "curl"}, + } + + result := e.templateArgs(args, "host1", nil) + pkgs := result["packages"].([]any) + assert.Equal(t, "nginx", pkgs[0]) + assert.Equal(t, "curl", pkgs[1]) +} + +// --- Helper functions --- + +func TestGetStringArg_Good(t *testing.T) { + args := map[string]any{ + "name": "value", + "number": 42, + } + + assert.Equal(t, "value", getStringArg(args, "name", "")) + assert.Equal(t, "42", getStringArg(args, "number", "")) + assert.Equal(t, "default", getStringArg(args, "missing", "default")) +} + +func TestGetBoolArg_Good(t *testing.T) { + args := map[string]any{ + "enabled": true, + "disabled": false, + "yes_str": "yes", + "true_str": "true", + "one_str": "1", + "no_str": "no", + } + + assert.True(t, getBoolArg(args, "enabled", false)) + assert.False(t, getBoolArg(args, "disabled", true)) + assert.True(t, getBoolArg(args, "yes_str", false)) + assert.True(t, getBoolArg(args, "true_str", false)) + assert.True(t, getBoolArg(args, "one_str", false)) + assert.False(t, getBoolArg(args, "no_str", true)) + assert.True(t, getBoolArg(args, "missing", true)) + assert.False(t, getBoolArg(args, "missing", false)) +} + +// --- Close --- + +func TestClose_Good_EmptyClients(t *testing.T) { + e := NewExecutor("/tmp") + // Should not panic with no clients + e.Close() + assert.Empty(t, e.clients) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c95e581 --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module forge.lthn.ai/core/go-ansible + +go 1.26.0 + +require ( + forge.lthn.ai/core/go-log v0.0.1 + github.com/stretchr/testify v1.11.1 + golang.org/x/crypto v0.48.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + golang.org/x/sys v0.41.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7865091 --- /dev/null +++ b/go.sum @@ -0,0 +1,26 @@ +forge.lthn.ai/core/go-log v0.0.1 h1:x/E6EfF9vixzqiLHQOl2KT25HyBcMc9qiBkomqVlpPg= +forge.lthn.ai/core/go-log v0.0.1/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mock_ssh_test.go b/mock_ssh_test.go new file mode 100644 index 0000000..6452d74 --- /dev/null +++ b/mock_ssh_test.go @@ -0,0 +1,1416 @@ +package ansible + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "sync" +) + +// --- Mock SSH Client --- + +// MockSSHClient simulates an SSHClient for testing module logic +// without requiring real SSH connections. +type MockSSHClient struct { + mu sync.Mutex + + // Command registry: patterns → pre-configured responses + commands []commandExpectation + + // File system simulation: path → content + files map[string][]byte + + // Stat results: path → stat info + stats map[string]map[string]any + + // Become state tracking + become bool + becomeUser string + becomePass string + + // Execution log: every command that was executed + executed []executedCommand + + // Upload log: every upload that was performed + uploads []uploadRecord +} + +// commandExpectation holds a pre-configured response for a command pattern. +type commandExpectation struct { + pattern *regexp.Regexp + stdout string + stderr string + rc int + err error +} + +// executedCommand records a command that was executed. +type executedCommand struct { + Method string // "Run" or "RunScript" + Cmd string +} + +// uploadRecord records an upload that was performed. +type uploadRecord struct { + Content []byte + Remote string + Mode os.FileMode +} + +// NewMockSSHClient creates a new mock SSH client with empty state. +func NewMockSSHClient() *MockSSHClient { + return &MockSSHClient{ + files: make(map[string][]byte), + stats: make(map[string]map[string]any), + } +} + +// expectCommand registers a command pattern with a pre-configured response. +// The pattern is a regular expression matched against the full command string. +func (m *MockSSHClient) expectCommand(pattern, stdout, stderr string, rc int) { + m.mu.Lock() + defer m.mu.Unlock() + m.commands = append(m.commands, commandExpectation{ + pattern: regexp.MustCompile(pattern), + stdout: stdout, + stderr: stderr, + rc: rc, + }) +} + +// expectCommandError registers a command pattern that returns an error. +func (m *MockSSHClient) expectCommandError(pattern string, err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.commands = append(m.commands, commandExpectation{ + pattern: regexp.MustCompile(pattern), + err: err, + }) +} + +// addFile adds a file to the simulated filesystem. +func (m *MockSSHClient) addFile(path string, content []byte) { + m.mu.Lock() + defer m.mu.Unlock() + m.files[path] = content +} + +// addStat adds stat info for a path. +func (m *MockSSHClient) addStat(path string, info map[string]any) { + m.mu.Lock() + defer m.mu.Unlock() + m.stats[path] = info +} + +// Run simulates executing a command. It matches against registered +// expectations in order (last match wins) and records the execution. +func (m *MockSSHClient) Run(_ context.Context, cmd string) (string, string, int, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.executed = append(m.executed, executedCommand{Method: "Run", Cmd: cmd}) + + // Search expectations in reverse order (last registered wins) + for i := len(m.commands) - 1; i >= 0; i-- { + exp := m.commands[i] + if exp.pattern.MatchString(cmd) { + return exp.stdout, exp.stderr, exp.rc, exp.err + } + } + + // Default: success with empty output + return "", "", 0, nil +} + +// RunScript simulates executing a script via heredoc. +func (m *MockSSHClient) RunScript(_ context.Context, script string) (string, string, int, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.executed = append(m.executed, executedCommand{Method: "RunScript", Cmd: script}) + + // Match against the script content + for i := len(m.commands) - 1; i >= 0; i-- { + exp := m.commands[i] + if exp.pattern.MatchString(script) { + return exp.stdout, exp.stderr, exp.rc, exp.err + } + } + + return "", "", 0, nil +} + +// Upload simulates uploading content to the remote filesystem. +func (m *MockSSHClient) Upload(_ context.Context, local io.Reader, remote string, mode os.FileMode) error { + m.mu.Lock() + defer m.mu.Unlock() + + content, err := io.ReadAll(local) + if err != nil { + return fmt.Errorf("mock upload read: %w", err) + } + + m.uploads = append(m.uploads, uploadRecord{ + Content: content, + Remote: remote, + Mode: mode, + }) + m.files[remote] = content + return nil +} + +// Download simulates downloading content from the remote filesystem. +func (m *MockSSHClient) Download(_ context.Context, remote string) ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + content, ok := m.files[remote] + if !ok { + return nil, fmt.Errorf("file not found: %s", remote) + } + return content, nil +} + +// FileExists checks if a path exists in the simulated filesystem. +func (m *MockSSHClient) FileExists(_ context.Context, path string) (bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + + _, ok := m.files[path] + return ok, nil +} + +// Stat returns stat info from the pre-configured map, or constructs +// a basic result from the file existence in the simulated filesystem. +func (m *MockSSHClient) Stat(_ context.Context, path string) (map[string]any, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // Check explicit stat results first + if info, ok := m.stats[path]; ok { + return info, nil + } + + // Fall back to file existence + if _, ok := m.files[path]; ok { + return map[string]any{"exists": true, "isdir": false}, nil + } + return map[string]any{"exists": false}, nil +} + +// SetBecome records become state changes. +func (m *MockSSHClient) SetBecome(become bool, user, password string) { + m.mu.Lock() + defer m.mu.Unlock() + m.become = become + if user != "" { + m.becomeUser = user + } + if password != "" { + m.becomePass = password + } +} + +// Close is a no-op for the mock. +func (m *MockSSHClient) Close() error { + return nil +} + +// --- Assertion helpers --- + +// executedCommands returns a copy of the execution log. +func (m *MockSSHClient) executedCommands() []executedCommand { + m.mu.Lock() + defer m.mu.Unlock() + cp := make([]executedCommand, len(m.executed)) + copy(cp, m.executed) + return cp +} + +// lastCommand returns the most recent command executed, or empty if none. +func (m *MockSSHClient) lastCommand() executedCommand { + m.mu.Lock() + defer m.mu.Unlock() + if len(m.executed) == 0 { + return executedCommand{} + } + return m.executed[len(m.executed)-1] +} + +// commandCount returns the number of commands executed. +func (m *MockSSHClient) commandCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.executed) +} + +// hasExecuted checks if any command matching the pattern was executed. +func (m *MockSSHClient) hasExecuted(pattern string) bool { + m.mu.Lock() + defer m.mu.Unlock() + re := regexp.MustCompile(pattern) + for _, cmd := range m.executed { + if re.MatchString(cmd.Cmd) { + return true + } + } + return false +} + +// hasExecutedMethod checks if a command with the given method and matching +// pattern was executed. +func (m *MockSSHClient) hasExecutedMethod(method, pattern string) bool { + m.mu.Lock() + defer m.mu.Unlock() + re := regexp.MustCompile(pattern) + for _, cmd := range m.executed { + if cmd.Method == method && re.MatchString(cmd.Cmd) { + return true + } + } + return false +} + +// findExecuted returns the first command matching the pattern, or nil. +func (m *MockSSHClient) findExecuted(pattern string) *executedCommand { + m.mu.Lock() + defer m.mu.Unlock() + re := regexp.MustCompile(pattern) + for i := range m.executed { + if re.MatchString(m.executed[i].Cmd) { + cmd := m.executed[i] + return &cmd + } + } + return nil +} + +// uploadCount returns the number of uploads performed. +func (m *MockSSHClient) uploadCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.uploads) +} + +// lastUpload returns the most recent upload, or nil if none. +func (m *MockSSHClient) lastUpload() *uploadRecord { + m.mu.Lock() + defer m.mu.Unlock() + if len(m.uploads) == 0 { + return nil + } + u := m.uploads[len(m.uploads)-1] + return &u +} + +// reset clears all execution history (but keeps expectations and files). +func (m *MockSSHClient) reset() { + m.mu.Lock() + defer m.mu.Unlock() + m.executed = nil + m.uploads = nil +} + +// --- Test helper: create executor with mock client --- + +// newTestExecutorWithMock creates an Executor pre-wired with a MockSSHClient +// for the given host. The executor has a minimal inventory so that +// executeModule can be called directly. +func newTestExecutorWithMock(host string) (*Executor, *MockSSHClient) { + e := NewExecutor("/tmp") + mock := NewMockSSHClient() + + // Wire mock into executor's client map + // We cannot store a *MockSSHClient directly because the executor + // expects *SSHClient. Instead, we provide a helper that calls + // modules the same way the executor does but with the mock. + // Since modules call methods on *SSHClient directly and the mock + // has identical method signatures, we use a shim approach. + + // Set up minimal inventory so host resolution works + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + host: {AnsibleHost: "127.0.0.1"}, + }, + }, + }) + + return e, mock +} + +// executeModuleWithMock calls a module handler directly using the mock client. +// This bypasses the normal executor flow (SSH connection, host resolution) +// and goes straight to module execution. +func executeModuleWithMock(e *Executor, mock *MockSSHClient, host string, task *Task) (*TaskResult, error) { + module := NormalizeModule(task.Module) + args := e.templateArgs(task.Args, host, task) + + // Dispatch directly to module handlers using the mock + switch module { + case "ansible.builtin.shell": + return moduleShellWithClient(e, mock, args) + case "ansible.builtin.command": + return moduleCommandWithClient(e, mock, args) + case "ansible.builtin.raw": + return moduleRawWithClient(e, mock, args) + case "ansible.builtin.script": + return moduleScriptWithClient(e, mock, args) + case "ansible.builtin.copy": + return moduleCopyWithClient(e, mock, args, host, task) + case "ansible.builtin.template": + return moduleTemplateWithClient(e, mock, args, host, task) + case "ansible.builtin.file": + return moduleFileWithClient(e, mock, args) + case "ansible.builtin.lineinfile": + return moduleLineinfileWithClient(e, mock, args) + case "ansible.builtin.blockinfile": + return moduleBlockinfileWithClient(e, mock, args) + case "ansible.builtin.stat": + return moduleStatWithClient(e, mock, args) + // Service management + case "ansible.builtin.service": + return moduleServiceWithClient(e, mock, args) + case "ansible.builtin.systemd": + return moduleSystemdWithClient(e, mock, args) + + // Package management + case "ansible.builtin.apt": + return moduleAptWithClient(e, mock, args) + case "ansible.builtin.apt_key": + return moduleAptKeyWithClient(e, mock, args) + case "ansible.builtin.apt_repository": + return moduleAptRepositoryWithClient(e, mock, args) + case "ansible.builtin.package": + return modulePackageWithClient(e, mock, args) + case "ansible.builtin.pip": + return modulePipWithClient(e, mock, args) + + // User/group management + case "ansible.builtin.user": + return moduleUserWithClient(e, mock, args) + case "ansible.builtin.group": + return moduleGroupWithClient(e, mock, args) + + // Cron + case "ansible.builtin.cron": + return moduleCronWithClient(e, mock, args) + + // SSH keys + case "ansible.posix.authorized_key", "ansible.builtin.authorized_key": + return moduleAuthorizedKeyWithClient(e, mock, args) + + // Git + case "ansible.builtin.git": + return moduleGitWithClient(e, mock, args) + + // Archive + case "ansible.builtin.unarchive": + return moduleUnarchiveWithClient(e, mock, args) + + // HTTP + case "ansible.builtin.uri": + return moduleURIWithClient(e, mock, args) + + // Firewall + case "community.general.ufw", "ansible.builtin.ufw": + return moduleUFWWithClient(e, mock, args) + + // Docker + case "community.docker.docker_compose_v2", "ansible.builtin.docker_compose": + return moduleDockerComposeWithClient(e, mock, args) + + default: + return nil, fmt.Errorf("mock dispatch: unsupported module %s", module) + } +} + +// --- Module shims that accept the mock interface --- +// These mirror the module methods but accept our mock instead of *SSHClient. + +type sshRunner interface { + Run(ctx context.Context, cmd string) (string, string, int, error) + RunScript(ctx context.Context, script string) (string, string, int, error) +} + +func moduleShellWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + cmd := getStringArg(args, "_raw_params", "") + if cmd == "" { + cmd = getStringArg(args, "cmd", "") + } + if cmd == "" { + return nil, fmt.Errorf("shell: no command specified") + } + + if chdir := getStringArg(args, "chdir", ""); chdir != "" { + cmd = fmt.Sprintf("cd %q && %s", chdir, cmd) + } + + stdout, stderr, rc, err := client.RunScript(context.Background(), cmd) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error(), Stdout: stdout, Stderr: stderr, RC: rc}, nil + } + + return &TaskResult{ + Changed: true, + Stdout: stdout, + Stderr: stderr, + RC: rc, + Failed: rc != 0, + }, nil +} + +func moduleCommandWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + cmd := getStringArg(args, "_raw_params", "") + if cmd == "" { + cmd = getStringArg(args, "cmd", "") + } + if cmd == "" { + return nil, fmt.Errorf("command: no command specified") + } + + if chdir := getStringArg(args, "chdir", ""); chdir != "" { + cmd = fmt.Sprintf("cd %q && %s", chdir, cmd) + } + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error()}, nil + } + + return &TaskResult{ + Changed: true, + Stdout: stdout, + Stderr: stderr, + RC: rc, + Failed: rc != 0, + }, nil +} + +func moduleRawWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + cmd := getStringArg(args, "_raw_params", "") + if cmd == "" { + return nil, fmt.Errorf("raw: no command specified") + } + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error()}, nil + } + + return &TaskResult{ + Changed: true, + Stdout: stdout, + Stderr: stderr, + RC: rc, + }, nil +} + +func moduleScriptWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + script := getStringArg(args, "_raw_params", "") + if script == "" { + return nil, fmt.Errorf("script: no script specified") + } + + content, err := os.ReadFile(script) + if err != nil { + return nil, fmt.Errorf("read script: %w", err) + } + + stdout, stderr, rc, err := client.RunScript(context.Background(), string(content)) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error()}, nil + } + + return &TaskResult{ + Changed: true, + Stdout: stdout, + Stderr: stderr, + RC: rc, + Failed: rc != 0, + }, nil +} + +// --- Extended interface for file operations --- +// File modules need Upload, Stat, FileExists in addition to Run/RunScript. + +type sshFileRunner interface { + sshRunner + Upload(ctx context.Context, local io.Reader, remote string, mode os.FileMode) error + Stat(ctx context.Context, path string) (map[string]any, error) + FileExists(ctx context.Context, path string) (bool, error) +} + +// --- File module shims --- + +func moduleCopyWithClient(e *Executor, client sshFileRunner, args map[string]any, host string, task *Task) (*TaskResult, error) { + dest := getStringArg(args, "dest", "") + if dest == "" { + return nil, fmt.Errorf("copy: dest required") + } + + var content []byte + var err error + + if src := getStringArg(args, "src", ""); src != "" { + content, err = os.ReadFile(src) + if err != nil { + return nil, fmt.Errorf("read src: %w", err) + } + } else if c := getStringArg(args, "content", ""); c != "" { + content = []byte(c) + } else { + return nil, fmt.Errorf("copy: src or content required") + } + + mode := os.FileMode(0644) + if m := getStringArg(args, "mode", ""); m != "" { + if parsed, parseErr := strconv.ParseInt(m, 8, 32); parseErr == nil { + mode = os.FileMode(parsed) + } + } + + err = client.Upload(context.Background(), strings.NewReader(string(content)), dest, mode) + if err != nil { + return nil, err + } + + // Handle owner/group (best-effort, errors ignored) + if owner := getStringArg(args, "owner", ""); owner != "" { + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("chown %s %q", owner, dest)) + } + if group := getStringArg(args, "group", ""); group != "" { + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("chgrp %s %q", group, dest)) + } + + return &TaskResult{Changed: true, Msg: fmt.Sprintf("copied to %s", dest)}, nil +} + +func moduleTemplateWithClient(e *Executor, client sshFileRunner, args map[string]any, host string, task *Task) (*TaskResult, error) { + src := getStringArg(args, "src", "") + dest := getStringArg(args, "dest", "") + if src == "" || dest == "" { + return nil, fmt.Errorf("template: src and dest required") + } + + // Process template + content, err := e.TemplateFile(src, host, task) + if err != nil { + return nil, fmt.Errorf("template: %w", err) + } + + mode := os.FileMode(0644) + if m := getStringArg(args, "mode", ""); m != "" { + if parsed, parseErr := strconv.ParseInt(m, 8, 32); parseErr == nil { + mode = os.FileMode(parsed) + } + } + + err = client.Upload(context.Background(), strings.NewReader(content), dest, mode) + if err != nil { + return nil, err + } + + return &TaskResult{Changed: true, Msg: fmt.Sprintf("templated to %s", dest)}, nil +} + +func moduleFileWithClient(_ *Executor, client sshFileRunner, args map[string]any) (*TaskResult, error) { + path := getStringArg(args, "path", "") + if path == "" { + path = getStringArg(args, "dest", "") + } + if path == "" { + return nil, fmt.Errorf("file: path required") + } + + state := getStringArg(args, "state", "file") + + switch state { + case "directory": + mode := getStringArg(args, "mode", "0755") + cmd := fmt.Sprintf("mkdir -p %q && chmod %s %q", path, mode, path) + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + case "absent": + cmd := fmt.Sprintf("rm -rf %q", path) + _, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, RC: rc}, nil + } + + case "touch": + cmd := fmt.Sprintf("touch %q", path) + _, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, RC: rc}, nil + } + + case "link": + src := getStringArg(args, "src", "") + if src == "" { + return nil, fmt.Errorf("file: src required for link state") + } + cmd := fmt.Sprintf("ln -sf %q %q", src, path) + _, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, RC: rc}, nil + } + + case "file": + // Ensure file exists and set permissions + if mode := getStringArg(args, "mode", ""); mode != "" { + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("chmod %s %q", mode, path)) + } + } + + // Handle owner/group (best-effort, errors ignored) + if owner := getStringArg(args, "owner", ""); owner != "" { + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("chown %s %q", owner, path)) + } + if group := getStringArg(args, "group", ""); group != "" { + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("chgrp %s %q", group, path)) + } + if recurse := getBoolArg(args, "recurse", false); recurse { + if owner := getStringArg(args, "owner", ""); owner != "" { + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("chown -R %s %q", owner, path)) + } + } + + return &TaskResult{Changed: true}, nil +} + +func moduleLineinfileWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + path := getStringArg(args, "path", "") + if path == "" { + path = getStringArg(args, "dest", "") + } + if path == "" { + return nil, fmt.Errorf("lineinfile: path required") + } + + line := getStringArg(args, "line", "") + regexpArg := getStringArg(args, "regexp", "") + state := getStringArg(args, "state", "present") + + if state == "absent" { + if regexpArg != "" { + cmd := fmt.Sprintf("sed -i '/%s/d' %q", regexpArg, path) + _, stderr, rc, _ := client.Run(context.Background(), cmd) + if rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, RC: rc}, nil + } + } + } else { + // state == present + if regexpArg != "" { + // Replace line matching regexp + escapedLine := strings.ReplaceAll(line, "/", "\\/") + cmd := fmt.Sprintf("sed -i 's/%s/%s/' %q", regexpArg, escapedLine, path) + _, _, rc, _ := client.Run(context.Background(), cmd) + if rc != 0 { + // Line not found, append + cmd = fmt.Sprintf("echo %q >> %q", line, path) + _, _, _, _ = client.Run(context.Background(), cmd) + } + } else if line != "" { + // Ensure line is present + cmd := fmt.Sprintf("grep -qxF %q %q || echo %q >> %q", line, path, line, path) + _, _, _, _ = client.Run(context.Background(), cmd) + } + } + + return &TaskResult{Changed: true}, nil +} + +func moduleBlockinfileWithClient(_ *Executor, client sshFileRunner, args map[string]any) (*TaskResult, error) { + path := getStringArg(args, "path", "") + if path == "" { + path = getStringArg(args, "dest", "") + } + if path == "" { + return nil, fmt.Errorf("blockinfile: path required") + } + + block := getStringArg(args, "block", "") + marker := getStringArg(args, "marker", "# {mark} ANSIBLE MANAGED BLOCK") + state := getStringArg(args, "state", "present") + create := getBoolArg(args, "create", false) + + beginMarker := strings.Replace(marker, "{mark}", "BEGIN", 1) + endMarker := strings.Replace(marker, "{mark}", "END", 1) + + if state == "absent" { + // Remove block + cmd := fmt.Sprintf("sed -i '/%s/,/%s/d' %q", + strings.ReplaceAll(beginMarker, "/", "\\/"), + strings.ReplaceAll(endMarker, "/", "\\/"), + path) + _, _, _, _ = client.Run(context.Background(), cmd) + return &TaskResult{Changed: true}, nil + } + + // Create file if needed (best-effort) + if create { + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("touch %q", path)) + } + + // Remove existing block and add new one + escapedBlock := strings.ReplaceAll(block, "'", "'\\''") + cmd := fmt.Sprintf(` +sed -i '/%s/,/%s/d' %q 2>/dev/null || true +cat >> %q << 'BLOCK_EOF' +%s +%s +%s +BLOCK_EOF +`, strings.ReplaceAll(beginMarker, "/", "\\/"), + strings.ReplaceAll(endMarker, "/", "\\/"), + path, path, beginMarker, escapedBlock, endMarker) + + stdout, stderr, rc, err := client.RunScript(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +func moduleStatWithClient(_ *Executor, client sshFileRunner, args map[string]any) (*TaskResult, error) { + path := getStringArg(args, "path", "") + if path == "" { + return nil, fmt.Errorf("stat: path required") + } + + stat, err := client.Stat(context.Background(), path) + if err != nil { + return nil, err + } + + return &TaskResult{ + Changed: false, + Data: map[string]any{"stat": stat}, + }, nil +} + +// --- Service module shims --- + +func moduleServiceWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + state := getStringArg(args, "state", "") + enabled := args["enabled"] + + if name == "" { + return nil, fmt.Errorf("service: name required") + } + + var cmds []string + + if state != "" { + switch state { + case "started": + cmds = append(cmds, fmt.Sprintf("systemctl start %s", name)) + case "stopped": + cmds = append(cmds, fmt.Sprintf("systemctl stop %s", name)) + case "restarted": + cmds = append(cmds, fmt.Sprintf("systemctl restart %s", name)) + case "reloaded": + cmds = append(cmds, fmt.Sprintf("systemctl reload %s", name)) + } + } + + if enabled != nil { + if getBoolArg(args, "enabled", false) { + cmds = append(cmds, fmt.Sprintf("systemctl enable %s", name)) + } else { + cmds = append(cmds, fmt.Sprintf("systemctl disable %s", name)) + } + } + + for _, cmd := range cmds { + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + } + + return &TaskResult{Changed: len(cmds) > 0}, nil +} + +func moduleSystemdWithClient(e *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + if getBoolArg(args, "daemon_reload", false) { + _, _, _, _ = client.Run(context.Background(), "systemctl daemon-reload") + } + + return moduleServiceWithClient(e, client, args) +} + +// --- Package module shims --- + +func moduleAptWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + state := getStringArg(args, "state", "present") + updateCache := getBoolArg(args, "update_cache", false) + + var cmd string + + if updateCache { + _, _, _, _ = client.Run(context.Background(), "apt-get update -qq") + } + + switch state { + case "present", "installed": + if name != "" { + cmd = fmt.Sprintf("DEBIAN_FRONTEND=noninteractive apt-get install -y -qq %s", name) + } + case "absent", "removed": + cmd = fmt.Sprintf("DEBIAN_FRONTEND=noninteractive apt-get remove -y -qq %s", name) + case "latest": + cmd = fmt.Sprintf("DEBIAN_FRONTEND=noninteractive apt-get install -y -qq --only-upgrade %s", name) + } + + if cmd == "" { + return &TaskResult{Changed: false}, nil + } + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +func moduleAptKeyWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + url := getStringArg(args, "url", "") + keyring := getStringArg(args, "keyring", "") + state := getStringArg(args, "state", "present") + + if state == "absent" { + if keyring != "" { + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("rm -f %q", keyring)) + } + return &TaskResult{Changed: true}, nil + } + + if url == "" { + return nil, fmt.Errorf("apt_key: url required") + } + + var cmd string + if keyring != "" { + cmd = fmt.Sprintf("curl -fsSL %q | gpg --dearmor -o %q", url, keyring) + } else { + cmd = fmt.Sprintf("curl -fsSL %q | apt-key add -", url) + } + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +func moduleAptRepositoryWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + repo := getStringArg(args, "repo", "") + filename := getStringArg(args, "filename", "") + state := getStringArg(args, "state", "present") + + if repo == "" { + return nil, fmt.Errorf("apt_repository: repo required") + } + + if filename == "" { + filename = strings.ReplaceAll(repo, " ", "-") + filename = strings.ReplaceAll(filename, "/", "-") + filename = strings.ReplaceAll(filename, ":", "") + } + + path := fmt.Sprintf("/etc/apt/sources.list.d/%s.list", filename) + + if state == "absent" { + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("rm -f %q", path)) + return &TaskResult{Changed: true}, nil + } + + cmd := fmt.Sprintf("echo %q > %q", repo, path) + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + if getBoolArg(args, "update_cache", true) { + _, _, _, _ = client.Run(context.Background(), "apt-get update -qq") + } + + return &TaskResult{Changed: true}, nil +} + +func modulePackageWithClient(e *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + stdout, _, _, _ := client.Run(context.Background(), "which apt-get yum dnf 2>/dev/null | head -1") + stdout = strings.TrimSpace(stdout) + + if strings.Contains(stdout, "apt") { + return moduleAptWithClient(e, client, args) + } + + return moduleAptWithClient(e, client, args) +} + +func modulePipWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + state := getStringArg(args, "state", "present") + executable := getStringArg(args, "executable", "pip3") + + var cmd string + switch state { + case "present", "installed": + cmd = fmt.Sprintf("%s install %s", executable, name) + case "absent", "removed": + cmd = fmt.Sprintf("%s uninstall -y %s", executable, name) + case "latest": + cmd = fmt.Sprintf("%s install --upgrade %s", executable, name) + } + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +// --- User/Group module shims --- + +func moduleUserWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + state := getStringArg(args, "state", "present") + + if name == "" { + return nil, fmt.Errorf("user: name required") + } + + if state == "absent" { + cmd := fmt.Sprintf("userdel -r %s 2>/dev/null || true", name) + _, _, _, _ = client.Run(context.Background(), cmd) + return &TaskResult{Changed: true}, nil + } + + // Build useradd/usermod command + var opts []string + + if uid := getStringArg(args, "uid", ""); uid != "" { + opts = append(opts, "-u", uid) + } + if group := getStringArg(args, "group", ""); group != "" { + opts = append(opts, "-g", group) + } + if groups := getStringArg(args, "groups", ""); groups != "" { + opts = append(opts, "-G", groups) + } + if home := getStringArg(args, "home", ""); home != "" { + opts = append(opts, "-d", home) + } + if shell := getStringArg(args, "shell", ""); shell != "" { + opts = append(opts, "-s", shell) + } + if getBoolArg(args, "system", false) { + opts = append(opts, "-r") + } + if getBoolArg(args, "create_home", true) { + opts = append(opts, "-m") + } + + // Try usermod first, then useradd + optsStr := strings.Join(opts, " ") + var cmd string + if optsStr == "" { + cmd = fmt.Sprintf("id %s >/dev/null 2>&1 || useradd %s", name, name) + } else { + cmd = fmt.Sprintf("id %s >/dev/null 2>&1 && usermod %s %s || useradd %s %s", + name, optsStr, name, optsStr, name) + } + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +func moduleGroupWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + state := getStringArg(args, "state", "present") + + if name == "" { + return nil, fmt.Errorf("group: name required") + } + + if state == "absent" { + cmd := fmt.Sprintf("groupdel %s 2>/dev/null || true", name) + _, _, _, _ = client.Run(context.Background(), cmd) + return &TaskResult{Changed: true}, nil + } + + var opts []string + if gid := getStringArg(args, "gid", ""); gid != "" { + opts = append(opts, "-g", gid) + } + if getBoolArg(args, "system", false) { + opts = append(opts, "-r") + } + + cmd := fmt.Sprintf("getent group %s >/dev/null 2>&1 || groupadd %s %s", + name, strings.Join(opts, " "), name) + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +// --- Cron module shim --- + +func moduleCronWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + job := getStringArg(args, "job", "") + state := getStringArg(args, "state", "present") + user := getStringArg(args, "user", "root") + + minute := getStringArg(args, "minute", "*") + hour := getStringArg(args, "hour", "*") + day := getStringArg(args, "day", "*") + month := getStringArg(args, "month", "*") + weekday := getStringArg(args, "weekday", "*") + + if state == "absent" { + if name != "" { + // Remove by name (comment marker) + cmd := fmt.Sprintf("crontab -u %s -l 2>/dev/null | grep -v '# %s' | grep -v '%s' | crontab -u %s -", + user, name, job, user) + _, _, _, _ = client.Run(context.Background(), cmd) + } + return &TaskResult{Changed: true}, nil + } + + // Build cron entry + schedule := fmt.Sprintf("%s %s %s %s %s", minute, hour, day, month, weekday) + entry := fmt.Sprintf("%s %s # %s", schedule, job, name) + + // Add to crontab + cmd := fmt.Sprintf("(crontab -u %s -l 2>/dev/null | grep -v '# %s' ; echo %q) | crontab -u %s -", + user, name, entry, user) + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +// --- Authorized key module shim --- + +func moduleAuthorizedKeyWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + user := getStringArg(args, "user", "") + key := getStringArg(args, "key", "") + state := getStringArg(args, "state", "present") + + if user == "" || key == "" { + return nil, fmt.Errorf("authorized_key: user and key required") + } + + // Get user's home directory + stdout, _, _, err := client.Run(context.Background(), fmt.Sprintf("getent passwd %s | cut -d: -f6", user)) + if err != nil { + return nil, fmt.Errorf("get home dir: %w", err) + } + home := strings.TrimSpace(stdout) + if home == "" { + home = "/root" + if user != "root" { + home = "/home/" + user + } + } + + authKeysPath := filepath.Join(home, ".ssh", "authorized_keys") + + if state == "absent" { + // Remove key + escapedKey := strings.ReplaceAll(key, "/", "\\/") + cmd := fmt.Sprintf("sed -i '/%s/d' %q 2>/dev/null || true", escapedKey[:40], authKeysPath) + _, _, _, _ = client.Run(context.Background(), cmd) + return &TaskResult{Changed: true}, nil + } + + // Ensure .ssh directory exists (best-effort) + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("mkdir -p %q && chmod 700 %q && chown %s:%s %q", + filepath.Dir(authKeysPath), filepath.Dir(authKeysPath), user, user, filepath.Dir(authKeysPath))) + + // Add key if not present + cmd := fmt.Sprintf("grep -qF %q %q 2>/dev/null || echo %q >> %q", + key[:40], authKeysPath, key, authKeysPath) + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + // Fix permissions (best-effort) + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("chmod 600 %q && chown %s:%s %q", + authKeysPath, user, user, authKeysPath)) + + return &TaskResult{Changed: true}, nil +} + +// --- Git module shim --- + +func moduleGitWithClient(_ *Executor, client sshFileRunner, args map[string]any) (*TaskResult, error) { + repo := getStringArg(args, "repo", "") + dest := getStringArg(args, "dest", "") + version := getStringArg(args, "version", "HEAD") + + if repo == "" || dest == "" { + return nil, fmt.Errorf("git: repo and dest required") + } + + // Check if dest exists + exists, _ := client.FileExists(context.Background(), dest+"/.git") + + var cmd string + if exists { + cmd = fmt.Sprintf("cd %q && git fetch --all && git checkout --force %q", dest, version) + } else { + cmd = fmt.Sprintf("git clone %q %q && cd %q && git checkout %q", + repo, dest, dest, version) + } + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +// --- Unarchive module shim --- + +func moduleUnarchiveWithClient(_ *Executor, client sshFileRunner, args map[string]any) (*TaskResult, error) { + src := getStringArg(args, "src", "") + dest := getStringArg(args, "dest", "") + remote := getBoolArg(args, "remote_src", false) + + if src == "" || dest == "" { + return nil, fmt.Errorf("unarchive: src and dest required") + } + + // Create dest directory (best-effort) + _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("mkdir -p %q", dest)) + + var cmd string + if !remote { + // Upload local file first + content, err := os.ReadFile(src) + if err != nil { + return nil, fmt.Errorf("read src: %w", err) + } + tmpPath := "/tmp/ansible_unarchive_" + filepath.Base(src) + err = client.Upload(context.Background(), strings.NewReader(string(content)), tmpPath, 0644) + if err != nil { + return nil, err + } + src = tmpPath + defer func() { _, _, _, _ = client.Run(context.Background(), fmt.Sprintf("rm -f %q", tmpPath)) }() + } + + // Detect archive type and extract + if strings.HasSuffix(src, ".tar.gz") || strings.HasSuffix(src, ".tgz") { + cmd = fmt.Sprintf("tar -xzf %q -C %q", src, dest) + } else if strings.HasSuffix(src, ".tar.xz") { + cmd = fmt.Sprintf("tar -xJf %q -C %q", src, dest) + } else if strings.HasSuffix(src, ".tar.bz2") { + cmd = fmt.Sprintf("tar -xjf %q -C %q", src, dest) + } else if strings.HasSuffix(src, ".tar") { + cmd = fmt.Sprintf("tar -xf %q -C %q", src, dest) + } else if strings.HasSuffix(src, ".zip") { + cmd = fmt.Sprintf("unzip -o %q -d %q", src, dest) + } else { + cmd = fmt.Sprintf("tar -xf %q -C %q", src, dest) // Guess tar + } + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +// --- URI module shim --- + +func moduleURIWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + url := getStringArg(args, "url", "") + method := getStringArg(args, "method", "GET") + + if url == "" { + return nil, fmt.Errorf("uri: url required") + } + + var curlOpts []string + curlOpts = append(curlOpts, "-s", "-S") + curlOpts = append(curlOpts, "-X", method) + + // Headers + if headers, ok := args["headers"].(map[string]any); ok { + for k, v := range headers { + curlOpts = append(curlOpts, "-H", fmt.Sprintf("%s: %v", k, v)) + } + } + + // Body + if body := getStringArg(args, "body", ""); body != "" { + curlOpts = append(curlOpts, "-d", body) + } + + // Status code + curlOpts = append(curlOpts, "-w", "\\n%{http_code}") + + cmd := fmt.Sprintf("curl %s %q", strings.Join(curlOpts, " "), url) + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error()}, nil + } + + // Parse status code from last line + lines := strings.Split(strings.TrimSpace(stdout), "\n") + statusCode := 0 + if len(lines) > 0 { + statusCode, _ = strconv.Atoi(lines[len(lines)-1]) + } + + // Check expected status + expectedStatus := 200 + if s, ok := args["status_code"].(int); ok { + expectedStatus = s + } + + failed := rc != 0 || statusCode != expectedStatus + + return &TaskResult{ + Changed: false, + Failed: failed, + Stdout: stdout, + Stderr: stderr, + RC: statusCode, + Data: map[string]any{"status": statusCode}, + }, nil +} + +// --- UFW module shim --- + +func moduleUFWWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + rule := getStringArg(args, "rule", "") + port := getStringArg(args, "port", "") + proto := getStringArg(args, "proto", "tcp") + state := getStringArg(args, "state", "") + + var cmd string + + // Handle state (enable/disable) + if state != "" { + switch state { + case "enabled": + cmd = "ufw --force enable" + case "disabled": + cmd = "ufw disable" + case "reloaded": + cmd = "ufw reload" + case "reset": + cmd = "ufw --force reset" + } + if cmd != "" { + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + return &TaskResult{Changed: true}, nil + } + } + + // Handle rule + if rule != "" && port != "" { + switch rule { + case "allow": + cmd = fmt.Sprintf("ufw allow %s/%s", port, proto) + case "deny": + cmd = fmt.Sprintf("ufw deny %s/%s", port, proto) + case "reject": + cmd = fmt.Sprintf("ufw reject %s/%s", port, proto) + case "limit": + cmd = fmt.Sprintf("ufw limit %s/%s", port, proto) + } + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + } + + return &TaskResult{Changed: true}, nil +} + +// --- Docker Compose module shim --- + +func moduleDockerComposeWithClient(_ *Executor, client sshRunner, args map[string]any) (*TaskResult, error) { + projectSrc := getStringArg(args, "project_src", "") + state := getStringArg(args, "state", "present") + + if projectSrc == "" { + return nil, fmt.Errorf("docker_compose: project_src required") + } + + var cmd string + switch state { + case "present": + cmd = fmt.Sprintf("cd %q && docker compose up -d", projectSrc) + case "absent": + cmd = fmt.Sprintf("cd %q && docker compose down", projectSrc) + case "restarted": + cmd = fmt.Sprintf("cd %q && docker compose restart", projectSrc) + default: + cmd = fmt.Sprintf("cd %q && docker compose up -d", projectSrc) + } + + stdout, stderr, rc, err := client.Run(context.Background(), cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + // Heuristic for changed + changed := !strings.Contains(stdout, "Up to date") && !strings.Contains(stderr, "Up to date") + + return &TaskResult{Changed: changed, Stdout: stdout}, nil +} + +// --- String helpers for assertions --- + +// containsSubstring checks if any executed command contains the given substring. +func (m *MockSSHClient) containsSubstring(sub string) bool { + m.mu.Lock() + defer m.mu.Unlock() + for _, cmd := range m.executed { + if strings.Contains(cmd.Cmd, sub) { + return true + } + } + return false +} diff --git a/modules.go b/modules.go new file mode 100644 index 0000000..9a3c722 --- /dev/null +++ b/modules.go @@ -0,0 +1,1435 @@ +package ansible + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" +) + +// executeModule dispatches to the appropriate module handler. +func (e *Executor) executeModule(ctx context.Context, host string, client *SSHClient, task *Task, play *Play) (*TaskResult, error) { + module := NormalizeModule(task.Module) + + // Apply task-level become + if task.Become != nil && *task.Become { + // Save old state to restore + oldBecome := client.become + oldUser := client.becomeUser + oldPass := client.becomePass + + client.SetBecome(true, task.BecomeUser, "") + + defer client.SetBecome(oldBecome, oldUser, oldPass) + } + + // Template the args + args := e.templateArgs(task.Args, host, task) + + switch module { + // Command execution + case "ansible.builtin.shell": + return e.moduleShell(ctx, client, args) + case "ansible.builtin.command": + return e.moduleCommand(ctx, client, args) + case "ansible.builtin.raw": + return e.moduleRaw(ctx, client, args) + case "ansible.builtin.script": + return e.moduleScript(ctx, client, args) + + // File operations + case "ansible.builtin.copy": + return e.moduleCopy(ctx, client, args, host, task) + case "ansible.builtin.template": + return e.moduleTemplate(ctx, client, args, host, task) + case "ansible.builtin.file": + return e.moduleFile(ctx, client, args) + case "ansible.builtin.lineinfile": + return e.moduleLineinfile(ctx, client, args) + case "ansible.builtin.stat": + return e.moduleStat(ctx, client, args) + case "ansible.builtin.slurp": + return e.moduleSlurp(ctx, client, args) + case "ansible.builtin.fetch": + return e.moduleFetch(ctx, client, args) + case "ansible.builtin.get_url": + return e.moduleGetURL(ctx, client, args) + + // Package management + case "ansible.builtin.apt": + return e.moduleApt(ctx, client, args) + case "ansible.builtin.apt_key": + return e.moduleAptKey(ctx, client, args) + case "ansible.builtin.apt_repository": + return e.moduleAptRepository(ctx, client, args) + case "ansible.builtin.package": + return e.modulePackage(ctx, client, args) + case "ansible.builtin.pip": + return e.modulePip(ctx, client, args) + + // Service management + case "ansible.builtin.service": + return e.moduleService(ctx, client, args) + case "ansible.builtin.systemd": + return e.moduleSystemd(ctx, client, args) + + // User/Group + case "ansible.builtin.user": + return e.moduleUser(ctx, client, args) + case "ansible.builtin.group": + return e.moduleGroup(ctx, client, args) + + // HTTP + case "ansible.builtin.uri": + return e.moduleURI(ctx, client, args) + + // Misc + case "ansible.builtin.debug": + return e.moduleDebug(args) + case "ansible.builtin.fail": + return e.moduleFail(args) + case "ansible.builtin.assert": + return e.moduleAssert(args, host) + case "ansible.builtin.set_fact": + return e.moduleSetFact(args) + case "ansible.builtin.pause": + return e.modulePause(ctx, args) + case "ansible.builtin.wait_for": + return e.moduleWaitFor(ctx, client, args) + case "ansible.builtin.git": + return e.moduleGit(ctx, client, args) + case "ansible.builtin.unarchive": + return e.moduleUnarchive(ctx, client, args) + + // Additional modules + case "ansible.builtin.hostname": + return e.moduleHostname(ctx, client, args) + case "ansible.builtin.sysctl": + return e.moduleSysctl(ctx, client, args) + case "ansible.builtin.cron": + return e.moduleCron(ctx, client, args) + case "ansible.builtin.blockinfile": + return e.moduleBlockinfile(ctx, client, args) + case "ansible.builtin.include_vars": + return e.moduleIncludeVars(args) + case "ansible.builtin.meta": + return e.moduleMeta(args) + case "ansible.builtin.setup": + return e.moduleSetup(ctx, client) + case "ansible.builtin.reboot": + return e.moduleReboot(ctx, client, args) + + // Community modules (basic support) + case "community.general.ufw": + return e.moduleUFW(ctx, client, args) + case "ansible.posix.authorized_key": + return e.moduleAuthorizedKey(ctx, client, args) + case "community.docker.docker_compose": + return e.moduleDockerCompose(ctx, client, args) + + default: + // For unknown modules, try to execute as shell if it looks like a command + if strings.Contains(task.Module, " ") || task.Module == "" { + return e.moduleShell(ctx, client, args) + } + return nil, fmt.Errorf("unsupported module: %s", module) + } +} + +// templateArgs templates all string values in args. +func (e *Executor) templateArgs(args map[string]any, host string, task *Task) map[string]any { + // Set inventory_hostname for templating + e.vars["inventory_hostname"] = host + + result := make(map[string]any) + for k, v := range args { + switch val := v.(type) { + case string: + result[k] = e.templateString(val, host, task) + case map[string]any: + // Recurse for nested maps + result[k] = e.templateArgs(val, host, task) + case []any: + // Template strings in arrays + templated := make([]any, len(val)) + for i, item := range val { + if s, ok := item.(string); ok { + templated[i] = e.templateString(s, host, task) + } else { + templated[i] = item + } + } + result[k] = templated + default: + result[k] = v + } + } + return result +} + +// --- Command Modules --- + +func (e *Executor) moduleShell(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + cmd := getStringArg(args, "_raw_params", "") + if cmd == "" { + cmd = getStringArg(args, "cmd", "") + } + if cmd == "" { + return nil, errors.New("shell: no command specified") + } + + // Handle chdir + if chdir := getStringArg(args, "chdir", ""); chdir != "" { + cmd = fmt.Sprintf("cd %q && %s", chdir, cmd) + } + + stdout, stderr, rc, err := client.RunScript(ctx, cmd) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error(), Stdout: stdout, Stderr: stderr, RC: rc}, nil + } + + return &TaskResult{ + Changed: true, + Stdout: stdout, + Stderr: stderr, + RC: rc, + Failed: rc != 0, + }, nil +} + +func (e *Executor) moduleCommand(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + cmd := getStringArg(args, "_raw_params", "") + if cmd == "" { + cmd = getStringArg(args, "cmd", "") + } + if cmd == "" { + return nil, errors.New("command: no command specified") + } + + // Handle chdir + if chdir := getStringArg(args, "chdir", ""); chdir != "" { + cmd = fmt.Sprintf("cd %q && %s", chdir, cmd) + } + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error()}, nil + } + + return &TaskResult{ + Changed: true, + Stdout: stdout, + Stderr: stderr, + RC: rc, + Failed: rc != 0, + }, nil +} + +func (e *Executor) moduleRaw(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + cmd := getStringArg(args, "_raw_params", "") + if cmd == "" { + return nil, errors.New("raw: no command specified") + } + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error()}, nil + } + + return &TaskResult{ + Changed: true, + Stdout: stdout, + Stderr: stderr, + RC: rc, + }, nil +} + +func (e *Executor) moduleScript(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + script := getStringArg(args, "_raw_params", "") + if script == "" { + return nil, errors.New("script: no script specified") + } + + // Read local script + content, err := os.ReadFile(script) + if err != nil { + return nil, fmt.Errorf("read script: %w", err) + } + + stdout, stderr, rc, err := client.RunScript(ctx, string(content)) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error()}, nil + } + + return &TaskResult{ + Changed: true, + Stdout: stdout, + Stderr: stderr, + RC: rc, + Failed: rc != 0, + }, nil +} + +// --- File Modules --- + +func (e *Executor) moduleCopy(ctx context.Context, client *SSHClient, args map[string]any, host string, task *Task) (*TaskResult, error) { + dest := getStringArg(args, "dest", "") + if dest == "" { + return nil, errors.New("copy: dest required") + } + + var content []byte + var err error + + if src := getStringArg(args, "src", ""); src != "" { + content, err = os.ReadFile(src) + if err != nil { + return nil, fmt.Errorf("read src: %w", err) + } + } else if c := getStringArg(args, "content", ""); c != "" { + content = []byte(c) + } else { + return nil, errors.New("copy: src or content required") + } + + mode := os.FileMode(0644) + if m := getStringArg(args, "mode", ""); m != "" { + if parsed, err := strconv.ParseInt(m, 8, 32); err == nil { + mode = os.FileMode(parsed) + } + } + + err = client.Upload(ctx, strings.NewReader(string(content)), dest, mode) + if err != nil { + return nil, err + } + + // Handle owner/group (best-effort, errors ignored) + if owner := getStringArg(args, "owner", ""); owner != "" { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("chown %s %q", owner, dest)) + } + if group := getStringArg(args, "group", ""); group != "" { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("chgrp %s %q", group, dest)) + } + + return &TaskResult{Changed: true, Msg: fmt.Sprintf("copied to %s", dest)}, nil +} + +func (e *Executor) moduleTemplate(ctx context.Context, client *SSHClient, args map[string]any, host string, task *Task) (*TaskResult, error) { + src := getStringArg(args, "src", "") + dest := getStringArg(args, "dest", "") + if src == "" || dest == "" { + return nil, errors.New("template: src and dest required") + } + + // Process template + content, err := e.TemplateFile(src, host, task) + if err != nil { + return nil, fmt.Errorf("template: %w", err) + } + + mode := os.FileMode(0644) + if m := getStringArg(args, "mode", ""); m != "" { + if parsed, err := strconv.ParseInt(m, 8, 32); err == nil { + mode = os.FileMode(parsed) + } + } + + err = client.Upload(ctx, strings.NewReader(content), dest, mode) + if err != nil { + return nil, err + } + + return &TaskResult{Changed: true, Msg: fmt.Sprintf("templated to %s", dest)}, nil +} + +func (e *Executor) moduleFile(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + path := getStringArg(args, "path", "") + if path == "" { + path = getStringArg(args, "dest", "") + } + if path == "" { + return nil, errors.New("file: path required") + } + + state := getStringArg(args, "state", "file") + + switch state { + case "directory": + mode := getStringArg(args, "mode", "0755") + cmd := fmt.Sprintf("mkdir -p %q && chmod %s %q", path, mode, path) + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + case "absent": + cmd := fmt.Sprintf("rm -rf %q", path) + _, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, RC: rc}, nil + } + + case "touch": + cmd := fmt.Sprintf("touch %q", path) + _, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, RC: rc}, nil + } + + case "link": + src := getStringArg(args, "src", "") + if src == "" { + return nil, errors.New("file: src required for link state") + } + cmd := fmt.Sprintf("ln -sf %q %q", src, path) + _, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, RC: rc}, nil + } + + case "file": + // Ensure file exists and set permissions + if mode := getStringArg(args, "mode", ""); mode != "" { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("chmod %s %q", mode, path)) + } + } + + // Handle owner/group (best-effort, errors ignored) + if owner := getStringArg(args, "owner", ""); owner != "" { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("chown %s %q", owner, path)) + } + if group := getStringArg(args, "group", ""); group != "" { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("chgrp %s %q", group, path)) + } + if recurse := getBoolArg(args, "recurse", false); recurse { + if owner := getStringArg(args, "owner", ""); owner != "" { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("chown -R %s %q", owner, path)) + } + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleLineinfile(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + path := getStringArg(args, "path", "") + if path == "" { + path = getStringArg(args, "dest", "") + } + if path == "" { + return nil, errors.New("lineinfile: path required") + } + + line := getStringArg(args, "line", "") + regexp := getStringArg(args, "regexp", "") + state := getStringArg(args, "state", "present") + + if state == "absent" { + if regexp != "" { + cmd := fmt.Sprintf("sed -i '/%s/d' %q", regexp, path) + _, stderr, rc, _ := client.Run(ctx, cmd) + if rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, RC: rc}, nil + } + } + } else { + // state == present + if regexp != "" { + // Replace line matching regexp + escapedLine := strings.ReplaceAll(line, "/", "\\/") + cmd := fmt.Sprintf("sed -i 's/%s/%s/' %q", regexp, escapedLine, path) + _, _, rc, _ := client.Run(ctx, cmd) + if rc != 0 { + // Line not found, append + cmd = fmt.Sprintf("echo %q >> %q", line, path) + _, _, _, _ = client.Run(ctx, cmd) + } + } else if line != "" { + // Ensure line is present + cmd := fmt.Sprintf("grep -qxF %q %q || echo %q >> %q", line, path, line, path) + _, _, _, _ = client.Run(ctx, cmd) + } + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleStat(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + path := getStringArg(args, "path", "") + if path == "" { + return nil, errors.New("stat: path required") + } + + stat, err := client.Stat(ctx, path) + if err != nil { + return nil, err + } + + return &TaskResult{ + Changed: false, + Data: map[string]any{"stat": stat}, + }, nil +} + +func (e *Executor) moduleSlurp(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + path := getStringArg(args, "path", "") + if path == "" { + path = getStringArg(args, "src", "") + } + if path == "" { + return nil, errors.New("slurp: path required") + } + + content, err := client.Download(ctx, path) + if err != nil { + return nil, err + } + + encoded := base64.StdEncoding.EncodeToString(content) + + return &TaskResult{ + Changed: false, + Data: map[string]any{"content": encoded, "encoding": "base64"}, + }, nil +} + +func (e *Executor) moduleFetch(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + src := getStringArg(args, "src", "") + dest := getStringArg(args, "dest", "") + if src == "" || dest == "" { + return nil, errors.New("fetch: src and dest required") + } + + content, err := client.Download(ctx, src) + if err != nil { + return nil, err + } + + // Create dest directory + if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil { + return nil, err + } + + if err := os.WriteFile(dest, content, 0644); err != nil { + return nil, err + } + + return &TaskResult{Changed: true, Msg: fmt.Sprintf("fetched %s to %s", src, dest)}, nil +} + +func (e *Executor) moduleGetURL(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + url := getStringArg(args, "url", "") + dest := getStringArg(args, "dest", "") + if url == "" || dest == "" { + return nil, errors.New("get_url: url and dest required") + } + + // Use curl or wget + cmd := fmt.Sprintf("curl -fsSL -o %q %q || wget -q -O %q %q", dest, url, dest, url) + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + // Set mode if specified (best-effort) + if mode := getStringArg(args, "mode", ""); mode != "" { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("chmod %s %q", mode, dest)) + } + + return &TaskResult{Changed: true}, nil +} + +// --- Package Modules --- + +func (e *Executor) moduleApt(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + state := getStringArg(args, "state", "present") + updateCache := getBoolArg(args, "update_cache", false) + + var cmd string + + if updateCache { + _, _, _, _ = client.Run(ctx, "apt-get update -qq") + } + + switch state { + case "present", "installed": + if name != "" { + cmd = fmt.Sprintf("DEBIAN_FRONTEND=noninteractive apt-get install -y -qq %s", name) + } + case "absent", "removed": + cmd = fmt.Sprintf("DEBIAN_FRONTEND=noninteractive apt-get remove -y -qq %s", name) + case "latest": + cmd = fmt.Sprintf("DEBIAN_FRONTEND=noninteractive apt-get install -y -qq --only-upgrade %s", name) + } + + if cmd == "" { + return &TaskResult{Changed: false}, nil + } + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleAptKey(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + url := getStringArg(args, "url", "") + keyring := getStringArg(args, "keyring", "") + state := getStringArg(args, "state", "present") + + if state == "absent" { + if keyring != "" { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("rm -f %q", keyring)) + } + return &TaskResult{Changed: true}, nil + } + + if url == "" { + return nil, errors.New("apt_key: url required") + } + + var cmd string + if keyring != "" { + cmd = fmt.Sprintf("curl -fsSL %q | gpg --dearmor -o %q", url, keyring) + } else { + cmd = fmt.Sprintf("curl -fsSL %q | apt-key add -", url) + } + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleAptRepository(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + repo := getStringArg(args, "repo", "") + filename := getStringArg(args, "filename", "") + state := getStringArg(args, "state", "present") + + if repo == "" { + return nil, errors.New("apt_repository: repo required") + } + + if filename == "" { + // Generate filename from repo + filename = strings.ReplaceAll(repo, " ", "-") + filename = strings.ReplaceAll(filename, "/", "-") + filename = strings.ReplaceAll(filename, ":", "") + } + + path := fmt.Sprintf("/etc/apt/sources.list.d/%s.list", filename) + + if state == "absent" { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("rm -f %q", path)) + return &TaskResult{Changed: true}, nil + } + + cmd := fmt.Sprintf("echo %q > %q", repo, path) + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + // Update apt cache (best-effort) + if getBoolArg(args, "update_cache", true) { + _, _, _, _ = client.Run(ctx, "apt-get update -qq") + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) modulePackage(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + // Detect package manager and delegate + stdout, _, _, _ := client.Run(ctx, "which apt-get yum dnf 2>/dev/null | head -1") + stdout = strings.TrimSpace(stdout) + + if strings.Contains(stdout, "apt") { + return e.moduleApt(ctx, client, args) + } + + // Default to apt + return e.moduleApt(ctx, client, args) +} + +func (e *Executor) modulePip(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + state := getStringArg(args, "state", "present") + executable := getStringArg(args, "executable", "pip3") + + var cmd string + switch state { + case "present", "installed": + cmd = fmt.Sprintf("%s install %s", executable, name) + case "absent", "removed": + cmd = fmt.Sprintf("%s uninstall -y %s", executable, name) + case "latest": + cmd = fmt.Sprintf("%s install --upgrade %s", executable, name) + } + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +// --- Service Modules --- + +func (e *Executor) moduleService(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + state := getStringArg(args, "state", "") + enabled := args["enabled"] + + if name == "" { + return nil, errors.New("service: name required") + } + + var cmds []string + + if state != "" { + switch state { + case "started": + cmds = append(cmds, fmt.Sprintf("systemctl start %s", name)) + case "stopped": + cmds = append(cmds, fmt.Sprintf("systemctl stop %s", name)) + case "restarted": + cmds = append(cmds, fmt.Sprintf("systemctl restart %s", name)) + case "reloaded": + cmds = append(cmds, fmt.Sprintf("systemctl reload %s", name)) + } + } + + if enabled != nil { + if getBoolArg(args, "enabled", false) { + cmds = append(cmds, fmt.Sprintf("systemctl enable %s", name)) + } else { + cmds = append(cmds, fmt.Sprintf("systemctl disable %s", name)) + } + } + + for _, cmd := range cmds { + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + } + + return &TaskResult{Changed: len(cmds) > 0}, nil +} + +func (e *Executor) moduleSystemd(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + // systemd is similar to service + if getBoolArg(args, "daemon_reload", false) { + _, _, _, _ = client.Run(ctx, "systemctl daemon-reload") + } + + return e.moduleService(ctx, client, args) +} + +// --- User/Group Modules --- + +func (e *Executor) moduleUser(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + state := getStringArg(args, "state", "present") + + if name == "" { + return nil, errors.New("user: name required") + } + + if state == "absent" { + cmd := fmt.Sprintf("userdel -r %s 2>/dev/null || true", name) + _, _, _, _ = client.Run(ctx, cmd) + return &TaskResult{Changed: true}, nil + } + + // Build useradd/usermod command + var opts []string + + if uid := getStringArg(args, "uid", ""); uid != "" { + opts = append(opts, "-u", uid) + } + if group := getStringArg(args, "group", ""); group != "" { + opts = append(opts, "-g", group) + } + if groups := getStringArg(args, "groups", ""); groups != "" { + opts = append(opts, "-G", groups) + } + if home := getStringArg(args, "home", ""); home != "" { + opts = append(opts, "-d", home) + } + if shell := getStringArg(args, "shell", ""); shell != "" { + opts = append(opts, "-s", shell) + } + if getBoolArg(args, "system", false) { + opts = append(opts, "-r") + } + if getBoolArg(args, "create_home", true) { + opts = append(opts, "-m") + } + + // Try usermod first, then useradd + optsStr := strings.Join(opts, " ") + var cmd string + if optsStr == "" { + cmd = fmt.Sprintf("id %s >/dev/null 2>&1 || useradd %s", name, name) + } else { + cmd = fmt.Sprintf("id %s >/dev/null 2>&1 && usermod %s %s || useradd %s %s", + name, optsStr, name, optsStr, name) + } + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleGroup(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + state := getStringArg(args, "state", "present") + + if name == "" { + return nil, errors.New("group: name required") + } + + if state == "absent" { + cmd := fmt.Sprintf("groupdel %s 2>/dev/null || true", name) + _, _, _, _ = client.Run(ctx, cmd) + return &TaskResult{Changed: true}, nil + } + + var opts []string + if gid := getStringArg(args, "gid", ""); gid != "" { + opts = append(opts, "-g", gid) + } + if getBoolArg(args, "system", false) { + opts = append(opts, "-r") + } + + cmd := fmt.Sprintf("getent group %s >/dev/null 2>&1 || groupadd %s %s", + name, strings.Join(opts, " "), name) + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +// --- HTTP Module --- + +func (e *Executor) moduleURI(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + url := getStringArg(args, "url", "") + method := getStringArg(args, "method", "GET") + + if url == "" { + return nil, errors.New("uri: url required") + } + + var curlOpts []string + curlOpts = append(curlOpts, "-s", "-S") + curlOpts = append(curlOpts, "-X", method) + + // Headers + if headers, ok := args["headers"].(map[string]any); ok { + for k, v := range headers { + curlOpts = append(curlOpts, "-H", fmt.Sprintf("%s: %v", k, v)) + } + } + + // Body + if body := getStringArg(args, "body", ""); body != "" { + curlOpts = append(curlOpts, "-d", body) + } + + // Status code + curlOpts = append(curlOpts, "-w", "\\n%{http_code}") + + cmd := fmt.Sprintf("curl %s %q", strings.Join(curlOpts, " "), url) + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error()}, nil + } + + // Parse status code from last line + lines := strings.Split(strings.TrimSpace(stdout), "\n") + statusCode := 0 + if len(lines) > 0 { + statusCode, _ = strconv.Atoi(lines[len(lines)-1]) + } + + // Check expected status + expectedStatus := 200 + if s, ok := args["status_code"].(int); ok { + expectedStatus = s + } + + failed := rc != 0 || statusCode != expectedStatus + + return &TaskResult{ + Changed: false, + Failed: failed, + Stdout: stdout, + Stderr: stderr, + RC: statusCode, + Data: map[string]any{"status": statusCode}, + }, nil +} + +// --- Misc Modules --- + +func (e *Executor) moduleDebug(args map[string]any) (*TaskResult, error) { + msg := getStringArg(args, "msg", "") + if v, ok := args["var"]; ok { + msg = fmt.Sprintf("%v = %v", v, e.vars[fmt.Sprintf("%v", v)]) + } + + return &TaskResult{ + Changed: false, + Msg: msg, + }, nil +} + +func (e *Executor) moduleFail(args map[string]any) (*TaskResult, error) { + msg := getStringArg(args, "msg", "Failed as requested") + return &TaskResult{ + Failed: true, + Msg: msg, + }, nil +} + +func (e *Executor) moduleAssert(args map[string]any, host string) (*TaskResult, error) { + that, ok := args["that"] + if !ok { + return nil, errors.New("assert: 'that' required") + } + + conditions := normalizeConditions(that) + for _, cond := range conditions { + if !e.evalCondition(cond, host) { + msg := getStringArg(args, "fail_msg", fmt.Sprintf("Assertion failed: %s", cond)) + return &TaskResult{Failed: true, Msg: msg}, nil + } + } + + return &TaskResult{Changed: false, Msg: "All assertions passed"}, nil +} + +func (e *Executor) moduleSetFact(args map[string]any) (*TaskResult, error) { + for k, v := range args { + if k != "cacheable" { + e.vars[k] = v + } + } + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) modulePause(ctx context.Context, args map[string]any) (*TaskResult, error) { + seconds := 0 + if s, ok := args["seconds"].(int); ok { + seconds = s + } + if s, ok := args["seconds"].(string); ok { + seconds, _ = strconv.Atoi(s) + } + + if seconds > 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ctxSleep(ctx, seconds): + } + } + + return &TaskResult{Changed: false}, nil +} + +func ctxSleep(ctx context.Context, seconds int) <-chan struct{} { + ch := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + case <-sleepChan(seconds): + } + close(ch) + }() + return ch +} + +func sleepChan(seconds int) <-chan struct{} { + ch := make(chan struct{}) + go func() { + for range seconds { + select { + case <-ch: + return + default: + // Sleep 1 second at a time + } + } + close(ch) + }() + return ch +} + +func (e *Executor) moduleWaitFor(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + port := 0 + if p, ok := args["port"].(int); ok { + port = p + } + host := getStringArg(args, "host", "127.0.0.1") + state := getStringArg(args, "state", "started") + timeout := 300 + if t, ok := args["timeout"].(int); ok { + timeout = t + } + + if port > 0 && state == "started" { + cmd := fmt.Sprintf("timeout %d bash -c 'until nc -z %s %d; do sleep 1; done'", + timeout, host, port) + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + } + + return &TaskResult{Changed: false}, nil +} + +func (e *Executor) moduleGit(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + repo := getStringArg(args, "repo", "") + dest := getStringArg(args, "dest", "") + version := getStringArg(args, "version", "HEAD") + + if repo == "" || dest == "" { + return nil, errors.New("git: repo and dest required") + } + + // Check if dest exists + exists, _ := client.FileExists(ctx, dest+"/.git") + + var cmd string + if exists { + // Fetch and checkout (force to ensure clean state) + cmd = fmt.Sprintf("cd %q && git fetch --all && git checkout --force %q", dest, version) + } else { + cmd = fmt.Sprintf("git clone %q %q && cd %q && git checkout %q", + repo, dest, dest, version) + } + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleUnarchive(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + src := getStringArg(args, "src", "") + dest := getStringArg(args, "dest", "") + remote := getBoolArg(args, "remote_src", false) + + if src == "" || dest == "" { + return nil, errors.New("unarchive: src and dest required") + } + + // Create dest directory (best-effort) + _, _, _, _ = client.Run(ctx, fmt.Sprintf("mkdir -p %q", dest)) + + var cmd string + if !remote { + // Upload local file first + content, err := os.ReadFile(src) + if err != nil { + return nil, fmt.Errorf("read src: %w", err) + } + tmpPath := "/tmp/ansible_unarchive_" + filepath.Base(src) + err = client.Upload(ctx, strings.NewReader(string(content)), tmpPath, 0644) + if err != nil { + return nil, err + } + src = tmpPath + defer func() { _, _, _, _ = client.Run(ctx, fmt.Sprintf("rm -f %q", tmpPath)) }() + } + + // Detect archive type and extract + if strings.HasSuffix(src, ".tar.gz") || strings.HasSuffix(src, ".tgz") { + cmd = fmt.Sprintf("tar -xzf %q -C %q", src, dest) + } else if strings.HasSuffix(src, ".tar.xz") { + cmd = fmt.Sprintf("tar -xJf %q -C %q", src, dest) + } else if strings.HasSuffix(src, ".tar.bz2") { + cmd = fmt.Sprintf("tar -xjf %q -C %q", src, dest) + } else if strings.HasSuffix(src, ".tar") { + cmd = fmt.Sprintf("tar -xf %q -C %q", src, dest) + } else if strings.HasSuffix(src, ".zip") { + cmd = fmt.Sprintf("unzip -o %q -d %q", src, dest) + } else { + cmd = fmt.Sprintf("tar -xf %q -C %q", src, dest) // Guess tar + } + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +// --- Helpers --- + +func getStringArg(args map[string]any, key, def string) string { + if v, ok := args[key]; ok { + if s, ok := v.(string); ok { + return s + } + return fmt.Sprintf("%v", v) + } + return def +} + +func getBoolArg(args map[string]any, key string, def bool) bool { + if v, ok := args[key]; ok { + switch b := v.(type) { + case bool: + return b + case string: + lower := strings.ToLower(b) + return lower == "true" || lower == "yes" || lower == "1" + } + } + return def +} + +// --- Additional Modules --- + +func (e *Executor) moduleHostname(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + if name == "" { + return nil, errors.New("hostname: name required") + } + + // Set hostname + cmd := fmt.Sprintf("hostnamectl set-hostname %q || hostname %q", name, name) + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + // Update /etc/hosts if needed (best-effort) + _, _, _, _ = client.Run(ctx, fmt.Sprintf("sed -i 's/127.0.1.1.*/127.0.1.1\t%s/' /etc/hosts", name)) + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleSysctl(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + value := getStringArg(args, "value", "") + state := getStringArg(args, "state", "present") + + if name == "" { + return nil, errors.New("sysctl: name required") + } + + if state == "absent" { + // Remove from sysctl.conf + cmd := fmt.Sprintf("sed -i '/%s/d' /etc/sysctl.conf", name) + _, _, _, _ = client.Run(ctx, cmd) + return &TaskResult{Changed: true}, nil + } + + // Set value + cmd := fmt.Sprintf("sysctl -w %s=%s", name, value) + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + // Persist if requested (best-effort) + if getBoolArg(args, "sysctl_set", true) { + cmd = fmt.Sprintf("grep -q '^%s' /etc/sysctl.conf && sed -i 's/^%s.*/%s=%s/' /etc/sysctl.conf || echo '%s=%s' >> /etc/sysctl.conf", + name, name, name, value, name, value) + _, _, _, _ = client.Run(ctx, cmd) + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleCron(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + name := getStringArg(args, "name", "") + job := getStringArg(args, "job", "") + state := getStringArg(args, "state", "present") + user := getStringArg(args, "user", "root") + + minute := getStringArg(args, "minute", "*") + hour := getStringArg(args, "hour", "*") + day := getStringArg(args, "day", "*") + month := getStringArg(args, "month", "*") + weekday := getStringArg(args, "weekday", "*") + + if state == "absent" { + if name != "" { + // Remove by name (comment marker) + cmd := fmt.Sprintf("crontab -u %s -l 2>/dev/null | grep -v '# %s' | grep -v '%s' | crontab -u %s -", + user, name, job, user) + _, _, _, _ = client.Run(ctx, cmd) + } + return &TaskResult{Changed: true}, nil + } + + // Build cron entry + schedule := fmt.Sprintf("%s %s %s %s %s", minute, hour, day, month, weekday) + entry := fmt.Sprintf("%s %s # %s", schedule, job, name) + + // Add to crontab + cmd := fmt.Sprintf("(crontab -u %s -l 2>/dev/null | grep -v '# %s' ; echo %q) | crontab -u %s -", + user, name, entry, user) + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleBlockinfile(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + path := getStringArg(args, "path", "") + if path == "" { + path = getStringArg(args, "dest", "") + } + if path == "" { + return nil, errors.New("blockinfile: path required") + } + + block := getStringArg(args, "block", "") + marker := getStringArg(args, "marker", "# {mark} ANSIBLE MANAGED BLOCK") + state := getStringArg(args, "state", "present") + create := getBoolArg(args, "create", false) + + beginMarker := strings.Replace(marker, "{mark}", "BEGIN", 1) + endMarker := strings.Replace(marker, "{mark}", "END", 1) + + if state == "absent" { + // Remove block + cmd := fmt.Sprintf("sed -i '/%s/,/%s/d' %q", + strings.ReplaceAll(beginMarker, "/", "\\/"), + strings.ReplaceAll(endMarker, "/", "\\/"), + path) + _, _, _, _ = client.Run(ctx, cmd) + return &TaskResult{Changed: true}, nil + } + + // Create file if needed (best-effort) + if create { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("touch %q", path)) + } + + // Remove existing block and add new one + escapedBlock := strings.ReplaceAll(block, "'", "'\\''") + cmd := fmt.Sprintf(` +sed -i '/%s/,/%s/d' %q 2>/dev/null || true +cat >> %q << 'BLOCK_EOF' +%s +%s +%s +BLOCK_EOF +`, strings.ReplaceAll(beginMarker, "/", "\\/"), + strings.ReplaceAll(endMarker, "/", "\\/"), + path, path, beginMarker, escapedBlock, endMarker) + + stdout, stderr, rc, err := client.RunScript(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleIncludeVars(args map[string]any) (*TaskResult, error) { + file := getStringArg(args, "file", "") + if file == "" { + file = getStringArg(args, "_raw_params", "") + } + + if file != "" { + // Would need to read and parse the vars file + // For now, just acknowledge + return &TaskResult{Changed: false, Msg: "include_vars: " + file}, nil + } + + return &TaskResult{Changed: false}, nil +} + +func (e *Executor) moduleMeta(args map[string]any) (*TaskResult, error) { + // meta module controls play execution + // Most actions are no-ops for us + return &TaskResult{Changed: false}, nil +} + +func (e *Executor) moduleSetup(ctx context.Context, client *SSHClient) (*TaskResult, error) { + // Gather facts - similar to what we do in gatherFacts + return &TaskResult{Changed: false, Msg: "facts gathered"}, nil +} + +func (e *Executor) moduleReboot(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + preRebootDelay := 0 + if d, ok := args["pre_reboot_delay"].(int); ok { + preRebootDelay = d + } + + msg := getStringArg(args, "msg", "Reboot initiated by Ansible") + + if preRebootDelay > 0 { + cmd := fmt.Sprintf("sleep %d && shutdown -r now '%s' &", preRebootDelay, msg) + _, _, _, _ = client.Run(ctx, cmd) + } else { + _, _, _, _ = client.Run(ctx, fmt.Sprintf("shutdown -r now '%s' &", msg)) + } + + return &TaskResult{Changed: true, Msg: "Reboot initiated"}, nil +} + +func (e *Executor) moduleUFW(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + rule := getStringArg(args, "rule", "") + port := getStringArg(args, "port", "") + proto := getStringArg(args, "proto", "tcp") + state := getStringArg(args, "state", "") + + var cmd string + + // Handle state (enable/disable) + if state != "" { + switch state { + case "enabled": + cmd = "ufw --force enable" + case "disabled": + cmd = "ufw disable" + case "reloaded": + cmd = "ufw reload" + case "reset": + cmd = "ufw --force reset" + } + if cmd != "" { + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + return &TaskResult{Changed: true}, nil + } + } + + // Handle rule + if rule != "" && port != "" { + switch rule { + case "allow": + cmd = fmt.Sprintf("ufw allow %s/%s", port, proto) + case "deny": + cmd = fmt.Sprintf("ufw deny %s/%s", port, proto) + case "reject": + cmd = fmt.Sprintf("ufw reject %s/%s", port, proto) + case "limit": + cmd = fmt.Sprintf("ufw limit %s/%s", port, proto) + } + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + } + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleAuthorizedKey(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + user := getStringArg(args, "user", "") + key := getStringArg(args, "key", "") + state := getStringArg(args, "state", "present") + + if user == "" || key == "" { + return nil, errors.New("authorized_key: user and key required") + } + + // Get user's home directory + stdout, _, _, err := client.Run(ctx, fmt.Sprintf("getent passwd %s | cut -d: -f6", user)) + if err != nil { + return nil, fmt.Errorf("get home dir: %w", err) + } + home := strings.TrimSpace(stdout) + if home == "" { + home = "/root" + if user != "root" { + home = "/home/" + user + } + } + + authKeysPath := filepath.Join(home, ".ssh", "authorized_keys") + + if state == "absent" { + // Remove key + escapedKey := strings.ReplaceAll(key, "/", "\\/") + cmd := fmt.Sprintf("sed -i '/%s/d' %q 2>/dev/null || true", escapedKey[:40], authKeysPath) + _, _, _, _ = client.Run(ctx, cmd) + return &TaskResult{Changed: true}, nil + } + + // Ensure .ssh directory exists (best-effort) + _, _, _, _ = client.Run(ctx, fmt.Sprintf("mkdir -p %q && chmod 700 %q && chown %s:%s %q", + filepath.Dir(authKeysPath), filepath.Dir(authKeysPath), user, user, filepath.Dir(authKeysPath))) + + // Add key if not present + cmd := fmt.Sprintf("grep -qF %q %q 2>/dev/null || echo %q >> %q", + key[:40], authKeysPath, key, authKeysPath) + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + // Fix permissions (best-effort) + _, _, _, _ = client.Run(ctx, fmt.Sprintf("chmod 600 %q && chown %s:%s %q", + authKeysPath, user, user, authKeysPath)) + + return &TaskResult{Changed: true}, nil +} + +func (e *Executor) moduleDockerCompose(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { + projectSrc := getStringArg(args, "project_src", "") + state := getStringArg(args, "state", "present") + + if projectSrc == "" { + return nil, errors.New("docker_compose: project_src required") + } + + var cmd string + switch state { + case "present": + cmd = fmt.Sprintf("cd %q && docker compose up -d", projectSrc) + case "absent": + cmd = fmt.Sprintf("cd %q && docker compose down", projectSrc) + case "restarted": + cmd = fmt.Sprintf("cd %q && docker compose restart", projectSrc) + default: + cmd = fmt.Sprintf("cd %q && docker compose up -d", projectSrc) + } + + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + } + + // Heuristic for changed + changed := !strings.Contains(stdout, "Up to date") && !strings.Contains(stderr, "Up to date") + + return &TaskResult{Changed: changed, Stdout: stdout}, nil +} diff --git a/modules_adv_test.go b/modules_adv_test.go new file mode 100644 index 0000000..389c5dd --- /dev/null +++ b/modules_adv_test.go @@ -0,0 +1,1127 @@ +package ansible + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================ +// Step 1.4: user / group / cron / authorized_key / git / +// unarchive / uri / ufw / docker_compose / blockinfile +// advanced module tests +// ============================================================ + +// --- user module --- + +func TestModuleUser_Good_CreateNewUser(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`id deploy >/dev/null 2>&1`, "", "no such user", 1) + mock.expectCommand(`useradd`, "", "", 0) + + result, err := moduleUserWithClient(e, mock, map[string]any{ + "name": "deploy", + "uid": "1500", + "group": "www-data", + "groups": "docker,sudo", + "home": "/opt/deploy", + "shell": "/bin/bash", + "create_home": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.containsSubstring("useradd")) + assert.True(t, mock.containsSubstring("-u 1500")) + assert.True(t, mock.containsSubstring("-g www-data")) + assert.True(t, mock.containsSubstring("-G docker,sudo")) + assert.True(t, mock.containsSubstring("-d /opt/deploy")) + assert.True(t, mock.containsSubstring("-s /bin/bash")) + assert.True(t, mock.containsSubstring("-m")) +} + +func TestModuleUser_Good_ModifyExistingUser(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + // id returns success meaning user exists, so usermod branch is taken + mock.expectCommand(`id deploy >/dev/null 2>&1 && usermod`, "", "", 0) + + result, err := moduleUserWithClient(e, mock, map[string]any{ + "name": "deploy", + "shell": "/bin/zsh", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.containsSubstring("usermod")) + assert.True(t, mock.containsSubstring("-s /bin/zsh")) +} + +func TestModuleUser_Good_RemoveUser(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`userdel -r deploy`, "", "", 0) + + result, err := moduleUserWithClient(e, mock, map[string]any{ + "name": "deploy", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`userdel -r deploy`)) +} + +func TestModuleUser_Good_SystemUser(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`id|useradd`, "", "", 0) + + result, err := moduleUserWithClient(e, mock, map[string]any{ + "name": "prometheus", + "system": true, + "create_home": false, + "shell": "/usr/sbin/nologin", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + // system flag adds -r + assert.True(t, mock.containsSubstring("-r")) + assert.True(t, mock.containsSubstring("-s /usr/sbin/nologin")) + // create_home=false means -m should NOT be present + // Actually, looking at the production code: getBoolArg(args, "create_home", true) — default is true + // We set it to false explicitly, so -m should NOT appear + cmd := mock.lastCommand() + assert.NotContains(t, cmd.Cmd, " -m ") +} + +func TestModuleUser_Good_NoOptsUsesSimpleForm(t *testing.T) { + // When no options are provided, uses the simple "id || useradd" form + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`id testuser >/dev/null 2>&1 || useradd testuser`, "", "", 0) + + result, err := moduleUserWithClient(e, mock, map[string]any{ + "name": "testuser", + "create_home": false, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) +} + +func TestModuleUser_Bad_MissingName(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleUserWithClient(e, mock, map[string]any{ + "state": "present", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "name required") +} + +func TestModuleUser_Good_CommandFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`id|useradd|usermod`, "", "useradd: Permission denied", 1) + + result, err := moduleUserWithClient(e, mock, map[string]any{ + "name": "deploy", + "shell": "/bin/bash", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "Permission denied") +} + +// --- group module --- + +func TestModuleGroup_Good_CreateNewGroup(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + // getent fails → groupadd runs + mock.expectCommand(`getent group appgroup`, "", "", 1) + mock.expectCommand(`groupadd`, "", "", 0) + + result, err := moduleGroupWithClient(e, mock, map[string]any{ + "name": "appgroup", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.containsSubstring("groupadd")) + assert.True(t, mock.containsSubstring("appgroup")) +} + +func TestModuleGroup_Good_GroupAlreadyExists(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + // getent succeeds → groupadd skipped (|| short-circuits) + mock.expectCommand(`getent group docker >/dev/null 2>&1 || groupadd`, "", "", 0) + + result, err := moduleGroupWithClient(e, mock, map[string]any{ + "name": "docker", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) +} + +func TestModuleGroup_Good_RemoveGroup(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`groupdel oldgroup`, "", "", 0) + + result, err := moduleGroupWithClient(e, mock, map[string]any{ + "name": "oldgroup", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`groupdel oldgroup`)) +} + +func TestModuleGroup_Good_SystemGroup(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`getent group|groupadd`, "", "", 0) + + result, err := moduleGroupWithClient(e, mock, map[string]any{ + "name": "prometheus", + "system": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.containsSubstring("-r")) +} + +func TestModuleGroup_Good_CustomGID(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`getent group|groupadd`, "", "", 0) + + result, err := moduleGroupWithClient(e, mock, map[string]any{ + "name": "custom", + "gid": "5000", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.containsSubstring("-g 5000")) +} + +func TestModuleGroup_Bad_MissingName(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleGroupWithClient(e, mock, map[string]any{ + "state": "present", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "name required") +} + +func TestModuleGroup_Good_CommandFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`getent group|groupadd`, "", "groupadd: Permission denied", 1) + + result, err := moduleGroupWithClient(e, mock, map[string]any{ + "name": "failgroup", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) +} + +// --- cron module --- + +func TestModuleCron_Good_AddCronJob(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`crontab -u root`, "", "", 0) + + result, err := moduleCronWithClient(e, mock, map[string]any{ + "name": "backup", + "job": "/usr/local/bin/backup.sh", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + // Default schedule is * * * * * + assert.True(t, mock.containsSubstring("* * * * *")) + assert.True(t, mock.containsSubstring("/usr/local/bin/backup.sh")) + assert.True(t, mock.containsSubstring("# backup")) +} + +func TestModuleCron_Good_RemoveCronJob(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`crontab -u root -l`, "* * * * * /bin/backup # backup\n", "", 0) + + result, err := moduleCronWithClient(e, mock, map[string]any{ + "name": "backup", + "job": "/bin/backup", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.containsSubstring("grep -v")) +} + +func TestModuleCron_Good_CustomSchedule(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`crontab -u root`, "", "", 0) + + result, err := moduleCronWithClient(e, mock, map[string]any{ + "name": "nightly-backup", + "job": "/opt/scripts/backup.sh", + "minute": "30", + "hour": "2", + "day": "1", + "month": "6", + "weekday": "0", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.containsSubstring("30 2 1 6 0")) + assert.True(t, mock.containsSubstring("/opt/scripts/backup.sh")) +} + +func TestModuleCron_Good_CustomUser(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`crontab -u www-data`, "", "", 0) + + result, err := moduleCronWithClient(e, mock, map[string]any{ + "name": "cache-clear", + "job": "php artisan cache:clear", + "user": "www-data", + "minute": "0", + "hour": "*/4", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.containsSubstring("crontab -u www-data")) + assert.True(t, mock.containsSubstring("0 */4 * * *")) +} + +func TestModuleCron_Good_AbsentWithNoName(t *testing.T) { + // Absent with no name — changed but no grep command + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleCronWithClient(e, mock, map[string]any{ + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + // No commands should have run since name is empty + assert.Equal(t, 0, mock.commandCount()) +} + +// --- authorized_key module --- + +func TestModuleAuthorizedKey_Good_AddKey(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + testKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDcT... user@host" + mock.expectCommand(`getent passwd deploy`, "/home/deploy", "", 0) + mock.expectCommand(`mkdir -p`, "", "", 0) + mock.expectCommand(`grep -qF`, "", "", 1) // key not found, will be appended + mock.expectCommand(`echo`, "", "", 0) + mock.expectCommand(`chmod 600`, "", "", 0) + + result, err := moduleAuthorizedKeyWithClient(e, mock, map[string]any{ + "user": "deploy", + "key": testKey, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.containsSubstring("mkdir -p")) + assert.True(t, mock.containsSubstring("chmod 700")) + assert.True(t, mock.containsSubstring("authorized_keys")) +} + +func TestModuleAuthorizedKey_Good_RemoveKey(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + testKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDcT... user@host" + mock.expectCommand(`getent passwd deploy`, "/home/deploy", "", 0) + mock.expectCommand(`sed -i`, "", "", 0) + + result, err := moduleAuthorizedKeyWithClient(e, mock, map[string]any{ + "user": "deploy", + "key": testKey, + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`sed -i`)) + assert.True(t, mock.containsSubstring("authorized_keys")) +} + +func TestModuleAuthorizedKey_Good_KeyAlreadyExists(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + testKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDcT... user@host" + mock.expectCommand(`getent passwd deploy`, "/home/deploy", "", 0) + mock.expectCommand(`mkdir -p`, "", "", 0) + // grep succeeds: key already present, || short-circuits, echo not needed + mock.expectCommand(`grep -qF.*echo`, "", "", 0) + mock.expectCommand(`chmod 600`, "", "", 0) + + result, err := moduleAuthorizedKeyWithClient(e, mock, map[string]any{ + "user": "deploy", + "key": testKey, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) +} + +func TestModuleAuthorizedKey_Good_RootUserFallback(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + testKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDcT... admin@host" + // getent returns empty — falls back to /root for root user + mock.expectCommand(`getent passwd root`, "", "", 0) + mock.expectCommand(`mkdir -p`, "", "", 0) + mock.expectCommand(`grep -qF.*echo`, "", "", 0) + mock.expectCommand(`chmod 600`, "", "", 0) + + result, err := moduleAuthorizedKeyWithClient(e, mock, map[string]any{ + "user": "root", + "key": testKey, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + // Should use /root/.ssh/authorized_keys + assert.True(t, mock.containsSubstring("/root/.ssh")) +} + +func TestModuleAuthorizedKey_Bad_MissingUserAndKey(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleAuthorizedKeyWithClient(e, mock, map[string]any{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "user and key required") +} + +func TestModuleAuthorizedKey_Bad_MissingKey(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleAuthorizedKeyWithClient(e, mock, map[string]any{ + "user": "deploy", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "user and key required") +} + +func TestModuleAuthorizedKey_Bad_MissingUser(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleAuthorizedKeyWithClient(e, mock, map[string]any{ + "key": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDcT...", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "user and key required") +} + +// --- git module --- + +func TestModuleGit_Good_FreshClone(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + // .git does not exist → fresh clone + mock.expectCommand(`git clone`, "", "", 0) + + result, err := moduleGitWithClient(e, mock, map[string]any{ + "repo": "https://github.com/example/app.git", + "dest": "/opt/app", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`git clone`)) + assert.True(t, mock.containsSubstring("https://github.com/example/app.git")) + assert.True(t, mock.containsSubstring("/opt/app")) + // Default version is HEAD + assert.True(t, mock.containsSubstring("git checkout")) +} + +func TestModuleGit_Good_UpdateExisting(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + // .git exists → fetch + checkout + mock.addFile("/opt/app/.git", []byte("gitdir")) + mock.expectCommand(`git fetch --all && git checkout`, "", "", 0) + + result, err := moduleGitWithClient(e, mock, map[string]any{ + "repo": "https://github.com/example/app.git", + "dest": "/opt/app", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`git fetch --all`)) + assert.True(t, mock.containsSubstring("git checkout --force")) + // Should NOT contain git clone + assert.False(t, mock.containsSubstring("git clone")) +} + +func TestModuleGit_Good_CustomVersion(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`git clone`, "", "", 0) + + result, err := moduleGitWithClient(e, mock, map[string]any{ + "repo": "https://github.com/example/app.git", + "dest": "/opt/app", + "version": "v2.1.0", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.containsSubstring("v2.1.0")) +} + +func TestModuleGit_Good_UpdateWithBranch(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.addFile("/srv/myapp/.git", []byte("gitdir")) + mock.expectCommand(`git fetch --all && git checkout`, "", "", 0) + + result, err := moduleGitWithClient(e, mock, map[string]any{ + "repo": "git@github.com:org/repo.git", + "dest": "/srv/myapp", + "version": "develop", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.containsSubstring("develop")) +} + +func TestModuleGit_Bad_MissingRepoAndDest(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleGitWithClient(e, mock, map[string]any{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "repo and dest required") +} + +func TestModuleGit_Bad_MissingRepo(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleGitWithClient(e, mock, map[string]any{ + "dest": "/opt/app", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "repo and dest required") +} + +func TestModuleGit_Bad_MissingDest(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleGitWithClient(e, mock, map[string]any{ + "repo": "https://github.com/example/app.git", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "repo and dest required") +} + +func TestModuleGit_Good_CloneFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`git clone`, "", "fatal: repository not found", 128) + + result, err := moduleGitWithClient(e, mock, map[string]any{ + "repo": "https://github.com/example/nonexistent.git", + "dest": "/opt/app", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "repository not found") +} + +// --- unarchive module --- + +func TestModuleUnarchive_Good_ExtractTarGzLocal(t *testing.T) { + // Create a temporary "archive" file + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "package.tar.gz") + require.NoError(t, os.WriteFile(archivePath, []byte("fake-archive-content"), 0644)) + + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`mkdir -p`, "", "", 0) + mock.expectCommand(`tar -xzf`, "", "", 0) + mock.expectCommand(`rm -f`, "", "", 0) + + result, err := moduleUnarchiveWithClient(e, mock, map[string]any{ + "src": archivePath, + "dest": "/opt/app", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + // Should have uploaded the file + assert.Equal(t, 1, mock.uploadCount()) + assert.True(t, mock.containsSubstring("tar -xzf")) + assert.True(t, mock.containsSubstring("/opt/app")) +} + +func TestModuleUnarchive_Good_ExtractZipLocal(t *testing.T) { + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "release.zip") + require.NoError(t, os.WriteFile(archivePath, []byte("fake-zip-content"), 0644)) + + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`mkdir -p`, "", "", 0) + mock.expectCommand(`unzip -o`, "", "", 0) + mock.expectCommand(`rm -f`, "", "", 0) + + result, err := moduleUnarchiveWithClient(e, mock, map[string]any{ + "src": archivePath, + "dest": "/opt/releases", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.Equal(t, 1, mock.uploadCount()) + assert.True(t, mock.containsSubstring("unzip -o")) +} + +func TestModuleUnarchive_Good_RemoteSource(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`mkdir -p`, "", "", 0) + mock.expectCommand(`tar -xzf`, "", "", 0) + + result, err := moduleUnarchiveWithClient(e, mock, map[string]any{ + "src": "/tmp/remote-archive.tar.gz", + "dest": "/opt/app", + "remote_src": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + // No upload should happen for remote sources + assert.Equal(t, 0, mock.uploadCount()) + assert.True(t, mock.containsSubstring("tar -xzf")) +} + +func TestModuleUnarchive_Good_TarXz(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`mkdir -p`, "", "", 0) + mock.expectCommand(`tar -xJf`, "", "", 0) + + result, err := moduleUnarchiveWithClient(e, mock, map[string]any{ + "src": "/tmp/archive.tar.xz", + "dest": "/opt/extract", + "remote_src": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.containsSubstring("tar -xJf")) +} + +func TestModuleUnarchive_Good_TarBz2(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`mkdir -p`, "", "", 0) + mock.expectCommand(`tar -xjf`, "", "", 0) + + result, err := moduleUnarchiveWithClient(e, mock, map[string]any{ + "src": "/tmp/archive.tar.bz2", + "dest": "/opt/extract", + "remote_src": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.containsSubstring("tar -xjf")) +} + +func TestModuleUnarchive_Bad_MissingSrcAndDest(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleUnarchiveWithClient(e, mock, map[string]any{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "src and dest required") +} + +func TestModuleUnarchive_Bad_MissingSrc(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleUnarchiveWithClient(e, mock, map[string]any{ + "dest": "/opt/app", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "src and dest required") +} + +func TestModuleUnarchive_Bad_LocalFileNotFound(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + mock.expectCommand(`mkdir -p`, "", "", 0) + + _, err := moduleUnarchiveWithClient(e, mock, map[string]any{ + "src": "/nonexistent/archive.tar.gz", + "dest": "/opt/app", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "read src") +} + +// --- uri module --- + +func TestModuleURI_Good_GetRequestDefault(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`curl.*https://example.com/api/health`, "OK\n200", "", 0) + + result, err := moduleURIWithClient(e, mock, map[string]any{ + "url": "https://example.com/api/health", + }) + + require.NoError(t, err) + assert.False(t, result.Failed) + assert.False(t, result.Changed) // URI module does not set changed + assert.Equal(t, 200, result.RC) + assert.Equal(t, 200, result.Data["status"]) +} + +func TestModuleURI_Good_PostWithBodyAndHeaders(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + // Use a broad pattern since header order in map iteration is non-deterministic + mock.expectCommand(`curl.*api\.example\.com`, "{\"id\":1}\n201", "", 0) + + result, err := moduleURIWithClient(e, mock, map[string]any{ + "url": "https://api.example.com/users", + "method": "POST", + "body": `{"name":"test"}`, + "status_code": 201, + "headers": map[string]any{ + "Content-Type": "application/json", + "Authorization": "Bearer token123", + }, + }) + + require.NoError(t, err) + assert.False(t, result.Failed) + assert.Equal(t, 201, result.RC) + assert.True(t, mock.containsSubstring("-X POST")) + assert.True(t, mock.containsSubstring("-d")) + assert.True(t, mock.containsSubstring("Content-Type")) + assert.True(t, mock.containsSubstring("Authorization")) +} + +func TestModuleURI_Good_WrongStatusCode(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`curl`, "Not Found\n404", "", 0) + + result, err := moduleURIWithClient(e, mock, map[string]any{ + "url": "https://example.com/missing", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) // Expected 200, got 404 + assert.Equal(t, 404, result.RC) +} + +func TestModuleURI_Good_CurlCommandFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommandError(`curl`, assert.AnError) + + result, err := moduleURIWithClient(e, mock, map[string]any{ + "url": "https://unreachable.example.com", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, assert.AnError.Error()) +} + +func TestModuleURI_Good_CustomExpectedStatus(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`curl`, "\n204", "", 0) + + result, err := moduleURIWithClient(e, mock, map[string]any{ + "url": "https://api.example.com/resource/1", + "method": "DELETE", + "status_code": 204, + }) + + require.NoError(t, err) + assert.False(t, result.Failed) + assert.Equal(t, 204, result.RC) +} + +func TestModuleURI_Bad_MissingURL(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleURIWithClient(e, mock, map[string]any{ + "method": "GET", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "url required") +} + +// --- ufw module --- + +func TestModuleUFW_Good_AllowRuleWithPort(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`ufw allow 443/tcp`, "Rule added", "", 0) + + result, err := moduleUFWWithClient(e, mock, map[string]any{ + "rule": "allow", + "port": "443", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`ufw allow 443/tcp`)) +} + +func TestModuleUFW_Good_EnableFirewall(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`ufw --force enable`, "Firewall is active", "", 0) + + result, err := moduleUFWWithClient(e, mock, map[string]any{ + "state": "enabled", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`ufw --force enable`)) +} + +func TestModuleUFW_Good_DenyRuleWithProto(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`ufw deny 53/udp`, "Rule added", "", 0) + + result, err := moduleUFWWithClient(e, mock, map[string]any{ + "rule": "deny", + "port": "53", + "proto": "udp", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`ufw deny 53/udp`)) +} + +func TestModuleUFW_Good_ResetFirewall(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`ufw --force reset`, "Resetting", "", 0) + + result, err := moduleUFWWithClient(e, mock, map[string]any{ + "state": "reset", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`ufw --force reset`)) +} + +func TestModuleUFW_Good_DisableFirewall(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`ufw disable`, "Firewall stopped", "", 0) + + result, err := moduleUFWWithClient(e, mock, map[string]any{ + "state": "disabled", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`ufw disable`)) +} + +func TestModuleUFW_Good_ReloadFirewall(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`ufw reload`, "Firewall reloaded", "", 0) + + result, err := moduleUFWWithClient(e, mock, map[string]any{ + "state": "reloaded", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`ufw reload`)) +} + +func TestModuleUFW_Good_LimitRule(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`ufw limit 22/tcp`, "Rule added", "", 0) + + result, err := moduleUFWWithClient(e, mock, map[string]any{ + "rule": "limit", + "port": "22", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`ufw limit 22/tcp`)) +} + +func TestModuleUFW_Good_StateCommandFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`ufw --force enable`, "", "ERROR: problem running ufw", 1) + + result, err := moduleUFWWithClient(e, mock, map[string]any{ + "state": "enabled", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) +} + +// --- docker_compose module --- + +func TestModuleDockerCompose_Good_StatePresent(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`docker compose up -d`, "Creating container_1\nCreating container_2\n", "", 0) + + result, err := moduleDockerComposeWithClient(e, mock, map[string]any{ + "project_src": "/opt/myapp", + "state": "present", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`docker compose up -d`)) + assert.True(t, mock.containsSubstring("/opt/myapp")) +} + +func TestModuleDockerCompose_Good_StateAbsent(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`docker compose down`, "Removing container_1\n", "", 0) + + result, err := moduleDockerComposeWithClient(e, mock, map[string]any{ + "project_src": "/opt/myapp", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`docker compose down`)) +} + +func TestModuleDockerCompose_Good_AlreadyUpToDate(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`docker compose up -d`, "Container myapp-web-1 Up to date\n", "", 0) + + result, err := moduleDockerComposeWithClient(e, mock, map[string]any{ + "project_src": "/opt/myapp", + "state": "present", + }) + + require.NoError(t, err) + assert.False(t, result.Changed) // "Up to date" in stdout → changed=false + assert.False(t, result.Failed) +} + +func TestModuleDockerCompose_Good_StateRestarted(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`docker compose restart`, "Restarting container_1\n", "", 0) + + result, err := moduleDockerComposeWithClient(e, mock, map[string]any{ + "project_src": "/opt/stack", + "state": "restarted", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`docker compose restart`)) +} + +func TestModuleDockerCompose_Bad_MissingProjectSrc(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleDockerComposeWithClient(e, mock, map[string]any{ + "state": "present", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "project_src required") +} + +func TestModuleDockerCompose_Good_CommandFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`docker compose up -d`, "", "Error response from daemon", 1) + + result, err := moduleDockerComposeWithClient(e, mock, map[string]any{ + "project_src": "/opt/broken", + "state": "present", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "Error response from daemon") +} + +func TestModuleDockerCompose_Good_DefaultStateIsPresent(t *testing.T) { + // When no state is specified, default is "present" + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`docker compose up -d`, "Starting\n", "", 0) + + result, err := moduleDockerComposeWithClient(e, mock, map[string]any{ + "project_src": "/opt/app", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`docker compose up -d`)) +} + +// --- Cross-module dispatch tests for advanced modules --- + +func TestExecuteModuleWithMock_Good_DispatchUser(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`id|useradd|usermod`, "", "", 0) + + task := &Task{ + Module: "user", + Args: map[string]any{ + "name": "appuser", + "shell": "/bin/bash", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Good_DispatchGroup(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`getent group|groupadd`, "", "", 0) + + task := &Task{ + Module: "group", + Args: map[string]any{ + "name": "docker", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Good_DispatchCron(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`crontab`, "", "", 0) + + task := &Task{ + Module: "cron", + Args: map[string]any{ + "name": "logrotate", + "job": "/usr/sbin/logrotate", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Good_DispatchGit(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`git clone`, "", "", 0) + + task := &Task{ + Module: "git", + Args: map[string]any{ + "repo": "https://github.com/org/repo.git", + "dest": "/opt/repo", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Good_DispatchURI(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`curl`, "OK\n200", "", 0) + + task := &Task{ + Module: "uri", + Args: map[string]any{ + "url": "https://example.com", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.False(t, result.Failed) +} + +func TestExecuteModuleWithMock_Good_DispatchDockerCompose(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`docker compose up -d`, "Creating\n", "", 0) + + task := &Task{ + Module: "ansible.builtin.docker_compose", + Args: map[string]any{ + "project_src": "/opt/stack", + "state": "present", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} diff --git a/modules_cmd_test.go b/modules_cmd_test.go new file mode 100644 index 0000000..cf4f28e --- /dev/null +++ b/modules_cmd_test.go @@ -0,0 +1,722 @@ +package ansible + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================ +// Step 1.1: command / shell / raw / script module tests +// ============================================================ + +// --- MockSSHClient basic tests --- + +func TestMockSSHClient_Good_RunRecordsExecution(t *testing.T) { + mock := NewMockSSHClient() + mock.expectCommand("echo hello", "hello\n", "", 0) + + stdout, stderr, rc, err := mock.Run(nil, "echo hello") + + assert.NoError(t, err) + assert.Equal(t, "hello\n", stdout) + assert.Equal(t, "", stderr) + assert.Equal(t, 0, rc) + assert.Equal(t, 1, mock.commandCount()) + assert.Equal(t, "Run", mock.lastCommand().Method) + assert.Equal(t, "echo hello", mock.lastCommand().Cmd) +} + +func TestMockSSHClient_Good_RunScriptRecordsExecution(t *testing.T) { + mock := NewMockSSHClient() + mock.expectCommand("set -e", "ok", "", 0) + + stdout, _, rc, err := mock.RunScript(nil, "set -e\necho done") + + assert.NoError(t, err) + assert.Equal(t, "ok", stdout) + assert.Equal(t, 0, rc) + assert.Equal(t, 1, mock.commandCount()) + assert.Equal(t, "RunScript", mock.lastCommand().Method) +} + +func TestMockSSHClient_Good_DefaultSuccessResponse(t *testing.T) { + mock := NewMockSSHClient() + + // No expectations registered — should return empty success + stdout, stderr, rc, err := mock.Run(nil, "anything") + + assert.NoError(t, err) + assert.Equal(t, "", stdout) + assert.Equal(t, "", stderr) + assert.Equal(t, 0, rc) +} + +func TestMockSSHClient_Good_LastMatchWins(t *testing.T) { + mock := NewMockSSHClient() + mock.expectCommand("echo", "first", "", 0) + mock.expectCommand("echo", "second", "", 0) + + stdout, _, _, _ := mock.Run(nil, "echo hello") + + assert.Equal(t, "second", stdout) +} + +func TestMockSSHClient_Good_FileOperations(t *testing.T) { + mock := NewMockSSHClient() + + // File does not exist initially + exists, err := mock.FileExists(nil, "/etc/config") + assert.NoError(t, err) + assert.False(t, exists) + + // Add file + mock.addFile("/etc/config", []byte("key=value")) + + // Now it exists + exists, err = mock.FileExists(nil, "/etc/config") + assert.NoError(t, err) + assert.True(t, exists) + + // Download it + content, err := mock.Download(nil, "/etc/config") + assert.NoError(t, err) + assert.Equal(t, []byte("key=value"), content) + + // Download non-existent file + _, err = mock.Download(nil, "/nonexistent") + assert.Error(t, err) +} + +func TestMockSSHClient_Good_StatWithExplicit(t *testing.T) { + mock := NewMockSSHClient() + mock.addStat("/var/log", map[string]any{"exists": true, "isdir": true}) + + info, err := mock.Stat(nil, "/var/log") + assert.NoError(t, err) + assert.Equal(t, true, info["exists"]) + assert.Equal(t, true, info["isdir"]) +} + +func TestMockSSHClient_Good_StatFallback(t *testing.T) { + mock := NewMockSSHClient() + mock.addFile("/etc/hosts", []byte("127.0.0.1 localhost")) + + info, err := mock.Stat(nil, "/etc/hosts") + assert.NoError(t, err) + assert.Equal(t, true, info["exists"]) + assert.Equal(t, false, info["isdir"]) + + info, err = mock.Stat(nil, "/nonexistent") + assert.NoError(t, err) + assert.Equal(t, false, info["exists"]) +} + +func TestMockSSHClient_Good_BecomeTracking(t *testing.T) { + mock := NewMockSSHClient() + + assert.False(t, mock.become) + assert.Equal(t, "", mock.becomeUser) + + mock.SetBecome(true, "root", "secret") + + assert.True(t, mock.become) + assert.Equal(t, "root", mock.becomeUser) + assert.Equal(t, "secret", mock.becomePass) +} + +func TestMockSSHClient_Good_HasExecuted(t *testing.T) { + mock := NewMockSSHClient() + _, _, _, _ = mock.Run(nil, "systemctl restart nginx") + _, _, _, _ = mock.Run(nil, "apt-get update") + + assert.True(t, mock.hasExecuted("systemctl.*nginx")) + assert.True(t, mock.hasExecuted("apt-get")) + assert.False(t, mock.hasExecuted("yum")) +} + +func TestMockSSHClient_Good_HasExecutedMethod(t *testing.T) { + mock := NewMockSSHClient() + _, _, _, _ = mock.Run(nil, "echo run") + _, _, _, _ = mock.RunScript(nil, "echo script") + + assert.True(t, mock.hasExecutedMethod("Run", "echo run")) + assert.True(t, mock.hasExecutedMethod("RunScript", "echo script")) + assert.False(t, mock.hasExecutedMethod("Run", "echo script")) + assert.False(t, mock.hasExecutedMethod("RunScript", "echo run")) +} + +func TestMockSSHClient_Good_Reset(t *testing.T) { + mock := NewMockSSHClient() + _, _, _, _ = mock.Run(nil, "echo hello") + assert.Equal(t, 1, mock.commandCount()) + + mock.reset() + assert.Equal(t, 0, mock.commandCount()) +} + +func TestMockSSHClient_Good_ErrorExpectation(t *testing.T) { + mock := NewMockSSHClient() + mock.expectCommandError("bad cmd", assert.AnError) + + _, _, _, err := mock.Run(nil, "bad cmd") + assert.Error(t, err) +} + +// --- command module --- + +func TestModuleCommand_Good_BasicCommand(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("ls -la /tmp", "total 0\n", "", 0) + + result, err := moduleCommandWithClient(e, mock, map[string]any{ + "_raw_params": "ls -la /tmp", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.Equal(t, "total 0\n", result.Stdout) + assert.Equal(t, 0, result.RC) + + // Verify it used Run (not RunScript) + assert.True(t, mock.hasExecutedMethod("Run", "ls -la /tmp")) + assert.False(t, mock.hasExecutedMethod("RunScript", ".*")) +} + +func TestModuleCommand_Good_CmdArg(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("whoami", "root\n", "", 0) + + result, err := moduleCommandWithClient(e, mock, map[string]any{ + "cmd": "whoami", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, "root\n", result.Stdout) + assert.True(t, mock.hasExecutedMethod("Run", "whoami")) +} + +func TestModuleCommand_Good_WithChdir(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`cd "/var/log" && ls`, "syslog\n", "", 0) + + result, err := moduleCommandWithClient(e, mock, map[string]any{ + "_raw_params": "ls", + "chdir": "/var/log", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + // The command should have been wrapped with cd + last := mock.lastCommand() + assert.Equal(t, "Run", last.Method) + assert.Contains(t, last.Cmd, `cd "/var/log"`) + assert.Contains(t, last.Cmd, "ls") +} + +func TestModuleCommand_Bad_NoCommand(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleCommandWithClient(e, mock, map[string]any{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "no command specified") +} + +func TestModuleCommand_Good_NonZeroRC(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("false", "", "error occurred", 1) + + result, err := moduleCommandWithClient(e, mock, map[string]any{ + "_raw_params": "false", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Equal(t, 1, result.RC) + assert.Equal(t, "error occurred", result.Stderr) +} + +func TestModuleCommand_Good_SSHError(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + mock.expectCommandError(".*", assert.AnError) + + result, err := moduleCommandWithClient(e, mock, map[string]any{ + "_raw_params": "any command", + }) + + require.NoError(t, err) // Module wraps SSH errors into result.Failed + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, assert.AnError.Error()) +} + +func TestModuleCommand_Good_RawParamsTakesPrecedence(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("from_raw", "raw\n", "", 0) + + result, err := moduleCommandWithClient(e, mock, map[string]any{ + "_raw_params": "from_raw", + "cmd": "from_cmd", + }) + + require.NoError(t, err) + assert.Equal(t, "raw\n", result.Stdout) + assert.True(t, mock.hasExecuted("from_raw")) +} + +// --- shell module --- + +func TestModuleShell_Good_BasicShell(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("echo hello", "hello\n", "", 0) + + result, err := moduleShellWithClient(e, mock, map[string]any{ + "_raw_params": "echo hello", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.Equal(t, "hello\n", result.Stdout) + + // Shell must use RunScript (not Run) + assert.True(t, mock.hasExecutedMethod("RunScript", "echo hello")) + assert.False(t, mock.hasExecutedMethod("Run", ".*")) +} + +func TestModuleShell_Good_CmdArg(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("date", "Thu Feb 20\n", "", 0) + + result, err := moduleShellWithClient(e, mock, map[string]any{ + "cmd": "date", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecutedMethod("RunScript", "date")) +} + +func TestModuleShell_Good_WithChdir(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`cd "/app" && npm install`, "done\n", "", 0) + + result, err := moduleShellWithClient(e, mock, map[string]any{ + "_raw_params": "npm install", + "chdir": "/app", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + last := mock.lastCommand() + assert.Equal(t, "RunScript", last.Method) + assert.Contains(t, last.Cmd, `cd "/app"`) + assert.Contains(t, last.Cmd, "npm install") +} + +func TestModuleShell_Bad_NoCommand(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleShellWithClient(e, mock, map[string]any{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "no command specified") +} + +func TestModuleShell_Good_NonZeroRC(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("exit 2", "", "failed", 2) + + result, err := moduleShellWithClient(e, mock, map[string]any{ + "_raw_params": "exit 2", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Equal(t, 2, result.RC) +} + +func TestModuleShell_Good_SSHError(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + mock.expectCommandError(".*", assert.AnError) + + result, err := moduleShellWithClient(e, mock, map[string]any{ + "_raw_params": "some command", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) +} + +func TestModuleShell_Good_PipelineCommand(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`cat /etc/passwd \| grep root`, "root:x:0:0\n", "", 0) + + result, err := moduleShellWithClient(e, mock, map[string]any{ + "_raw_params": "cat /etc/passwd | grep root", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + // Shell uses RunScript, so pipes work + assert.True(t, mock.hasExecutedMethod("RunScript", "cat /etc/passwd")) +} + +// --- raw module --- + +func TestModuleRaw_Good_BasicRaw(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("uname -a", "Linux host1 5.15\n", "", 0) + + result, err := moduleRawWithClient(e, mock, map[string]any{ + "_raw_params": "uname -a", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, "Linux host1 5.15\n", result.Stdout) + + // Raw must use Run (not RunScript) — no shell wrapping + assert.True(t, mock.hasExecutedMethod("Run", "uname -a")) + assert.False(t, mock.hasExecutedMethod("RunScript", ".*")) +} + +func TestModuleRaw_Bad_NoCommand(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleRawWithClient(e, mock, map[string]any{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "no command specified") +} + +func TestModuleRaw_Good_NoChdir(t *testing.T) { + // Raw module does NOT support chdir — it should ignore it + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("echo test", "test\n", "", 0) + + result, err := moduleRawWithClient(e, mock, map[string]any{ + "_raw_params": "echo test", + "chdir": "/should/be/ignored", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + // The chdir should NOT appear in the command + last := mock.lastCommand() + assert.Equal(t, "echo test", last.Cmd) + assert.NotContains(t, last.Cmd, "cd") +} + +func TestModuleRaw_Good_NonZeroRC(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("invalid", "", "not found", 127) + + result, err := moduleRawWithClient(e, mock, map[string]any{ + "_raw_params": "invalid", + }) + + require.NoError(t, err) + // Note: raw module does NOT set Failed based on RC + assert.Equal(t, 127, result.RC) + assert.Equal(t, "not found", result.Stderr) +} + +func TestModuleRaw_Good_SSHError(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + mock.expectCommandError(".*", assert.AnError) + + result, err := moduleRawWithClient(e, mock, map[string]any{ + "_raw_params": "any", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) +} + +func TestModuleRaw_Good_ExactCommandPassthrough(t *testing.T) { + // Raw should pass the command exactly as given — no wrapping + e, mock := newTestExecutorWithMock("host1") + complexCmd := `/usr/bin/python3 -c 'import sys; print(sys.version)'` + mock.expectCommand(".*python3.*", "3.10.0\n", "", 0) + + result, err := moduleRawWithClient(e, mock, map[string]any{ + "_raw_params": complexCmd, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + last := mock.lastCommand() + assert.Equal(t, complexCmd, last.Cmd) +} + +// --- script module --- + +func TestModuleScript_Good_BasicScript(t *testing.T) { + // Create a temporary script file + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "setup.sh") + scriptContent := "#!/bin/bash\necho 'setup complete'\nexit 0" + require.NoError(t, os.WriteFile(scriptPath, []byte(scriptContent), 0755)) + + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("setup complete", "setup complete\n", "", 0) + + result, err := moduleScriptWithClient(e, mock, map[string]any{ + "_raw_params": scriptPath, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + + // Script must use RunScript (not Run) — it sends the file content + assert.True(t, mock.hasExecutedMethod("RunScript", "setup complete")) + assert.False(t, mock.hasExecutedMethod("Run", ".*")) + + // Verify the full script content was sent + last := mock.lastCommand() + assert.Equal(t, scriptContent, last.Cmd) +} + +func TestModuleScript_Bad_NoScript(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleScriptWithClient(e, mock, map[string]any{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "no script specified") +} + +func TestModuleScript_Bad_FileNotFound(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleScriptWithClient(e, mock, map[string]any{ + "_raw_params": "/nonexistent/script.sh", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "read script") +} + +func TestModuleScript_Good_NonZeroRC(t *testing.T) { + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "fail.sh") + require.NoError(t, os.WriteFile(scriptPath, []byte("exit 1"), 0755)) + + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("exit 1", "", "script failed", 1) + + result, err := moduleScriptWithClient(e, mock, map[string]any{ + "_raw_params": scriptPath, + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Equal(t, 1, result.RC) +} + +func TestModuleScript_Good_MultiLineScript(t *testing.T) { + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "multi.sh") + scriptContent := "#!/bin/bash\nset -e\napt-get update\napt-get install -y nginx\nsystemctl start nginx" + require.NoError(t, os.WriteFile(scriptPath, []byte(scriptContent), 0755)) + + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("apt-get", "done\n", "", 0) + + result, err := moduleScriptWithClient(e, mock, map[string]any{ + "_raw_params": scriptPath, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Verify RunScript was called with the full content + last := mock.lastCommand() + assert.Equal(t, "RunScript", last.Method) + assert.Equal(t, scriptContent, last.Cmd) +} + +func TestModuleScript_Good_SSHError(t *testing.T) { + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "ok.sh") + require.NoError(t, os.WriteFile(scriptPath, []byte("echo ok"), 0755)) + + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + mock.expectCommandError(".*", assert.AnError) + + result, err := moduleScriptWithClient(e, mock, map[string]any{ + "_raw_params": scriptPath, + }) + + require.NoError(t, err) + assert.True(t, result.Failed) +} + +// --- Cross-module differentiation tests --- + +func TestModuleDifferentiation_Good_CommandUsesRun(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("echo test", "test\n", "", 0) + + _, _ = moduleCommandWithClient(e, mock, map[string]any{"_raw_params": "echo test"}) + + cmds := mock.executedCommands() + require.Len(t, cmds, 1) + assert.Equal(t, "Run", cmds[0].Method, "command module must use Run()") +} + +func TestModuleDifferentiation_Good_ShellUsesRunScript(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("echo test", "test\n", "", 0) + + _, _ = moduleShellWithClient(e, mock, map[string]any{"_raw_params": "echo test"}) + + cmds := mock.executedCommands() + require.Len(t, cmds, 1) + assert.Equal(t, "RunScript", cmds[0].Method, "shell module must use RunScript()") +} + +func TestModuleDifferentiation_Good_RawUsesRun(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("echo test", "test\n", "", 0) + + _, _ = moduleRawWithClient(e, mock, map[string]any{"_raw_params": "echo test"}) + + cmds := mock.executedCommands() + require.Len(t, cmds, 1) + assert.Equal(t, "Run", cmds[0].Method, "raw module must use Run()") +} + +func TestModuleDifferentiation_Good_ScriptUsesRunScript(t *testing.T) { + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "test.sh") + require.NoError(t, os.WriteFile(scriptPath, []byte("echo test"), 0755)) + + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("echo test", "test\n", "", 0) + + _, _ = moduleScriptWithClient(e, mock, map[string]any{"_raw_params": scriptPath}) + + cmds := mock.executedCommands() + require.Len(t, cmds, 1) + assert.Equal(t, "RunScript", cmds[0].Method, "script module must use RunScript()") +} + +// --- executeModuleWithMock dispatch tests --- + +func TestExecuteModuleWithMock_Good_DispatchCommand(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("uptime", "up 5 days\n", "", 0) + + task := &Task{ + Module: "command", + Args: map[string]any{"_raw_params": "uptime"}, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, "up 5 days\n", result.Stdout) +} + +func TestExecuteModuleWithMock_Good_DispatchShell(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("ps aux", "root.*bash\n", "", 0) + + task := &Task{ + Module: "ansible.builtin.shell", + Args: map[string]any{"_raw_params": "ps aux | grep bash"}, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Good_DispatchRaw(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("cat /etc/hostname", "web01\n", "", 0) + + task := &Task{ + Module: "raw", + Args: map[string]any{"_raw_params": "cat /etc/hostname"}, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, "web01\n", result.Stdout) +} + +func TestExecuteModuleWithMock_Good_DispatchScript(t *testing.T) { + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "deploy.sh") + require.NoError(t, os.WriteFile(scriptPath, []byte("echo deploying"), 0755)) + + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("deploying", "deploying\n", "", 0) + + task := &Task{ + Module: "script", + Args: map[string]any{"_raw_params": scriptPath}, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Bad_UnsupportedModule(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + task := &Task{ + Module: "ansible.builtin.hostname", + Args: map[string]any{}, + } + + _, err := executeModuleWithMock(e, mock, "host1", task) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported module") +} + +// --- Template integration tests --- + +func TestModuleCommand_Good_TemplatedArgs(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + e.SetVar("service_name", "nginx") + mock.expectCommand("systemctl status nginx", "active\n", "", 0) + + task := &Task{ + Module: "command", + Args: map[string]any{"_raw_params": "systemctl status {{ service_name }}"}, + } + + // Template the args the way the executor does + args := e.templateArgs(task.Args, "host1", task) + result, err := moduleCommandWithClient(e, mock, args) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted("systemctl status nginx")) +} diff --git a/modules_file_test.go b/modules_file_test.go new file mode 100644 index 0000000..cf3cd33 --- /dev/null +++ b/modules_file_test.go @@ -0,0 +1,899 @@ +package ansible + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================ +// Step 1.2: copy / template / file / lineinfile / blockinfile / stat module tests +// ============================================================ + +// --- copy module --- + +func TestModuleCopy_Good_ContentUpload(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleCopyWithClient(e, mock, map[string]any{ + "content": "server_name=web01", + "dest": "/etc/app/config", + }, "host1", &Task{}) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Contains(t, result.Msg, "copied to /etc/app/config") + + // Verify upload was performed + assert.Equal(t, 1, mock.uploadCount()) + up := mock.lastUpload() + require.NotNil(t, up) + assert.Equal(t, "/etc/app/config", up.Remote) + assert.Equal(t, []byte("server_name=web01"), up.Content) + assert.Equal(t, os.FileMode(0644), up.Mode) +} + +func TestModuleCopy_Good_SrcFile(t *testing.T) { + tmpDir := t.TempDir() + srcPath := filepath.Join(tmpDir, "nginx.conf") + require.NoError(t, os.WriteFile(srcPath, []byte("worker_processes auto;"), 0644)) + + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleCopyWithClient(e, mock, map[string]any{ + "src": srcPath, + "dest": "/etc/nginx/nginx.conf", + }, "host1", &Task{}) + + require.NoError(t, err) + assert.True(t, result.Changed) + + up := mock.lastUpload() + require.NotNil(t, up) + assert.Equal(t, "/etc/nginx/nginx.conf", up.Remote) + assert.Equal(t, []byte("worker_processes auto;"), up.Content) +} + +func TestModuleCopy_Good_OwnerGroup(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleCopyWithClient(e, mock, map[string]any{ + "content": "data", + "dest": "/opt/app/data.txt", + "owner": "appuser", + "group": "appgroup", + }, "host1", &Task{}) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Upload + chown + chgrp = 1 upload + 2 Run calls + assert.Equal(t, 1, mock.uploadCount()) + assert.True(t, mock.hasExecuted(`chown appuser "/opt/app/data.txt"`)) + assert.True(t, mock.hasExecuted(`chgrp appgroup "/opt/app/data.txt"`)) +} + +func TestModuleCopy_Good_CustomMode(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleCopyWithClient(e, mock, map[string]any{ + "content": "#!/bin/bash\necho hello", + "dest": "/usr/local/bin/hello.sh", + "mode": "0755", + }, "host1", &Task{}) + + require.NoError(t, err) + assert.True(t, result.Changed) + + up := mock.lastUpload() + require.NotNil(t, up) + assert.Equal(t, os.FileMode(0755), up.Mode) +} + +func TestModuleCopy_Bad_MissingDest(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleCopyWithClient(e, mock, map[string]any{ + "content": "data", + }, "host1", &Task{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "dest required") +} + +func TestModuleCopy_Bad_MissingSrcAndContent(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleCopyWithClient(e, mock, map[string]any{ + "dest": "/tmp/out", + }, "host1", &Task{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "src or content required") +} + +func TestModuleCopy_Bad_SrcFileNotFound(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleCopyWithClient(e, mock, map[string]any{ + "src": "/nonexistent/file.txt", + "dest": "/tmp/out", + }, "host1", &Task{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "read src") +} + +func TestModuleCopy_Good_ContentTakesPrecedenceOverSrc(t *testing.T) { + // When both content and src are given, src is checked first in the implementation + // but if src is empty string, content is used + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleCopyWithClient(e, mock, map[string]any{ + "content": "from_content", + "dest": "/tmp/out", + }, "host1", &Task{}) + + require.NoError(t, err) + assert.True(t, result.Changed) + up := mock.lastUpload() + assert.Equal(t, []byte("from_content"), up.Content) +} + +// --- file module --- + +func TestModuleFile_Good_StateDirectory(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleFileWithClient(e, mock, map[string]any{ + "path": "/var/lib/app", + "state": "directory", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should execute mkdir -p with default mode 0755 + assert.True(t, mock.hasExecuted(`mkdir -p "/var/lib/app"`)) + assert.True(t, mock.hasExecuted(`chmod 0755`)) +} + +func TestModuleFile_Good_StateDirectoryCustomMode(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleFileWithClient(e, mock, map[string]any{ + "path": "/opt/data", + "state": "directory", + "mode": "0700", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`mkdir -p "/opt/data" && chmod 0700 "/opt/data"`)) +} + +func TestModuleFile_Good_StateAbsent(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleFileWithClient(e, mock, map[string]any{ + "path": "/tmp/old-dir", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`rm -rf "/tmp/old-dir"`)) +} + +func TestModuleFile_Good_StateTouch(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleFileWithClient(e, mock, map[string]any{ + "path": "/var/log/app.log", + "state": "touch", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`touch "/var/log/app.log"`)) +} + +func TestModuleFile_Good_StateLink(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleFileWithClient(e, mock, map[string]any{ + "path": "/usr/local/bin/node", + "state": "link", + "src": "/opt/node/bin/node", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`ln -sf "/opt/node/bin/node" "/usr/local/bin/node"`)) +} + +func TestModuleFile_Bad_LinkMissingSrc(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleFileWithClient(e, mock, map[string]any{ + "path": "/usr/local/bin/node", + "state": "link", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "src required for link state") +} + +func TestModuleFile_Good_OwnerGroupMode(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleFileWithClient(e, mock, map[string]any{ + "path": "/var/lib/app/data", + "state": "directory", + "owner": "www-data", + "group": "www-data", + "mode": "0775", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should have mkdir, chmod in the directory command, then chown and chgrp + assert.True(t, mock.hasExecuted(`mkdir -p "/var/lib/app/data" && chmod 0775 "/var/lib/app/data"`)) + assert.True(t, mock.hasExecuted(`chown www-data "/var/lib/app/data"`)) + assert.True(t, mock.hasExecuted(`chgrp www-data "/var/lib/app/data"`)) +} + +func TestModuleFile_Good_RecurseOwner(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleFileWithClient(e, mock, map[string]any{ + "path": "/var/www", + "state": "directory", + "owner": "www-data", + "recurse": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should have both regular chown and recursive chown + assert.True(t, mock.hasExecuted(`chown www-data "/var/www"`)) + assert.True(t, mock.hasExecuted(`chown -R www-data "/var/www"`)) +} + +func TestModuleFile_Bad_MissingPath(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleFileWithClient(e, mock, map[string]any{ + "state": "directory", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "path required") +} + +func TestModuleFile_Good_DestAliasForPath(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleFileWithClient(e, mock, map[string]any{ + "dest": "/opt/myapp", + "state": "directory", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`mkdir -p "/opt/myapp"`)) +} + +func TestModuleFile_Good_StateFileWithMode(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleFileWithClient(e, mock, map[string]any{ + "path": "/etc/config.yml", + "state": "file", + "mode": "0600", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`chmod 0600 "/etc/config.yml"`)) +} + +func TestModuleFile_Good_DirectoryCommandFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("mkdir", "", "permission denied", 1) + + result, err := moduleFileWithClient(e, mock, map[string]any{ + "path": "/root/protected", + "state": "directory", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "permission denied") +} + +// --- lineinfile module --- + +func TestModuleLineinfile_Good_InsertLine(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleLineinfileWithClient(e, mock, map[string]any{ + "path": "/etc/hosts", + "line": "192.168.1.100 web01", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should use grep -qxF to check and echo to append + assert.True(t, mock.hasExecuted(`grep -qxF`)) + assert.True(t, mock.hasExecuted(`192.168.1.100 web01`)) +} + +func TestModuleLineinfile_Good_ReplaceRegexp(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleLineinfileWithClient(e, mock, map[string]any{ + "path": "/etc/ssh/sshd_config", + "regexp": "^#?PermitRootLogin", + "line": "PermitRootLogin no", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should use sed to replace + assert.True(t, mock.hasExecuted(`sed -i 's/\^#\?PermitRootLogin/PermitRootLogin no/'`)) +} + +func TestModuleLineinfile_Good_RemoveLine(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleLineinfileWithClient(e, mock, map[string]any{ + "path": "/etc/hosts", + "regexp": "^192\\.168\\.1\\.100", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should use sed to delete matching lines + assert.True(t, mock.hasExecuted(`sed -i '/\^192`)) + assert.True(t, mock.hasExecuted(`/d'`)) +} + +func TestModuleLineinfile_Good_RegexpFallsBackToAppend(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + // Simulate sed returning non-zero (pattern not found) + mock.expectCommand("sed -i", "", "", 1) + + result, err := moduleLineinfileWithClient(e, mock, map[string]any{ + "path": "/etc/config", + "regexp": "^setting=", + "line": "setting=value", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should have attempted sed, then fallen back to echo append + cmds := mock.executedCommands() + assert.GreaterOrEqual(t, len(cmds), 2) + assert.True(t, mock.hasExecuted(`echo`)) +} + +func TestModuleLineinfile_Bad_MissingPath(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleLineinfileWithClient(e, mock, map[string]any{ + "line": "test", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "path required") +} + +func TestModuleLineinfile_Good_DestAliasForPath(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleLineinfileWithClient(e, mock, map[string]any{ + "dest": "/etc/config", + "line": "key=value", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`/etc/config`)) +} + +func TestModuleLineinfile_Good_AbsentWithNoRegexp(t *testing.T) { + // When state=absent but no regexp, nothing happens (no commands) + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleLineinfileWithClient(e, mock, map[string]any{ + "path": "/etc/config", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, 0, mock.commandCount()) +} + +func TestModuleLineinfile_Good_LineWithSlashes(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleLineinfileWithClient(e, mock, map[string]any{ + "path": "/etc/nginx/conf.d/default.conf", + "regexp": "^root /", + "line": "root /var/www/html;", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Slashes in the line should be escaped + assert.True(t, mock.hasExecuted(`root \\/var\\/www\\/html;`)) +} + +// --- blockinfile module --- + +func TestModuleBlockinfile_Good_InsertBlock(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleBlockinfileWithClient(e, mock, map[string]any{ + "path": "/etc/nginx/conf.d/upstream.conf", + "block": "server 10.0.0.1:8080;\nserver 10.0.0.2:8080;", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should use RunScript for the heredoc approach + assert.True(t, mock.hasExecutedMethod("RunScript", "BEGIN ANSIBLE MANAGED BLOCK")) + assert.True(t, mock.hasExecutedMethod("RunScript", "END ANSIBLE MANAGED BLOCK")) + assert.True(t, mock.hasExecutedMethod("RunScript", "10\\.0\\.0\\.1")) +} + +func TestModuleBlockinfile_Good_CustomMarkers(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleBlockinfileWithClient(e, mock, map[string]any{ + "path": "/etc/hosts", + "block": "10.0.0.5 db01", + "marker": "# {mark} managed by devops", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should use custom markers instead of default + assert.True(t, mock.hasExecutedMethod("RunScript", "# BEGIN managed by devops")) + assert.True(t, mock.hasExecutedMethod("RunScript", "# END managed by devops")) +} + +func TestModuleBlockinfile_Good_RemoveBlock(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleBlockinfileWithClient(e, mock, map[string]any{ + "path": "/etc/config", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should use sed to remove the block between markers + assert.True(t, mock.hasExecuted(`sed -i '/.*BEGIN ANSIBLE MANAGED BLOCK/,/.*END ANSIBLE MANAGED BLOCK/d'`)) +} + +func TestModuleBlockinfile_Good_CreateFile(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleBlockinfileWithClient(e, mock, map[string]any{ + "path": "/etc/new-config", + "block": "setting=value", + "create": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + + // Should touch the file first when create=true + assert.True(t, mock.hasExecuted(`touch "/etc/new-config"`)) +} + +func TestModuleBlockinfile_Bad_MissingPath(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleBlockinfileWithClient(e, mock, map[string]any{ + "block": "content", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "path required") +} + +func TestModuleBlockinfile_Good_DestAliasForPath(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleBlockinfileWithClient(e, mock, map[string]any{ + "dest": "/etc/config", + "block": "data", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestModuleBlockinfile_Good_ScriptFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand("BLOCK_EOF", "", "write error", 1) + + result, err := moduleBlockinfileWithClient(e, mock, map[string]any{ + "path": "/etc/config", + "block": "data", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "write error") +} + +// --- stat module --- + +func TestModuleStat_Good_ExistingFile(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.addStat("/etc/nginx/nginx.conf", map[string]any{ + "exists": true, + "isdir": false, + "mode": "0644", + "size": 1234, + "uid": 0, + "gid": 0, + }) + + result, err := moduleStatWithClient(e, mock, map[string]any{ + "path": "/etc/nginx/nginx.conf", + }) + + require.NoError(t, err) + assert.False(t, result.Changed) // stat never changes anything + require.NotNil(t, result.Data) + + stat, ok := result.Data["stat"].(map[string]any) + require.True(t, ok) + assert.Equal(t, true, stat["exists"]) + assert.Equal(t, false, stat["isdir"]) + assert.Equal(t, "0644", stat["mode"]) + assert.Equal(t, 1234, stat["size"]) +} + +func TestModuleStat_Good_MissingFile(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleStatWithClient(e, mock, map[string]any{ + "path": "/nonexistent/file.txt", + }) + + require.NoError(t, err) + assert.False(t, result.Changed) + require.NotNil(t, result.Data) + + stat, ok := result.Data["stat"].(map[string]any) + require.True(t, ok) + assert.Equal(t, false, stat["exists"]) +} + +func TestModuleStat_Good_Directory(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.addStat("/var/log", map[string]any{ + "exists": true, + "isdir": true, + "mode": "0755", + }) + + result, err := moduleStatWithClient(e, mock, map[string]any{ + "path": "/var/log", + }) + + require.NoError(t, err) + stat := result.Data["stat"].(map[string]any) + assert.Equal(t, true, stat["exists"]) + assert.Equal(t, true, stat["isdir"]) +} + +func TestModuleStat_Good_FallbackFromFileSystem(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + // No explicit stat, but add a file — stat falls back to file existence + mock.addFile("/etc/hosts", []byte("127.0.0.1 localhost")) + + result, err := moduleStatWithClient(e, mock, map[string]any{ + "path": "/etc/hosts", + }) + + require.NoError(t, err) + stat := result.Data["stat"].(map[string]any) + assert.Equal(t, true, stat["exists"]) + assert.Equal(t, false, stat["isdir"]) +} + +func TestModuleStat_Bad_MissingPath(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleStatWithClient(e, mock, map[string]any{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "path required") +} + +// --- template module --- + +func TestModuleTemplate_Good_BasicTemplate(t *testing.T) { + tmpDir := t.TempDir() + srcPath := filepath.Join(tmpDir, "app.conf.j2") + require.NoError(t, os.WriteFile(srcPath, []byte("server_name={{ server_name }};"), 0644)) + + e, mock := newTestExecutorWithMock("host1") + e.SetVar("server_name", "web01.example.com") + + result, err := moduleTemplateWithClient(e, mock, map[string]any{ + "src": srcPath, + "dest": "/etc/nginx/conf.d/app.conf", + }, "host1", &Task{}) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Contains(t, result.Msg, "templated to /etc/nginx/conf.d/app.conf") + + // Verify upload was performed with templated content + assert.Equal(t, 1, mock.uploadCount()) + up := mock.lastUpload() + require.NotNil(t, up) + assert.Equal(t, "/etc/nginx/conf.d/app.conf", up.Remote) + // Template replaces {{ var }} — the TemplateFile does Jinja2 to Go conversion + assert.Contains(t, string(up.Content), "web01.example.com") +} + +func TestModuleTemplate_Good_CustomMode(t *testing.T) { + tmpDir := t.TempDir() + srcPath := filepath.Join(tmpDir, "script.sh.j2") + require.NoError(t, os.WriteFile(srcPath, []byte("#!/bin/bash\necho done"), 0644)) + + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleTemplateWithClient(e, mock, map[string]any{ + "src": srcPath, + "dest": "/usr/local/bin/run.sh", + "mode": "0755", + }, "host1", &Task{}) + + require.NoError(t, err) + assert.True(t, result.Changed) + + up := mock.lastUpload() + require.NotNil(t, up) + assert.Equal(t, os.FileMode(0755), up.Mode) +} + +func TestModuleTemplate_Bad_MissingSrc(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleTemplateWithClient(e, mock, map[string]any{ + "dest": "/tmp/out", + }, "host1", &Task{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "src and dest required") +} + +func TestModuleTemplate_Bad_MissingDest(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleTemplateWithClient(e, mock, map[string]any{ + "src": "/tmp/in.j2", + }, "host1", &Task{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "src and dest required") +} + +func TestModuleTemplate_Bad_SrcFileNotFound(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + _, err := moduleTemplateWithClient(e, mock, map[string]any{ + "src": "/nonexistent/template.j2", + "dest": "/tmp/out", + }, "host1", &Task{}) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "template") +} + +func TestModuleTemplate_Good_PlainTextNoVars(t *testing.T) { + tmpDir := t.TempDir() + srcPath := filepath.Join(tmpDir, "static.conf") + content := "listen 80;\nserver_name localhost;" + require.NoError(t, os.WriteFile(srcPath, []byte(content), 0644)) + + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleTemplateWithClient(e, mock, map[string]any{ + "src": srcPath, + "dest": "/etc/config", + }, "host1", &Task{}) + + require.NoError(t, err) + assert.True(t, result.Changed) + + up := mock.lastUpload() + require.NotNil(t, up) + assert.Equal(t, content, string(up.Content)) +} + +// --- Cross-module dispatch tests for file modules --- + +func TestExecuteModuleWithMock_Good_DispatchCopy(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + task := &Task{ + Module: "copy", + Args: map[string]any{ + "content": "hello world", + "dest": "/tmp/hello.txt", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, 1, mock.uploadCount()) +} + +func TestExecuteModuleWithMock_Good_DispatchFile(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + task := &Task{ + Module: "file", + Args: map[string]any{ + "path": "/opt/data", + "state": "directory", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted("mkdir")) +} + +func TestExecuteModuleWithMock_Good_DispatchStat(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.addStat("/etc/hosts", map[string]any{"exists": true, "isdir": false}) + + task := &Task{ + Module: "stat", + Args: map[string]any{ + "path": "/etc/hosts", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.False(t, result.Changed) + stat := result.Data["stat"].(map[string]any) + assert.Equal(t, true, stat["exists"]) +} + +func TestExecuteModuleWithMock_Good_DispatchLineinfile(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + task := &Task{ + Module: "lineinfile", + Args: map[string]any{ + "path": "/etc/hosts", + "line": "10.0.0.1 dbhost", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Good_DispatchBlockinfile(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + task := &Task{ + Module: "blockinfile", + Args: map[string]any{ + "path": "/etc/config", + "block": "key=value", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Good_DispatchTemplate(t *testing.T) { + tmpDir := t.TempDir() + srcPath := filepath.Join(tmpDir, "test.j2") + require.NoError(t, os.WriteFile(srcPath, []byte("static content"), 0644)) + + e, mock := newTestExecutorWithMock("host1") + + task := &Task{ + Module: "template", + Args: map[string]any{ + "src": srcPath, + "dest": "/etc/out.conf", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, 1, mock.uploadCount()) +} + +// --- Template variable resolution integration --- + +func TestModuleCopy_Good_TemplatedArgs(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + e.SetVar("deploy_path", "/opt/myapp") + + task := &Task{ + Module: "copy", + Args: map[string]any{ + "content": "deployed", + "dest": "{{ deploy_path }}/config.yml", + }, + } + + // Template the args as the executor does + args := e.templateArgs(task.Args, "host1", task) + result, err := moduleCopyWithClient(e, mock, args, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) + + up := mock.lastUpload() + require.NotNil(t, up) + assert.Equal(t, "/opt/myapp/config.yml", up.Remote) +} + +func TestModuleFile_Good_TemplatedPath(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + e.SetVar("app_dir", "/var/www/html") + + task := &Task{ + Module: "file", + Args: map[string]any{ + "path": "{{ app_dir }}/uploads", + "state": "directory", + "owner": "www-data", + }, + } + + args := e.templateArgs(task.Args, "host1", task) + result, err := moduleFileWithClient(e, mock, args) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`mkdir -p "/var/www/html/uploads"`)) + assert.True(t, mock.hasExecuted(`chown www-data "/var/www/html/uploads"`)) +} diff --git a/modules_infra_test.go b/modules_infra_test.go new file mode 100644 index 0000000..14b6ea3 --- /dev/null +++ b/modules_infra_test.go @@ -0,0 +1,1261 @@ +package ansible + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// =========================================================================== +// 1. Error Propagation — getHosts +// =========================================================================== + +func TestGetHosts_Infra_Good_AllPattern(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "web1": {AnsibleHost: "10.0.0.1"}, + "web2": {AnsibleHost: "10.0.0.2"}, + "db1": {AnsibleHost: "10.0.1.1"}, + }, + }, + }) + + hosts := e.getHosts("all") + assert.Len(t, hosts, 3) + assert.Contains(t, hosts, "web1") + assert.Contains(t, hosts, "web2") + assert.Contains(t, hosts, "db1") +} + +func TestGetHosts_Infra_Good_SpecificHost(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "web1": {AnsibleHost: "10.0.0.1"}, + "web2": {AnsibleHost: "10.0.0.2"}, + }, + }, + }) + + hosts := e.getHosts("web1") + assert.Equal(t, []string{"web1"}, hosts) +} + +func TestGetHosts_Infra_Good_GroupName(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Children: map[string]*InventoryGroup{ + "webservers": { + Hosts: map[string]*Host{ + "web1": {AnsibleHost: "10.0.0.1"}, + "web2": {AnsibleHost: "10.0.0.2"}, + }, + }, + "dbservers": { + Hosts: map[string]*Host{ + "db1": {AnsibleHost: "10.0.1.1"}, + }, + }, + }, + }, + }) + + hosts := e.getHosts("webservers") + assert.Len(t, hosts, 2) + assert.Contains(t, hosts, "web1") + assert.Contains(t, hosts, "web2") +} + +func TestGetHosts_Infra_Good_Localhost(t *testing.T) { + e := NewExecutor("/tmp") + // No inventory at all + hosts := e.getHosts("localhost") + assert.Equal(t, []string{"localhost"}, hosts) +} + +func TestGetHosts_Infra_Bad_NilInventory(t *testing.T) { + e := NewExecutor("/tmp") + // inventory is nil, non-localhost pattern + hosts := e.getHosts("webservers") + assert.Nil(t, hosts) +} + +func TestGetHosts_Infra_Bad_NonexistentHost(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "web1": {}, + }, + }, + }) + + hosts := e.getHosts("nonexistent") + assert.Empty(t, hosts) +} + +func TestGetHosts_Infra_Good_LimitFiltering(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "web1": {}, + "web2": {}, + "db1": {}, + }, + }, + }) + e.Limit = "web1" + + hosts := e.getHosts("all") + assert.Len(t, hosts, 1) + assert.Contains(t, hosts, "web1") +} + +func TestGetHosts_Infra_Good_LimitSubstringMatch(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "prod-web-01": {}, + "prod-web-02": {}, + "staging-web-01": {}, + }, + }, + }) + e.Limit = "prod" + + hosts := e.getHosts("all") + // Limit uses substring matching as fallback + assert.Len(t, hosts, 2) +} + +// =========================================================================== +// 1. Error Propagation — matchesTags +// =========================================================================== + +func TestMatchesTags_Infra_Good_NoFiltersNoTags(t *testing.T) { + e := NewExecutor("/tmp") + // No Tags, no SkipTags set + assert.True(t, e.matchesTags(nil)) +} + +func TestMatchesTags_Infra_Good_NoFiltersWithTaskTags(t *testing.T) { + e := NewExecutor("/tmp") + assert.True(t, e.matchesTags([]string{"deploy", "config"})) +} + +func TestMatchesTags_Infra_Good_IncludeMatchesOneOfMultiple(t *testing.T) { + e := NewExecutor("/tmp") + e.Tags = []string{"deploy"} + + // Task has deploy among its tags + assert.True(t, e.matchesTags([]string{"setup", "deploy", "config"})) +} + +func TestMatchesTags_Infra_Bad_IncludeNoMatch(t *testing.T) { + e := NewExecutor("/tmp") + e.Tags = []string{"deploy"} + + assert.False(t, e.matchesTags([]string{"build", "test"})) +} + +func TestMatchesTags_Infra_Good_SkipOverridesInclude(t *testing.T) { + e := NewExecutor("/tmp") + e.SkipTags = []string{"slow"} + + // Even with no include tags, skip tags filter out matching tasks + assert.False(t, e.matchesTags([]string{"deploy", "slow"})) + assert.True(t, e.matchesTags([]string{"deploy", "fast"})) +} + +func TestMatchesTags_Infra_Bad_IncludeFilterNoTaskTags(t *testing.T) { + e := NewExecutor("/tmp") + e.Tags = []string{"deploy"} + + // Tasks with no tags should not run when include tags are active + assert.False(t, e.matchesTags(nil)) + assert.False(t, e.matchesTags([]string{})) +} + +func TestMatchesTags_Infra_Good_AllTagMatchesEverything(t *testing.T) { + e := NewExecutor("/tmp") + e.Tags = []string{"all"} + + assert.True(t, e.matchesTags([]string{"deploy"})) + assert.True(t, e.matchesTags([]string{"config", "slow"})) +} + +// =========================================================================== +// 1. Error Propagation — evaluateWhen +// =========================================================================== + +func TestEvaluateWhen_Infra_Good_DefinedCheck(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "myresult": {Changed: true}, + } + + assert.True(t, e.evaluateWhen("myresult is defined", "host1", nil)) +} + +func TestEvaluateWhen_Infra_Good_NotDefinedCheck(t *testing.T) { + e := NewExecutor("/tmp") + // No results registered for host1 + assert.True(t, e.evaluateWhen("missing_var is not defined", "host1", nil)) +} + +func TestEvaluateWhen_Infra_Good_UndefinedAlias(t *testing.T) { + e := NewExecutor("/tmp") + assert.True(t, e.evaluateWhen("some_var is undefined", "host1", nil)) +} + +func TestEvaluateWhen_Infra_Good_SucceededCheck(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "result": {Failed: false, Changed: true}, + } + + assert.True(t, e.evaluateWhen("result is success", "host1", nil)) + assert.True(t, e.evaluateWhen("result is succeeded", "host1", nil)) +} + +func TestEvaluateWhen_Infra_Good_FailedCheck(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "result": {Failed: true}, + } + + assert.True(t, e.evaluateWhen("result is failed", "host1", nil)) +} + +func TestEvaluateWhen_Infra_Good_ChangedCheck(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "result": {Changed: true}, + } + + assert.True(t, e.evaluateWhen("result is changed", "host1", nil)) +} + +func TestEvaluateWhen_Infra_Good_SkippedCheck(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "result": {Skipped: true}, + } + + assert.True(t, e.evaluateWhen("result is skipped", "host1", nil)) +} + +func TestEvaluateWhen_Infra_Good_BoolVarTruthy(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["my_flag"] = true + + assert.True(t, e.evalCondition("my_flag", "host1")) +} + +func TestEvaluateWhen_Infra_Good_BoolVarFalsy(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["my_flag"] = false + + assert.False(t, e.evalCondition("my_flag", "host1")) +} + +func TestEvaluateWhen_Infra_Good_StringVarTruthy(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["my_str"] = "hello" + assert.True(t, e.evalCondition("my_str", "host1")) +} + +func TestEvaluateWhen_Infra_Good_StringVarEmptyFalsy(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["my_str"] = "" + assert.False(t, e.evalCondition("my_str", "host1")) +} + +func TestEvaluateWhen_Infra_Good_StringVarFalseLiteral(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["my_str"] = "false" + assert.False(t, e.evalCondition("my_str", "host1")) + + e.vars["my_str2"] = "False" + assert.False(t, e.evalCondition("my_str2", "host1")) +} + +func TestEvaluateWhen_Infra_Good_IntVarNonZero(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["count"] = 42 + assert.True(t, e.evalCondition("count", "host1")) +} + +func TestEvaluateWhen_Infra_Good_IntVarZero(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["count"] = 0 + assert.False(t, e.evalCondition("count", "host1")) +} + +func TestEvaluateWhen_Infra_Good_Negation(t *testing.T) { + e := NewExecutor("/tmp") + assert.False(t, e.evalCondition("not true", "host1")) + assert.True(t, e.evalCondition("not false", "host1")) +} + +func TestEvaluateWhen_Infra_Good_MultipleConditionsAllTrue(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["enabled"] = true + e.results["host1"] = map[string]*TaskResult{ + "prev": {Failed: false}, + } + + // Both conditions must be true (AND semantics) + assert.True(t, e.evaluateWhen([]any{"enabled", "prev is success"}, "host1", nil)) +} + +func TestEvaluateWhen_Infra_Bad_MultipleConditionsOneFails(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["enabled"] = true + + // "false" literal fails + assert.False(t, e.evaluateWhen([]any{"enabled", "false"}, "host1", nil)) +} + +func TestEvaluateWhen_Infra_Good_DefaultFilterInCondition(t *testing.T) { + e := NewExecutor("/tmp") + // Condition with default filter should be satisfied + assert.True(t, e.evalCondition("my_var | default(true)", "host1")) +} + +func TestEvaluateWhen_Infra_Good_RegisteredVarTruthy(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "check_result": {Failed: false, Skipped: false}, + } + + // Just referencing a registered var name evaluates truthy if not failed/skipped + assert.True(t, e.evalCondition("check_result", "host1")) +} + +func TestEvaluateWhen_Infra_Bad_RegisteredVarFailedFalsy(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "check_result": {Failed: true}, + } + + // A failed registered var should be falsy + assert.False(t, e.evalCondition("check_result", "host1")) +} + +// =========================================================================== +// 1. Error Propagation — templateString +// =========================================================================== + +func TestTemplateString_Infra_Good_SimpleSubstitution(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["app_name"] = "myapp" + + result := e.templateString("Deploying {{ app_name }}", "", nil) + assert.Equal(t, "Deploying myapp", result) +} + +func TestTemplateString_Infra_Good_MultipleVars(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["host"] = "db.example.com" + e.vars["port"] = 5432 + + result := e.templateString("postgresql://{{ host }}:{{ port }}/mydb", "", nil) + assert.Equal(t, "postgresql://db.example.com:5432/mydb", result) +} + +func TestTemplateString_Infra_Good_Unresolved(t *testing.T) { + e := NewExecutor("/tmp") + result := e.templateString("{{ missing_var }}", "", nil) + assert.Equal(t, "{{ missing_var }}", result) +} + +func TestTemplateString_Infra_Good_NoTemplateMarkup(t *testing.T) { + e := NewExecutor("/tmp") + result := e.templateString("just a plain string", "", nil) + assert.Equal(t, "just a plain string", result) +} + +func TestTemplateString_Infra_Good_RegisteredVarStdout(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "cmd_result": {Stdout: "42"}, + } + + result := e.templateString("{{ cmd_result.stdout }}", "host1", nil) + assert.Equal(t, "42", result) +} + +func TestTemplateString_Infra_Good_RegisteredVarRC(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "cmd_result": {RC: 0}, + } + + result := e.templateString("{{ cmd_result.rc }}", "host1", nil) + assert.Equal(t, "0", result) +} + +func TestTemplateString_Infra_Good_RegisteredVarChanged(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "cmd_result": {Changed: true}, + } + + result := e.templateString("{{ cmd_result.changed }}", "host1", nil) + assert.Equal(t, "true", result) +} + +func TestTemplateString_Infra_Good_RegisteredVarFailed(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "cmd_result": {Failed: true}, + } + + result := e.templateString("{{ cmd_result.failed }}", "host1", nil) + assert.Equal(t, "true", result) +} + +func TestTemplateString_Infra_Good_TaskVars(t *testing.T) { + e := NewExecutor("/tmp") + task := &Task{ + Vars: map[string]any{ + "task_var": "task_value", + }, + } + + result := e.templateString("{{ task_var }}", "host1", task) + assert.Equal(t, "task_value", result) +} + +func TestTemplateString_Infra_Good_FactsResolution(t *testing.T) { + e := NewExecutor("/tmp") + e.facts["host1"] = &Facts{ + Hostname: "web1", + FQDN: "web1.example.com", + Distribution: "ubuntu", + Version: "24.04", + Architecture: "x86_64", + Kernel: "6.5.0", + } + + assert.Equal(t, "web1", e.templateString("{{ ansible_hostname }}", "host1", nil)) + assert.Equal(t, "web1.example.com", e.templateString("{{ ansible_fqdn }}", "host1", nil)) + assert.Equal(t, "ubuntu", e.templateString("{{ ansible_distribution }}", "host1", nil)) + assert.Equal(t, "24.04", e.templateString("{{ ansible_distribution_version }}", "host1", nil)) + assert.Equal(t, "x86_64", e.templateString("{{ ansible_architecture }}", "host1", nil)) + assert.Equal(t, "6.5.0", e.templateString("{{ ansible_kernel }}", "host1", nil)) +} + +// =========================================================================== +// 1. Error Propagation — applyFilter +// =========================================================================== + +func TestApplyFilter_Infra_Good_DefaultWithValue(t *testing.T) { + e := NewExecutor("/tmp") + // When value is non-empty, default is not applied + assert.Equal(t, "hello", e.applyFilter("hello", "default('fallback')")) +} + +func TestApplyFilter_Infra_Good_DefaultWithEmpty(t *testing.T) { + e := NewExecutor("/tmp") + // When value is empty, default IS applied + assert.Equal(t, "fallback", e.applyFilter("", "default('fallback')")) +} + +func TestApplyFilter_Infra_Good_DefaultWithDoubleQuotes(t *testing.T) { + e := NewExecutor("/tmp") + assert.Equal(t, "fallback", e.applyFilter("", `default("fallback")`)) +} + +func TestApplyFilter_Infra_Good_BoolFilterTrue(t *testing.T) { + e := NewExecutor("/tmp") + assert.Equal(t, "true", e.applyFilter("true", "bool")) + assert.Equal(t, "true", e.applyFilter("True", "bool")) + assert.Equal(t, "true", e.applyFilter("yes", "bool")) + assert.Equal(t, "true", e.applyFilter("Yes", "bool")) + assert.Equal(t, "true", e.applyFilter("1", "bool")) +} + +func TestApplyFilter_Infra_Good_BoolFilterFalse(t *testing.T) { + e := NewExecutor("/tmp") + assert.Equal(t, "false", e.applyFilter("false", "bool")) + assert.Equal(t, "false", e.applyFilter("no", "bool")) + assert.Equal(t, "false", e.applyFilter("0", "bool")) + assert.Equal(t, "false", e.applyFilter("random", "bool")) +} + +func TestApplyFilter_Infra_Good_TrimFilter(t *testing.T) { + e := NewExecutor("/tmp") + assert.Equal(t, "hello", e.applyFilter(" hello ", "trim")) + assert.Equal(t, "no spaces", e.applyFilter("no spaces", "trim")) + assert.Equal(t, "", e.applyFilter(" ", "trim")) +} + +func TestApplyFilter_Infra_Good_B64Decode(t *testing.T) { + e := NewExecutor("/tmp") + // b64decode currently returns value unchanged (placeholder) + assert.Equal(t, "dGVzdA==", e.applyFilter("dGVzdA==", "b64decode")) +} + +func TestApplyFilter_Infra_Good_UnknownFilter(t *testing.T) { + e := NewExecutor("/tmp") + // Unknown filters return value unchanged + assert.Equal(t, "hello", e.applyFilter("hello", "nonexistent_filter")) +} + +func TestTemplateString_Infra_Good_FilterInTemplate(t *testing.T) { + e := NewExecutor("/tmp") + // When a var is defined, the filter passes through + e.vars["defined_var"] = "hello" + result := e.templateString("{{ defined_var | default('fallback') }}", "", nil) + assert.Equal(t, "hello", result) +} + +func TestTemplateString_Infra_Good_DefaultFilterEmptyVar(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["empty_var"] = "" + // When var is empty string, default filter applies + result := e.applyFilter("", "default('fallback')") + assert.Equal(t, "fallback", result) +} + +func TestTemplateString_Infra_Good_BoolFilterInTemplate(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["flag"] = "yes" + + result := e.templateString("{{ flag | bool }}", "", nil) + assert.Equal(t, "true", result) +} + +func TestTemplateString_Infra_Good_TrimFilterInTemplate(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["padded"] = " trimmed " + + result := e.templateString("{{ padded | trim }}", "", nil) + assert.Equal(t, "trimmed", result) +} + +// =========================================================================== +// 1. Error Propagation — resolveLoop +// =========================================================================== + +func TestResolveLoop_Infra_Good_SliceAny(t *testing.T) { + e := NewExecutor("/tmp") + items := e.resolveLoop([]any{"a", "b", "c"}, "host1") + assert.Len(t, items, 3) + assert.Equal(t, "a", items[0]) + assert.Equal(t, "b", items[1]) + assert.Equal(t, "c", items[2]) +} + +func TestResolveLoop_Infra_Good_SliceString(t *testing.T) { + e := NewExecutor("/tmp") + items := e.resolveLoop([]string{"x", "y"}, "host1") + assert.Len(t, items, 2) + assert.Equal(t, "x", items[0]) + assert.Equal(t, "y", items[1]) +} + +func TestResolveLoop_Infra_Good_NilLoop(t *testing.T) { + e := NewExecutor("/tmp") + items := e.resolveLoop(nil, "host1") + assert.Nil(t, items) +} + +func TestResolveLoop_Infra_Good_VarReference(t *testing.T) { + e := NewExecutor("/tmp") + e.vars["my_list"] = []any{"item1", "item2", "item3"} + + // When loop is a string that resolves to a variable containing a list + items := e.resolveLoop("{{ my_list }}", "host1") + // The template resolves "{{ my_list }}" but the result is a string representation, + // not the original list. The resolveLoop handles this by trying to look up + // the resolved value in vars again. + // Since templateString returns the string "[item1 item2 item3]", and that + // isn't a var name, items will be nil. This tests the edge case. + // The actual var name lookup happens when the loop value is just "my_list". + assert.Nil(t, items) +} + +func TestResolveLoop_Infra_Good_MixedTypes(t *testing.T) { + e := NewExecutor("/tmp") + items := e.resolveLoop([]any{"str", 42, true, map[string]any{"key": "val"}}, "host1") + assert.Len(t, items, 4) + assert.Equal(t, "str", items[0]) + assert.Equal(t, 42, items[1]) + assert.Equal(t, true, items[2]) +} + +// =========================================================================== +// 1. Error Propagation — handleNotify +// =========================================================================== + +func TestHandleNotify_Infra_Good_SingleString(t *testing.T) { + e := NewExecutor("/tmp") + e.handleNotify("restart nginx") + + assert.True(t, e.notified["restart nginx"]) + assert.False(t, e.notified["restart apache"]) +} + +func TestHandleNotify_Infra_Good_StringSlice(t *testing.T) { + e := NewExecutor("/tmp") + e.handleNotify([]string{"restart nginx", "reload haproxy"}) + + assert.True(t, e.notified["restart nginx"]) + assert.True(t, e.notified["reload haproxy"]) +} + +func TestHandleNotify_Infra_Good_AnySlice(t *testing.T) { + e := NewExecutor("/tmp") + e.handleNotify([]any{"handler1", "handler2", "handler3"}) + + assert.True(t, e.notified["handler1"]) + assert.True(t, e.notified["handler2"]) + assert.True(t, e.notified["handler3"]) +} + +func TestHandleNotify_Infra_Good_NilNotify(t *testing.T) { + e := NewExecutor("/tmp") + // Should not panic + e.handleNotify(nil) + assert.Empty(t, e.notified) +} + +func TestHandleNotify_Infra_Good_MultipleCallsAccumulate(t *testing.T) { + e := NewExecutor("/tmp") + e.handleNotify("handler1") + e.handleNotify("handler2") + + assert.True(t, e.notified["handler1"]) + assert.True(t, e.notified["handler2"]) +} + +// =========================================================================== +// 1. Error Propagation — normalizeConditions +// =========================================================================== + +func TestNormalizeConditions_Infra_Good_String(t *testing.T) { + result := normalizeConditions("my_var is defined") + assert.Equal(t, []string{"my_var is defined"}, result) +} + +func TestNormalizeConditions_Infra_Good_StringSlice(t *testing.T) { + result := normalizeConditions([]string{"cond1", "cond2"}) + assert.Equal(t, []string{"cond1", "cond2"}, result) +} + +func TestNormalizeConditions_Infra_Good_AnySlice(t *testing.T) { + result := normalizeConditions([]any{"cond1", "cond2"}) + assert.Equal(t, []string{"cond1", "cond2"}, result) +} + +func TestNormalizeConditions_Infra_Good_Nil(t *testing.T) { + result := normalizeConditions(nil) + assert.Nil(t, result) +} + +func TestNormalizeConditions_Infra_Good_IntIgnored(t *testing.T) { + // Non-string types in any slice are silently skipped + result := normalizeConditions([]any{"cond1", 42}) + assert.Equal(t, []string{"cond1"}, result) +} + +func TestNormalizeConditions_Infra_Good_UnsupportedType(t *testing.T) { + result := normalizeConditions(42) + assert.Nil(t, result) +} + +// =========================================================================== +// 2. Become/Sudo +// =========================================================================== + +func TestBecome_Infra_Good_SetBecomeTrue(t *testing.T) { + cfg := SSHConfig{ + Host: "test-host", + Port: 22, + User: "deploy", + Become: true, + BecomeUser: "root", + BecomePass: "secret", + } + client, err := NewSSHClient(cfg) + require.NoError(t, err) + + assert.True(t, client.become) + assert.Equal(t, "root", client.becomeUser) + assert.Equal(t, "secret", client.becomePass) +} + +func TestBecome_Infra_Good_SetBecomeFalse(t *testing.T) { + cfg := SSHConfig{ + Host: "test-host", + Port: 22, + User: "deploy", + } + client, err := NewSSHClient(cfg) + require.NoError(t, err) + + assert.False(t, client.become) + assert.Empty(t, client.becomeUser) + assert.Empty(t, client.becomePass) +} + +func TestBecome_Infra_Good_SetBecomeMethod(t *testing.T) { + cfg := SSHConfig{Host: "test-host"} + client, err := NewSSHClient(cfg) + require.NoError(t, err) + + assert.False(t, client.become) + + client.SetBecome(true, "admin", "pass123") + assert.True(t, client.become) + assert.Equal(t, "admin", client.becomeUser) + assert.Equal(t, "pass123", client.becomePass) +} + +func TestBecome_Infra_Good_DisableAfterEnable(t *testing.T) { + cfg := SSHConfig{Host: "test-host"} + client, err := NewSSHClient(cfg) + require.NoError(t, err) + + client.SetBecome(true, "root", "secret") + assert.True(t, client.become) + + client.SetBecome(false, "", "") + assert.False(t, client.become) + // becomeUser and becomePass are only updated if non-empty + assert.Equal(t, "root", client.becomeUser) + assert.Equal(t, "secret", client.becomePass) +} + +func TestBecome_Infra_Good_MockBecomeTracking(t *testing.T) { + mock := NewMockSSHClient() + assert.False(t, mock.become) + + mock.SetBecome(true, "root", "password") + assert.True(t, mock.become) + assert.Equal(t, "root", mock.becomeUser) + assert.Equal(t, "password", mock.becomePass) +} + +func TestBecome_Infra_Good_DefaultBecomeUserRoot(t *testing.T) { + // When become is true but no user specified, it defaults to root in the Run method + cfg := SSHConfig{ + Host: "test-host", + Become: true, + // BecomeUser not set + } + client, err := NewSSHClient(cfg) + require.NoError(t, err) + + assert.True(t, client.become) + assert.Empty(t, client.becomeUser) // Empty in config... + // The Run() method defaults to "root" when becomeUser is empty +} + +func TestBecome_Infra_Good_PasswordlessBecome(t *testing.T) { + cfg := SSHConfig{ + Host: "test-host", + Become: true, + BecomeUser: "root", + // No BecomePass and no Password — triggers sudo -n + } + client, err := NewSSHClient(cfg) + require.NoError(t, err) + + assert.True(t, client.become) + assert.Empty(t, client.becomePass) + assert.Empty(t, client.password) +} + +func TestBecome_Infra_Good_ExecutorPlayBecome(t *testing.T) { + // Test that getClient applies play-level become settings + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "host1": {AnsibleHost: "127.0.0.1"}, + }, + }, + }) + + play := &Play{ + Become: true, + BecomeUser: "admin", + } + + // getClient will attempt SSH connection which will fail, + // but we can verify the config would be set correctly + // by checking the SSHConfig construction logic. + // Since getClient creates real connections, we just verify + // that the become fields are set on the play. + assert.True(t, play.Become) + assert.Equal(t, "admin", play.BecomeUser) +} + +// =========================================================================== +// 3. Fact Gathering +// =========================================================================== + +func TestFacts_Infra_Good_UbuntuParsing(t *testing.T) { + e, mock := newTestExecutorWithMock("web1") + + // Mock os-release output for Ubuntu + mock.expectCommand(`hostname -f`, "web1.example.com\n", "", 0) + mock.expectCommand(`hostname -s`, "web1\n", "", 0) + mock.expectCommand(`cat /etc/os-release`, "ID=ubuntu\nVERSION_ID=\"24.04\"\n", "", 0) + mock.expectCommand(`uname -m`, "x86_64\n", "", 0) + mock.expectCommand(`uname -r`, "6.5.0-44-generic\n", "", 0) + + // Simulate fact gathering by directly populating facts + // using the same parsing logic as gatherFacts + facts := &Facts{} + + stdout, _, _, _ := mock.Run(nil, "hostname -f 2>/dev/null || hostname") + facts.FQDN = trimSpace(stdout) + + stdout, _, _, _ = mock.Run(nil, "hostname -s 2>/dev/null || hostname") + facts.Hostname = trimSpace(stdout) + + stdout, _, _, _ = mock.Run(nil, "cat /etc/os-release 2>/dev/null | grep -E '^(ID|VERSION_ID)=' | head -2") + for _, line := range splitLines(stdout) { + if hasPrefix(line, "ID=") { + facts.Distribution = trimQuotes(trimPrefix(line, "ID=")) + } + if hasPrefix(line, "VERSION_ID=") { + facts.Version = trimQuotes(trimPrefix(line, "VERSION_ID=")) + } + } + + stdout, _, _, _ = mock.Run(nil, "uname -m") + facts.Architecture = trimSpace(stdout) + + stdout, _, _, _ = mock.Run(nil, "uname -r") + facts.Kernel = trimSpace(stdout) + + e.facts["web1"] = facts + + assert.Equal(t, "web1.example.com", facts.FQDN) + assert.Equal(t, "web1", facts.Hostname) + assert.Equal(t, "ubuntu", facts.Distribution) + assert.Equal(t, "24.04", facts.Version) + assert.Equal(t, "x86_64", facts.Architecture) + assert.Equal(t, "6.5.0-44-generic", facts.Kernel) + + // Now verify template resolution with these facts + result := e.templateString("{{ ansible_hostname }}", "web1", nil) + assert.Equal(t, "web1", result) + + result = e.templateString("{{ ansible_distribution }}", "web1", nil) + assert.Equal(t, "ubuntu", result) +} + +func TestFacts_Infra_Good_CentOSParsing(t *testing.T) { + facts := &Facts{} + + osRelease := "ID=centos\nVERSION_ID=\"8\"\n" + for _, line := range splitLines(osRelease) { + if hasPrefix(line, "ID=") { + facts.Distribution = trimQuotes(trimPrefix(line, "ID=")) + } + if hasPrefix(line, "VERSION_ID=") { + facts.Version = trimQuotes(trimPrefix(line, "VERSION_ID=")) + } + } + + assert.Equal(t, "centos", facts.Distribution) + assert.Equal(t, "8", facts.Version) +} + +func TestFacts_Infra_Good_AlpineParsing(t *testing.T) { + facts := &Facts{} + + osRelease := "ID=alpine\nVERSION_ID=3.19.1\n" + for _, line := range splitLines(osRelease) { + if hasPrefix(line, "ID=") { + facts.Distribution = trimQuotes(trimPrefix(line, "ID=")) + } + if hasPrefix(line, "VERSION_ID=") { + facts.Version = trimQuotes(trimPrefix(line, "VERSION_ID=")) + } + } + + assert.Equal(t, "alpine", facts.Distribution) + assert.Equal(t, "3.19.1", facts.Version) +} + +func TestFacts_Infra_Good_DebianParsing(t *testing.T) { + facts := &Facts{} + + osRelease := "ID=debian\nVERSION_ID=\"12\"\n" + for _, line := range splitLines(osRelease) { + if hasPrefix(line, "ID=") { + facts.Distribution = trimQuotes(trimPrefix(line, "ID=")) + } + if hasPrefix(line, "VERSION_ID=") { + facts.Version = trimQuotes(trimPrefix(line, "VERSION_ID=")) + } + } + + assert.Equal(t, "debian", facts.Distribution) + assert.Equal(t, "12", facts.Version) +} + +func TestFacts_Infra_Good_HostnameFromCommand(t *testing.T) { + e := NewExecutor("/tmp") + e.facts["host1"] = &Facts{ + Hostname: "myserver", + FQDN: "myserver.example.com", + } + + assert.Equal(t, "myserver", e.templateString("{{ ansible_hostname }}", "host1", nil)) + assert.Equal(t, "myserver.example.com", e.templateString("{{ ansible_fqdn }}", "host1", nil)) +} + +func TestFacts_Infra_Good_ArchitectureResolution(t *testing.T) { + e := NewExecutor("/tmp") + e.facts["host1"] = &Facts{ + Architecture: "aarch64", + } + + result := e.templateString("{{ ansible_architecture }}", "host1", nil) + assert.Equal(t, "aarch64", result) +} + +func TestFacts_Infra_Good_KernelResolution(t *testing.T) { + e := NewExecutor("/tmp") + e.facts["host1"] = &Facts{ + Kernel: "5.15.0-91-generic", + } + + result := e.templateString("{{ ansible_kernel }}", "host1", nil) + assert.Equal(t, "5.15.0-91-generic", result) +} + +func TestFacts_Infra_Good_NoFactsForHost(t *testing.T) { + e := NewExecutor("/tmp") + // No facts gathered for host1 + result := e.templateString("{{ ansible_hostname }}", "host1", nil) + // Should remain unresolved + assert.Equal(t, "{{ ansible_hostname }}", result) +} + +func TestFacts_Infra_Good_LocalhostFacts(t *testing.T) { + // When connection is local, gatherFacts sets minimal facts + e := NewExecutor("/tmp") + e.facts["localhost"] = &Facts{ + Hostname: "localhost", + } + + result := e.templateString("{{ ansible_hostname }}", "localhost", nil) + assert.Equal(t, "localhost", result) +} + +// =========================================================================== +// 4. Idempotency +// =========================================================================== + +func TestIdempotency_Infra_Good_GroupAlreadyExists(t *testing.T) { + _, mock := newTestExecutorWithMock("host1") + + // Mock: getent group docker succeeds (group exists) — the || means groupadd is skipped + mock.expectCommand(`getent group docker`, "docker:x:999:\n", "", 0) + + task := &Task{ + Module: "group", + Args: map[string]any{ + "name": "docker", + "state": "present", + }, + } + + result, err := moduleGroupWithClient(nil, mock, task.Args) + require.NoError(t, err) + + // The module runs the command: getent group docker >/dev/null 2>&1 || groupadd docker + // Since getent succeeds (rc=0), groupadd is not executed by the shell. + // However, the module always reports changed=true because it does not + // check idempotency at the Go level. This tests the current behaviour. + assert.True(t, result.Changed) + assert.False(t, result.Failed) +} + +func TestIdempotency_Infra_Good_AuthorizedKeyAlreadyPresent(t *testing.T) { + _, mock := newTestExecutorWithMock("host1") + + testKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC7xfG..." + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA user@host" + + // Mock: getent passwd returns home dir + mock.expectCommand(`getent passwd deploy`, "/home/deploy\n", "", 0) + + // Mock: mkdir + chmod + chown for .ssh dir + mock.expectCommand(`mkdir -p`, "", "", 0) + + // Mock: grep finds the key (rc=0, key is present) + mock.expectCommand(`grep -qF`, "", "", 0) + + // Mock: chmod + chown for authorized_keys + mock.expectCommand(`chmod 600`, "", "", 0) + + result, err := moduleAuthorizedKeyWithClient(nil, mock, map[string]any{ + "user": "deploy", + "key": testKey, + "state": "present", + }) + require.NoError(t, err) + + // Module reports changed=true regardless (it doesn't check grep result at Go level) + // The grep || echo construct handles idempotency at the shell level + assert.NotNil(t, result) + assert.False(t, result.Failed) +} + +func TestIdempotency_Infra_Good_DockerComposeUpToDate(t *testing.T) { + _, mock := newTestExecutorWithMock("host1") + + // Mock: docker compose up -d returns "Up to date" in stdout + mock.expectCommand(`docker compose up -d`, "web1 Up to date\nnginx Up to date\n", "", 0) + + result, err := moduleDockerComposeWithClient(nil, mock, map[string]any{ + "project_src": "/opt/myapp", + "state": "present", + }) + require.NoError(t, err) + require.NotNil(t, result) + + // When stdout contains "Up to date", changed should be false + assert.False(t, result.Changed) + assert.False(t, result.Failed) +} + +func TestIdempotency_Infra_Good_DockerComposeChanged(t *testing.T) { + _, mock := newTestExecutorWithMock("host1") + + // Mock: docker compose up -d with actual changes + mock.expectCommand(`docker compose up -d`, "Creating web1 ... done\nCreating nginx ... done\n", "", 0) + + result, err := moduleDockerComposeWithClient(nil, mock, map[string]any{ + "project_src": "/opt/myapp", + "state": "present", + }) + require.NoError(t, err) + require.NotNil(t, result) + + // When stdout does NOT contain "Up to date", changed should be true + assert.True(t, result.Changed) + assert.False(t, result.Failed) +} + +func TestIdempotency_Infra_Good_DockerComposeUpToDateInStderr(t *testing.T) { + _, mock := newTestExecutorWithMock("host1") + + // Some versions of docker compose output status to stderr + mock.expectCommand(`docker compose up -d`, "", "web1 Up to date\n", 0) + + result, err := moduleDockerComposeWithClient(nil, mock, map[string]any{ + "project_src": "/opt/myapp", + "state": "present", + }) + require.NoError(t, err) + require.NotNil(t, result) + + // The docker compose module checks both stdout and stderr for "Up to date" + assert.False(t, result.Changed) +} + +func TestIdempotency_Infra_Good_GroupCreationWhenNew(t *testing.T) { + _, mock := newTestExecutorWithMock("host1") + + // Mock: getent fails (group does not exist), groupadd succeeds + mock.expectCommand(`getent group newgroup`, "", "no such group", 2) + // The overall command runs in shell: getent group newgroup >/dev/null 2>&1 || groupadd newgroup + // Since we match on the full command, the mock will return rc=0 default + mock.expectCommand(`getent group newgroup .* groupadd`, "", "", 0) + + result, err := moduleGroupWithClient(nil, mock, map[string]any{ + "name": "newgroup", + "state": "present", + }) + require.NoError(t, err) + + assert.True(t, result.Changed) + assert.False(t, result.Failed) +} + +func TestIdempotency_Infra_Good_ServiceStatChanged(t *testing.T) { + _, mock := newTestExecutorWithMock("host1") + + // Mock: stat reports the file exists + mock.addStat("/etc/config.conf", map[string]any{"exists": true, "isdir": false}) + + result, err := moduleStatWithClient(nil, mock, map[string]any{ + "path": "/etc/config.conf", + }) + require.NoError(t, err) + + // Stat module should always report changed=false + assert.False(t, result.Changed) + assert.NotNil(t, result.Data) + stat := result.Data["stat"].(map[string]any) + assert.True(t, stat["exists"].(bool)) +} + +func TestIdempotency_Infra_Good_StatFileNotFound(t *testing.T) { + _, mock := newTestExecutorWithMock("host1") + + // No stat info added — will return exists=false from mock + result, err := moduleStatWithClient(nil, mock, map[string]any{ + "path": "/nonexistent/file", + }) + require.NoError(t, err) + + assert.False(t, result.Changed) + assert.NotNil(t, result.Data) + stat := result.Data["stat"].(map[string]any) + assert.False(t, stat["exists"].(bool)) +} + +// =========================================================================== +// Additional cross-cutting edge cases +// =========================================================================== + +func TestResolveExpr_Infra_Good_HostVars(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "host1": { + AnsibleHost: "10.0.0.1", + Vars: map[string]any{ + "custom_var": "custom_value", + }, + }, + }, + }, + }) + + result := e.templateString("{{ custom_var }}", "host1", nil) + assert.Equal(t, "custom_value", result) +} + +func TestTemplateArgs_Infra_Good_InventoryHostname(t *testing.T) { + e := NewExecutor("/tmp") + + args := map[string]any{ + "hostname": "{{ inventory_hostname }}", + } + + result := e.templateArgs(args, "web1", nil) + assert.Equal(t, "web1", result["hostname"]) +} + +func TestEvalCondition_Infra_Good_UnknownDefaultsTrue(t *testing.T) { + e := NewExecutor("/tmp") + // Unknown conditions default to true (permissive) + assert.True(t, e.evalCondition("some_complex_expression == 'value'", "host1")) +} + +func TestGetRegisteredVar_Infra_Good_DottedAccess(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "my_cmd": {Stdout: "output_text", RC: 0}, + } + + // getRegisteredVar parses dotted names + result := e.getRegisteredVar("host1", "my_cmd.stdout") + // getRegisteredVar only looks up the base name (before the dot) + assert.NotNil(t, result) + assert.Equal(t, "output_text", result.Stdout) +} + +func TestGetRegisteredVar_Infra_Bad_NotRegistered(t *testing.T) { + e := NewExecutor("/tmp") + + result := e.getRegisteredVar("host1", "nonexistent") + assert.Nil(t, result) +} + +func TestGetRegisteredVar_Infra_Bad_WrongHost(t *testing.T) { + e := NewExecutor("/tmp") + e.results["host1"] = map[string]*TaskResult{ + "my_cmd": {Stdout: "output"}, + } + + // Different host has no results + result := e.getRegisteredVar("host2", "my_cmd") + assert.Nil(t, result) +} + +// =========================================================================== +// String helper utilities used by fact tests +// =========================================================================== + +func trimSpace(s string) string { + result := "" + for _, c := range s { + if c != '\n' && c != '\r' && c != ' ' && c != '\t' { + result += string(c) + } else if len(result) > 0 && result[len(result)-1] != ' ' { + result += " " + } + } + // Actually just use strings.TrimSpace + return stringsTrimSpace(s) +} + +func trimQuotes(s string) string { + if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { + return s[1 : len(s)-1] + } + return s +} + +func trimPrefix(s, prefix string) string { + if len(s) >= len(prefix) && s[:len(prefix)] == prefix { + return s[len(prefix):] + } + return s +} + +func hasPrefix(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + +func splitLines(s string) []string { + var lines []string + current := "" + for _, c := range s { + if c == '\n' { + lines = append(lines, current) + current = "" + } else { + current += string(c) + } + } + if current != "" { + lines = append(lines, current) + } + return lines +} + +func stringsTrimSpace(s string) string { + start := 0 + end := len(s) + for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\n' || s[start] == '\r') { + start++ + } + for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\n' || s[end-1] == '\r') { + end-- + } + return s[start:end] +} diff --git a/modules_svc_test.go b/modules_svc_test.go new file mode 100644 index 0000000..8d642eb --- /dev/null +++ b/modules_svc_test.go @@ -0,0 +1,950 @@ +package ansible + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================ +// Step 1.3: service / systemd / apt / apt_key / apt_repository / package / pip module tests +// ============================================================ + +// --- service module --- + +func TestModuleService_Good_Start(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl start nginx`, "Started", "", 0) + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "started", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`systemctl start nginx`)) + assert.Equal(t, 1, mock.commandCount()) +} + +func TestModuleService_Good_Stop(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl stop nginx`, "", "", 0) + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "stopped", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`systemctl stop nginx`)) +} + +func TestModuleService_Good_Restart(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl restart docker`, "", "", 0) + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "docker", + "state": "restarted", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`systemctl restart docker`)) +} + +func TestModuleService_Good_Reload(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl reload nginx`, "", "", 0) + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "reloaded", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`systemctl reload nginx`)) +} + +func TestModuleService_Good_Enable(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl enable nginx`, "", "", 0) + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "nginx", + "enabled": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`systemctl enable nginx`)) +} + +func TestModuleService_Good_Disable(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl disable nginx`, "", "", 0) + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "nginx", + "enabled": false, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`systemctl disable nginx`)) +} + +func TestModuleService_Good_StartAndEnable(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl start nginx`, "", "", 0) + mock.expectCommand(`systemctl enable nginx`, "", "", 0) + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "started", + "enabled": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.Equal(t, 2, mock.commandCount()) + assert.True(t, mock.hasExecuted(`systemctl start nginx`)) + assert.True(t, mock.hasExecuted(`systemctl enable nginx`)) +} + +func TestModuleService_Good_RestartAndDisable(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl restart sshd`, "", "", 0) + mock.expectCommand(`systemctl disable sshd`, "", "", 0) + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "sshd", + "state": "restarted", + "enabled": false, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.Equal(t, 2, mock.commandCount()) + assert.True(t, mock.hasExecuted(`systemctl restart sshd`)) + assert.True(t, mock.hasExecuted(`systemctl disable sshd`)) +} + +func TestModuleService_Bad_MissingName(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleServiceWithClient(e, mock, map[string]any{ + "state": "started", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "name required") +} + +func TestModuleService_Good_NoStateNoEnabled(t *testing.T) { + // When neither state nor enabled is provided, no commands run + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "nginx", + }) + + require.NoError(t, err) + assert.False(t, result.Changed) + assert.False(t, result.Failed) + assert.Equal(t, 0, mock.commandCount()) +} + +func TestModuleService_Good_CommandFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl start.*`, "", "Failed to start nginx.service", 1) + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "started", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "Failed to start nginx.service") + assert.Equal(t, 1, result.RC) +} + +func TestModuleService_Good_FirstCommandFailsSkipsRest(t *testing.T) { + // When state command fails, enable should not run + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl start`, "", "unit not found", 5) + + result, err := moduleServiceWithClient(e, mock, map[string]any{ + "name": "nonexistent", + "state": "started", + "enabled": true, + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + // Only the start command should have been attempted + assert.Equal(t, 1, mock.commandCount()) + assert.False(t, mock.hasExecuted(`systemctl enable`)) +} + +// --- systemd module --- + +func TestModuleSystemd_Good_DaemonReloadThenStart(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl daemon-reload`, "", "", 0) + mock.expectCommand(`systemctl start nginx`, "", "", 0) + + result, err := moduleSystemdWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "started", + "daemon_reload": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + + // daemon-reload must run first, then start + cmds := mock.executedCommands() + require.GreaterOrEqual(t, len(cmds), 2) + assert.Contains(t, cmds[0].Cmd, "daemon-reload") + assert.Contains(t, cmds[1].Cmd, "systemctl start nginx") +} + +func TestModuleSystemd_Good_DaemonReloadOnly(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl daemon-reload`, "", "", 0) + + result, err := moduleSystemdWithClient(e, mock, map[string]any{ + "name": "nginx", + "daemon_reload": true, + }) + + require.NoError(t, err) + // daemon-reload runs, but no state/enabled means no further commands + // Changed is false because moduleService returns Changed: len(cmds) > 0 + // and no cmds were built (no state, no enabled) + assert.False(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`systemctl daemon-reload`)) +} + +func TestModuleSystemd_Good_DelegationToService(t *testing.T) { + // Without daemon_reload, systemd delegates entirely to service + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl restart docker`, "", "", 0) + + result, err := moduleSystemdWithClient(e, mock, map[string]any{ + "name": "docker", + "state": "restarted", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`systemctl restart docker`)) + // No daemon-reload should have run + assert.False(t, mock.hasExecuted(`daemon-reload`)) +} + +func TestModuleSystemd_Good_DaemonReloadWithEnable(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl daemon-reload`, "", "", 0) + mock.expectCommand(`systemctl enable myapp`, "", "", 0) + + result, err := moduleSystemdWithClient(e, mock, map[string]any{ + "name": "myapp", + "enabled": true, + "daemon_reload": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`systemctl daemon-reload`)) + assert.True(t, mock.hasExecuted(`systemctl enable myapp`)) +} + +// --- apt module --- + +func TestModuleApt_Good_InstallPresent(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`apt-get install -y -qq nginx`, "installed", "", 0) + + result, err := moduleAptWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "present", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`DEBIAN_FRONTEND=noninteractive apt-get install -y -qq nginx`)) +} + +func TestModuleApt_Good_InstallInstalled(t *testing.T) { + // state=installed is an alias for present + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`apt-get install -y -qq curl`, "", "", 0) + + result, err := moduleAptWithClient(e, mock, map[string]any{ + "name": "curl", + "state": "installed", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`apt-get install -y -qq curl`)) +} + +func TestModuleApt_Good_RemoveAbsent(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`apt-get remove -y -qq nginx`, "", "", 0) + + result, err := moduleAptWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`DEBIAN_FRONTEND=noninteractive apt-get remove -y -qq nginx`)) +} + +func TestModuleApt_Good_RemoveRemoved(t *testing.T) { + // state=removed is an alias for absent + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`apt-get remove -y -qq nginx`, "", "", 0) + + result, err := moduleAptWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "removed", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`apt-get remove -y -qq nginx`)) +} + +func TestModuleApt_Good_UpgradeLatest(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`apt-get install -y -qq --only-upgrade nginx`, "", "", 0) + + result, err := moduleAptWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "latest", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`DEBIAN_FRONTEND=noninteractive apt-get install -y -qq --only-upgrade nginx`)) +} + +func TestModuleApt_Good_UpdateCacheBeforeInstall(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`apt-get update`, "", "", 0) + mock.expectCommand(`apt-get install -y -qq nginx`, "", "", 0) + + result, err := moduleAptWithClient(e, mock, map[string]any{ + "name": "nginx", + "state": "present", + "update_cache": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + + // apt-get update must run before install + cmds := mock.executedCommands() + require.GreaterOrEqual(t, len(cmds), 2) + assert.Contains(t, cmds[0].Cmd, "apt-get update") + assert.Contains(t, cmds[1].Cmd, "apt-get install") +} + +func TestModuleApt_Good_UpdateCacheOnly(t *testing.T) { + // update_cache with no name means update only, no install + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`apt-get update`, "", "", 0) + + result, err := moduleAptWithClient(e, mock, map[string]any{ + "update_cache": true, + }) + + require.NoError(t, err) + // No package to install → not changed (cmd is empty) + assert.False(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`apt-get update`)) +} + +func TestModuleApt_Good_CommandFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`apt-get install`, "", "E: Unable to locate package badpkg", 100) + + result, err := moduleAptWithClient(e, mock, map[string]any{ + "name": "badpkg", + "state": "present", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "Unable to locate package") + assert.Equal(t, 100, result.RC) +} + +func TestModuleApt_Good_DefaultStateIsPresent(t *testing.T) { + // If no state is given, default is "present" (install) + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`apt-get install -y -qq vim`, "", "", 0) + + result, err := moduleAptWithClient(e, mock, map[string]any{ + "name": "vim", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`apt-get install -y -qq vim`)) +} + +// --- apt_key module --- + +func TestModuleAptKey_Good_AddWithKeyring(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`curl -fsSL.*gpg --dearmor`, "", "", 0) + + result, err := moduleAptKeyWithClient(e, mock, map[string]any{ + "url": "https://packages.example.com/key.gpg", + "keyring": "/etc/apt/keyrings/example.gpg", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`curl -fsSL`)) + assert.True(t, mock.hasExecuted(`gpg --dearmor -o`)) + assert.True(t, mock.containsSubstring("/etc/apt/keyrings/example.gpg")) +} + +func TestModuleAptKey_Good_AddWithoutKeyring(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`curl -fsSL.*apt-key add -`, "", "", 0) + + result, err := moduleAptKeyWithClient(e, mock, map[string]any{ + "url": "https://packages.example.com/key.gpg", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`apt-key add -`)) +} + +func TestModuleAptKey_Good_RemoveKey(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleAptKeyWithClient(e, mock, map[string]any{ + "keyring": "/etc/apt/keyrings/old.gpg", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`rm -f`)) + assert.True(t, mock.containsSubstring("/etc/apt/keyrings/old.gpg")) +} + +func TestModuleAptKey_Good_RemoveWithoutKeyring(t *testing.T) { + // Absent with no keyring — still succeeds, just no rm command + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleAptKeyWithClient(e, mock, map[string]any{ + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, 0, mock.commandCount()) +} + +func TestModuleAptKey_Bad_MissingURL(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleAptKeyWithClient(e, mock, map[string]any{ + "state": "present", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "url required") +} + +func TestModuleAptKey_Good_CommandFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`curl`, "", "curl: (22) 404 Not Found", 22) + + result, err := moduleAptKeyWithClient(e, mock, map[string]any{ + "url": "https://invalid.example.com/key.gpg", + "keyring": "/etc/apt/keyrings/bad.gpg", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "404 Not Found") +} + +// --- apt_repository module --- + +func TestModuleAptRepository_Good_AddRepository(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`echo.*sources\.list\.d`, "", "", 0) + mock.expectCommand(`apt-get update`, "", "", 0) + + result, err := moduleAptRepositoryWithClient(e, mock, map[string]any{ + "repo": "deb https://packages.example.com/apt stable main", + "filename": "example", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.containsSubstring("/etc/apt/sources.list.d/example.list")) +} + +func TestModuleAptRepository_Good_RemoveRepository(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + + result, err := moduleAptRepositoryWithClient(e, mock, map[string]any{ + "repo": "deb https://packages.example.com/apt stable main", + "filename": "example", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`rm -f`)) + assert.True(t, mock.containsSubstring("example.list")) +} + +func TestModuleAptRepository_Good_AddWithUpdateCache(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`echo`, "", "", 0) + mock.expectCommand(`apt-get update`, "", "", 0) + + result, err := moduleAptRepositoryWithClient(e, mock, map[string]any{ + "repo": "deb https://ppa.example.com/repo main", + "filename": "ppa-example", + "update_cache": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + + // update_cache defaults to true, so apt-get update should run + assert.True(t, mock.hasExecuted(`apt-get update`)) +} + +func TestModuleAptRepository_Good_AddWithoutUpdateCache(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`echo`, "", "", 0) + + result, err := moduleAptRepositoryWithClient(e, mock, map[string]any{ + "repo": "deb https://ppa.example.com/repo main", + "filename": "no-update", + "update_cache": false, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + + // update_cache=false, so no apt-get update + assert.False(t, mock.hasExecuted(`apt-get update`)) +} + +func TestModuleAptRepository_Good_CustomFilename(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`echo`, "", "", 0) + mock.expectCommand(`apt-get update`, "", "", 0) + + result, err := moduleAptRepositoryWithClient(e, mock, map[string]any{ + "repo": "deb http://ppa.launchpad.net/test/ppa/ubuntu jammy main", + "filename": "custom-ppa", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.containsSubstring("/etc/apt/sources.list.d/custom-ppa.list")) +} + +func TestModuleAptRepository_Good_AutoGeneratedFilename(t *testing.T) { + // When no filename is given, it auto-generates from the repo string + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`echo`, "", "", 0) + mock.expectCommand(`apt-get update`, "", "", 0) + + result, err := moduleAptRepositoryWithClient(e, mock, map[string]any{ + "repo": "deb https://example.com/repo main", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + // Filename should be derived from repo: spaces→dashes, slashes→dashes, colons removed + assert.True(t, mock.containsSubstring("/etc/apt/sources.list.d/")) +} + +func TestModuleAptRepository_Bad_MissingRepo(t *testing.T) { + e, _ := newTestExecutorWithMock("host1") + mock := NewMockSSHClient() + + _, err := moduleAptRepositoryWithClient(e, mock, map[string]any{ + "filename": "test", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "repo required") +} + +func TestModuleAptRepository_Good_WriteFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`echo`, "", "permission denied", 1) + + result, err := moduleAptRepositoryWithClient(e, mock, map[string]any{ + "repo": "deb https://example.com/repo main", + "filename": "test", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "permission denied") +} + +// --- package module --- + +func TestModulePackage_Good_DetectAptAndDelegate(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + // First command: which apt-get returns the path + mock.expectCommand(`which apt-get`, "/usr/bin/apt-get", "", 0) + // Second command: the actual apt install + mock.expectCommand(`apt-get install -y -qq htop`, "", "", 0) + + result, err := modulePackageWithClient(e, mock, map[string]any{ + "name": "htop", + "state": "present", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`which apt-get`)) + assert.True(t, mock.hasExecuted(`apt-get install -y -qq htop`)) +} + +func TestModulePackage_Good_FallbackToApt(t *testing.T) { + // When which returns nothing (no package manager found), still falls back to apt + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`which apt-get`, "", "", 1) + mock.expectCommand(`apt-get install -y -qq vim`, "", "", 0) + + result, err := modulePackageWithClient(e, mock, map[string]any{ + "name": "vim", + "state": "present", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`apt-get install -y -qq vim`)) +} + +func TestModulePackage_Good_RemovePackage(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`which apt-get`, "/usr/bin/apt-get", "", 0) + mock.expectCommand(`apt-get remove -y -qq nano`, "", "", 0) + + result, err := modulePackageWithClient(e, mock, map[string]any{ + "name": "nano", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`apt-get remove -y -qq nano`)) +} + +// --- pip module --- + +func TestModulePip_Good_InstallPresent(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`pip3 install flask`, "Successfully installed", "", 0) + + result, err := modulePipWithClient(e, mock, map[string]any{ + "name": "flask", + "state": "present", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`pip3 install flask`)) +} + +func TestModulePip_Good_UninstallAbsent(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`pip3 uninstall -y flask`, "Successfully uninstalled", "", 0) + + result, err := modulePipWithClient(e, mock, map[string]any{ + "name": "flask", + "state": "absent", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`pip3 uninstall -y flask`)) +} + +func TestModulePip_Good_UpgradeLatest(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`pip3 install --upgrade flask`, "Successfully installed", "", 0) + + result, err := modulePipWithClient(e, mock, map[string]any{ + "name": "flask", + "state": "latest", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`pip3 install --upgrade flask`)) +} + +func TestModulePip_Good_CustomExecutable(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`/opt/venv/bin/pip install requests`, "", "", 0) + + result, err := modulePipWithClient(e, mock, map[string]any{ + "name": "requests", + "state": "present", + "executable": "/opt/venv/bin/pip", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`/opt/venv/bin/pip install requests`)) +} + +func TestModulePip_Good_DefaultStateIsPresent(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`pip3 install django`, "", "", 0) + + result, err := modulePipWithClient(e, mock, map[string]any{ + "name": "django", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`pip3 install django`)) +} + +func TestModulePip_Good_CommandFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`pip3 install`, "", "ERROR: No matching distribution found", 1) + + result, err := modulePipWithClient(e, mock, map[string]any{ + "name": "nonexistent-pkg-xyz", + "state": "present", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "No matching distribution found") +} + +func TestModulePip_Good_InstalledAlias(t *testing.T) { + // state=installed is an alias for present + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`pip3 install boto3`, "", "", 0) + + result, err := modulePipWithClient(e, mock, map[string]any{ + "name": "boto3", + "state": "installed", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`pip3 install boto3`)) +} + +func TestModulePip_Good_RemovedAlias(t *testing.T) { + // state=removed is an alias for absent + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`pip3 uninstall -y boto3`, "", "", 0) + + result, err := modulePipWithClient(e, mock, map[string]any{ + "name": "boto3", + "state": "removed", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`pip3 uninstall -y boto3`)) +} + +// --- Cross-module dispatch tests --- + +func TestExecuteModuleWithMock_Good_DispatchService(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl restart nginx`, "", "", 0) + + task := &Task{ + Module: "service", + Args: map[string]any{ + "name": "nginx", + "state": "restarted", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`systemctl restart nginx`)) +} + +func TestExecuteModuleWithMock_Good_DispatchSystemd(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`systemctl daemon-reload`, "", "", 0) + mock.expectCommand(`systemctl start myapp`, "", "", 0) + + task := &Task{ + Module: "ansible.builtin.systemd", + Args: map[string]any{ + "name": "myapp", + "state": "started", + "daemon_reload": true, + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`systemctl daemon-reload`)) + assert.True(t, mock.hasExecuted(`systemctl start myapp`)) +} + +func TestExecuteModuleWithMock_Good_DispatchApt(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`apt-get install -y -qq nginx`, "", "", 0) + + task := &Task{ + Module: "apt", + Args: map[string]any{ + "name": "nginx", + "state": "present", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`apt-get install`)) +} + +func TestExecuteModuleWithMock_Good_DispatchAptKey(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`curl.*gpg`, "", "", 0) + + task := &Task{ + Module: "apt_key", + Args: map[string]any{ + "url": "https://example.com/key.gpg", + "keyring": "/etc/apt/keyrings/example.gpg", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Good_DispatchAptRepository(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`echo`, "", "", 0) + mock.expectCommand(`apt-get update`, "", "", 0) + + task := &Task{ + Module: "apt_repository", + Args: map[string]any{ + "repo": "deb https://example.com/repo main", + "filename": "example", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Good_DispatchPackage(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`which apt-get`, "/usr/bin/apt-get", "", 0) + mock.expectCommand(`apt-get install -y -qq git`, "", "", 0) + + task := &Task{ + Module: "package", + Args: map[string]any{ + "name": "git", + "state": "present", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) +} + +func TestExecuteModuleWithMock_Good_DispatchPip(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`pip3 install ansible`, "", "", 0) + + task := &Task{ + Module: "pip", + Args: map[string]any{ + "name": "ansible", + "state": "present", + }, + } + + result, err := executeModuleWithMock(e, mock, "host1", task) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.True(t, mock.hasExecuted(`pip3 install ansible`)) +} diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..baecf3f --- /dev/null +++ b/parser.go @@ -0,0 +1,510 @@ +package ansible + +import ( + "fmt" + "iter" + "maps" + "os" + "path/filepath" + "slices" + "strings" + + "forge.lthn.ai/core/go-log" + "gopkg.in/yaml.v3" +) + +// Parser handles Ansible YAML parsing. +type Parser struct { + basePath string + vars map[string]any +} + +// NewParser creates a new Ansible parser. +func NewParser(basePath string) *Parser { + return &Parser{ + basePath: basePath, + vars: make(map[string]any), + } +} + +// ParsePlaybook parses an Ansible playbook file. +func (p *Parser) ParsePlaybook(path string) ([]Play, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read playbook: %w", err) + } + + var plays []Play + if err := yaml.Unmarshal(data, &plays); err != nil { + return nil, fmt.Errorf("parse playbook: %w", err) + } + + // Process each play + for i := range plays { + if err := p.processPlay(&plays[i]); err != nil { + return nil, fmt.Errorf("process play %d: %w", i, err) + } + } + + return plays, nil +} + +// ParsePlaybookIter returns an iterator for plays in an Ansible playbook file. +func (p *Parser) ParsePlaybookIter(path string) (iter.Seq[Play], error) { + plays, err := p.ParsePlaybook(path) + if err != nil { + return nil, err + } + return func(yield func(Play) bool) { + for _, play := range plays { + if !yield(play) { + return + } + } + }, nil +} + +// ParseInventory parses an Ansible inventory file. +func (p *Parser) ParseInventory(path string) (*Inventory, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read inventory: %w", err) + } + + var inv Inventory + if err := yaml.Unmarshal(data, &inv); err != nil { + return nil, fmt.Errorf("parse inventory: %w", err) + } + + return &inv, nil +} + +// ParseTasks parses a tasks file (used by include_tasks). +func (p *Parser) ParseTasks(path string) ([]Task, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read tasks: %w", err) + } + + var tasks []Task + if err := yaml.Unmarshal(data, &tasks); err != nil { + return nil, fmt.Errorf("parse tasks: %w", err) + } + + for i := range tasks { + if err := p.extractModule(&tasks[i]); err != nil { + return nil, fmt.Errorf("task %d: %w", i, err) + } + } + + return tasks, nil +} + +// ParseTasksIter returns an iterator for tasks in a tasks file. +func (p *Parser) ParseTasksIter(path string) (iter.Seq[Task], error) { + tasks, err := p.ParseTasks(path) + if err != nil { + return nil, err + } + return func(yield func(Task) bool) { + for _, task := range tasks { + if !yield(task) { + return + } + } + }, nil +} + +// ParseRole parses a role and returns its tasks. +func (p *Parser) ParseRole(name string, tasksFrom string) ([]Task, error) { + if tasksFrom == "" { + tasksFrom = "main.yml" + } + + // Search paths for roles (in order of precedence) + searchPaths := []string{ + // Relative to playbook + filepath.Join(p.basePath, "roles", name, "tasks", tasksFrom), + // Parent directory roles + filepath.Join(filepath.Dir(p.basePath), "roles", name, "tasks", tasksFrom), + // Sibling roles directory + filepath.Join(p.basePath, "..", "roles", name, "tasks", tasksFrom), + // playbooks/roles pattern + filepath.Join(p.basePath, "playbooks", "roles", name, "tasks", tasksFrom), + // Common DevOps structure + filepath.Join(filepath.Dir(filepath.Dir(p.basePath)), "roles", name, "tasks", tasksFrom), + } + + var tasksPath string + for _, sp := range searchPaths { + // Clean the path to resolve .. segments + sp = filepath.Clean(sp) + if _, err := os.Stat(sp); err == nil { + tasksPath = sp + break + } + } + + if tasksPath == "" { + return nil, log.E("parser.ParseRole", fmt.Sprintf("role %s not found in search paths: %v", name, searchPaths), nil) + } + + // Load role defaults + defaultsPath := filepath.Join(filepath.Dir(filepath.Dir(tasksPath)), "defaults", "main.yml") + if data, err := os.ReadFile(defaultsPath); err == nil { + var defaults map[string]any + if yaml.Unmarshal(data, &defaults) == nil { + for k, v := range defaults { + if _, exists := p.vars[k]; !exists { + p.vars[k] = v + } + } + } + } + + // Load role vars + varsPath := filepath.Join(filepath.Dir(filepath.Dir(tasksPath)), "vars", "main.yml") + if data, err := os.ReadFile(varsPath); err == nil { + var roleVars map[string]any + if yaml.Unmarshal(data, &roleVars) == nil { + for k, v := range roleVars { + p.vars[k] = v + } + } + } + + return p.ParseTasks(tasksPath) +} + +// processPlay processes a play and extracts modules from tasks. +func (p *Parser) processPlay(play *Play) error { + // Merge play vars + for k, v := range play.Vars { + p.vars[k] = v + } + + for i := range play.PreTasks { + if err := p.extractModule(&play.PreTasks[i]); err != nil { + return fmt.Errorf("pre_task %d: %w", i, err) + } + } + + for i := range play.Tasks { + if err := p.extractModule(&play.Tasks[i]); err != nil { + return fmt.Errorf("task %d: %w", i, err) + } + } + + for i := range play.PostTasks { + if err := p.extractModule(&play.PostTasks[i]); err != nil { + return fmt.Errorf("post_task %d: %w", i, err) + } + } + + for i := range play.Handlers { + if err := p.extractModule(&play.Handlers[i]); err != nil { + return fmt.Errorf("handler %d: %w", i, err) + } + } + + return nil +} + +// extractModule extracts the module name and args from a task. +func (p *Parser) extractModule(task *Task) error { + // First, unmarshal the raw YAML to get all keys + // This is a workaround since we need to find the module key dynamically + + // Handle block tasks + for i := range task.Block { + if err := p.extractModule(&task.Block[i]); err != nil { + return err + } + } + for i := range task.Rescue { + if err := p.extractModule(&task.Rescue[i]); err != nil { + return err + } + } + for i := range task.Always { + if err := p.extractModule(&task.Always[i]); err != nil { + return err + } + } + + return nil +} + +// UnmarshalYAML implements custom YAML unmarshaling for Task. +func (t *Task) UnmarshalYAML(node *yaml.Node) error { + // First decode known fields + type rawTask Task + var raw rawTask + + // Create a map to capture all fields + var m map[string]any + if err := node.Decode(&m); err != nil { + return err + } + + // Decode into struct + if err := node.Decode(&raw); err != nil { + return err + } + *t = Task(raw) + t.raw = m + + // Find the module key + knownKeys := map[string]bool{ + "name": true, "register": true, "when": true, "loop": true, + "loop_control": true, "vars": true, "environment": true, + "changed_when": true, "failed_when": true, "ignore_errors": true, + "no_log": true, "become": true, "become_user": true, + "delegate_to": true, "run_once": true, "tags": true, + "block": true, "rescue": true, "always": true, "notify": true, + "retries": true, "delay": true, "until": true, + "include_tasks": true, "import_tasks": true, + "include_role": true, "import_role": true, + "with_items": true, "with_dict": true, "with_file": true, + } + + for key, val := range m { + if knownKeys[key] { + continue + } + + // Check if this is a module + if isModule(key) { + t.Module = key + t.Args = make(map[string]any) + + switch v := val.(type) { + case string: + // Free-form args (e.g., shell: echo hello) + t.Args["_raw_params"] = v + case map[string]any: + t.Args = v + case nil: + // Module with no args + default: + t.Args["_raw_params"] = v + } + break + } + } + + // Handle with_items as loop + if items, ok := m["with_items"]; ok && t.Loop == nil { + t.Loop = items + } + + return nil +} + +// isModule checks if a key is a known module. +func isModule(key string) bool { + for _, m := range KnownModules { + if key == m { + return true + } + // Also check without ansible.builtin. prefix + if strings.HasPrefix(m, "ansible.builtin.") { + if key == strings.TrimPrefix(m, "ansible.builtin.") { + return true + } + } + } + // Accept any key with dots (likely a module) + return strings.Contains(key, ".") +} + +// NormalizeModule normalizes a module name to its canonical form. +func NormalizeModule(name string) string { + // Add ansible.builtin. prefix if missing + if !strings.Contains(name, ".") { + return "ansible.builtin." + name + } + return name +} + +// GetHosts returns hosts matching a pattern from inventory. +func GetHosts(inv *Inventory, pattern string) []string { + if pattern == "all" { + return getAllHosts(inv.All) + } + if pattern == "localhost" { + return []string{"localhost"} + } + + // Check if it's a group name + hosts := getGroupHosts(inv.All, pattern) + if len(hosts) > 0 { + return hosts + } + + // Check if it's a specific host + if hasHost(inv.All, pattern) { + return []string{pattern} + } + + // Handle patterns with : (intersection/union) + // For now, just return empty + return nil +} + +// GetHostsIter returns an iterator for hosts matching a pattern from inventory. +func GetHostsIter(inv *Inventory, pattern string) iter.Seq[string] { + hosts := GetHosts(inv, pattern) + return func(yield func(string) bool) { + for _, host := range hosts { + if !yield(host) { + return + } + } + } +} + +func getAllHosts(group *InventoryGroup) []string { + if group == nil { + return nil + } + + var hosts []string + for name := range group.Hosts { + hosts = append(hosts, name) + } + for _, child := range group.Children { + hosts = append(hosts, getAllHosts(child)...) + } + return hosts +} + +// AllHostsIter returns an iterator for all hosts in an inventory group. +func AllHostsIter(group *InventoryGroup) iter.Seq[string] { + return func(yield func(string) bool) { + if group == nil { + return + } + // Sort keys for deterministic iteration + keys := slices.Sorted(maps.Keys(group.Hosts)) + for _, name := range keys { + if !yield(name) { + return + } + } + + // Sort children keys for deterministic iteration + childKeys := slices.Sorted(maps.Keys(group.Children)) + for _, name := range childKeys { + child := group.Children[name] + for host := range AllHostsIter(child) { + if !yield(host) { + return + } + } + } + } +} + +func getGroupHosts(group *InventoryGroup, name string) []string { + if group == nil { + return nil + } + + // Check children for the group name + if child, ok := group.Children[name]; ok { + return getAllHosts(child) + } + + // Recurse + for _, child := range group.Children { + if hosts := getGroupHosts(child, name); len(hosts) > 0 { + return hosts + } + } + + return nil +} + +func hasHost(group *InventoryGroup, name string) bool { + if group == nil { + return false + } + + if _, ok := group.Hosts[name]; ok { + return true + } + + for _, child := range group.Children { + if hasHost(child, name) { + return true + } + } + + return false +} + +// GetHostVars returns variables for a specific host. +func GetHostVars(inv *Inventory, hostname string) map[string]any { + vars := make(map[string]any) + + // Collect vars from all levels + collectHostVars(inv.All, hostname, vars) + + return vars +} + +func collectHostVars(group *InventoryGroup, hostname string, vars map[string]any) bool { + if group == nil { + return false + } + + // Check if host is in this group + found := false + if host, ok := group.Hosts[hostname]; ok { + found = true + // Apply group vars first + for k, v := range group.Vars { + vars[k] = v + } + // Then host vars + if host != nil { + if host.AnsibleHost != "" { + vars["ansible_host"] = host.AnsibleHost + } + if host.AnsiblePort != 0 { + vars["ansible_port"] = host.AnsiblePort + } + if host.AnsibleUser != "" { + vars["ansible_user"] = host.AnsibleUser + } + if host.AnsiblePassword != "" { + vars["ansible_password"] = host.AnsiblePassword + } + if host.AnsibleSSHPrivateKeyFile != "" { + vars["ansible_ssh_private_key_file"] = host.AnsibleSSHPrivateKeyFile + } + if host.AnsibleConnection != "" { + vars["ansible_connection"] = host.AnsibleConnection + } + for k, v := range host.Vars { + vars[k] = v + } + } + } + + // Check children + for _, child := range group.Children { + if collectHostVars(child, hostname, vars) { + // Apply this group's vars (parent vars) + for k, v := range group.Vars { + if _, exists := vars[k]; !exists { + vars[k] = v + } + } + found = true + } + } + + return found +} diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 0000000..8712e72 --- /dev/null +++ b/parser_test.go @@ -0,0 +1,777 @@ +package ansible + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- ParsePlaybook --- + +func TestParsePlaybook_Good_SimplePlay(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: Configure webserver + hosts: webservers + become: true + tasks: + - name: Install nginx + apt: + name: nginx + state: present +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + require.Len(t, plays, 1) + assert.Equal(t, "Configure webserver", plays[0].Name) + assert.Equal(t, "webservers", plays[0].Hosts) + assert.True(t, plays[0].Become) + require.Len(t, plays[0].Tasks, 1) + assert.Equal(t, "Install nginx", plays[0].Tasks[0].Name) + assert.Equal(t, "apt", plays[0].Tasks[0].Module) + assert.Equal(t, "nginx", plays[0].Tasks[0].Args["name"]) + assert.Equal(t, "present", plays[0].Tasks[0].Args["state"]) +} + +func TestParsePlaybook_Good_MultiplePlays(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: Play one + hosts: all + tasks: + - name: Say hello + debug: + msg: "Hello" + +- name: Play two + hosts: localhost + connection: local + tasks: + - name: Say goodbye + debug: + msg: "Goodbye" +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + require.Len(t, plays, 2) + assert.Equal(t, "Play one", plays[0].Name) + assert.Equal(t, "all", plays[0].Hosts) + assert.Equal(t, "Play two", plays[1].Name) + assert.Equal(t, "localhost", plays[1].Hosts) + assert.Equal(t, "local", plays[1].Connection) +} + +func TestParsePlaybook_Good_WithVars(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: With vars + hosts: all + vars: + http_port: 8080 + app_name: myapp + tasks: + - name: Print port + debug: + msg: "Port is {{ http_port }}" +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + require.Len(t, plays, 1) + assert.Equal(t, 8080, plays[0].Vars["http_port"]) + assert.Equal(t, "myapp", plays[0].Vars["app_name"]) +} + +func TestParsePlaybook_Good_PrePostTasks(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: Full lifecycle + hosts: all + pre_tasks: + - name: Pre task + debug: + msg: "pre" + tasks: + - name: Main task + debug: + msg: "main" + post_tasks: + - name: Post task + debug: + msg: "post" +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + require.Len(t, plays, 1) + assert.Len(t, plays[0].PreTasks, 1) + assert.Len(t, plays[0].Tasks, 1) + assert.Len(t, plays[0].PostTasks, 1) + assert.Equal(t, "Pre task", plays[0].PreTasks[0].Name) + assert.Equal(t, "Main task", plays[0].Tasks[0].Name) + assert.Equal(t, "Post task", plays[0].PostTasks[0].Name) +} + +func TestParsePlaybook_Good_Handlers(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: With handlers + hosts: all + tasks: + - name: Install package + apt: + name: nginx + notify: restart nginx + handlers: + - name: restart nginx + service: + name: nginx + state: restarted +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + require.Len(t, plays, 1) + assert.Len(t, plays[0].Handlers, 1) + assert.Equal(t, "restart nginx", plays[0].Handlers[0].Name) + assert.Equal(t, "service", plays[0].Handlers[0].Module) +} + +func TestParsePlaybook_Good_ShellFreeForm(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: Shell tasks + hosts: all + tasks: + - name: Run a command + shell: echo hello world + - name: Run raw command + command: ls -la /tmp +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + require.Len(t, plays[0].Tasks, 2) + assert.Equal(t, "shell", plays[0].Tasks[0].Module) + assert.Equal(t, "echo hello world", plays[0].Tasks[0].Args["_raw_params"]) + assert.Equal(t, "command", plays[0].Tasks[1].Module) + assert.Equal(t, "ls -la /tmp", plays[0].Tasks[1].Args["_raw_params"]) +} + +func TestParsePlaybook_Good_WithTags(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: Tagged play + hosts: all + tags: + - setup + tasks: + - name: Tagged task + debug: + msg: "tagged" + tags: + - debug + - always +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + assert.Equal(t, []string{"setup"}, plays[0].Tags) + assert.Equal(t, []string{"debug", "always"}, plays[0].Tasks[0].Tags) +} + +func TestParsePlaybook_Good_BlockRescueAlways(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: With blocks + hosts: all + tasks: + - name: Protected block + block: + - name: Try this + shell: echo try + rescue: + - name: Handle error + debug: + msg: "rescued" + always: + - name: Always runs + debug: + msg: "always" +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + task := plays[0].Tasks[0] + assert.Len(t, task.Block, 1) + assert.Len(t, task.Rescue, 1) + assert.Len(t, task.Always, 1) + assert.Equal(t, "Try this", task.Block[0].Name) + assert.Equal(t, "Handle error", task.Rescue[0].Name) + assert.Equal(t, "Always runs", task.Always[0].Name) +} + +func TestParsePlaybook_Good_WithLoop(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: Loop test + hosts: all + tasks: + - name: Install packages + apt: + name: "{{ item }}" + state: present + loop: + - vim + - curl + - git +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + task := plays[0].Tasks[0] + assert.Equal(t, "apt", task.Module) + items, ok := task.Loop.([]any) + require.True(t, ok) + assert.Len(t, items, 3) +} + +func TestParsePlaybook_Good_RoleRefs(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: With roles + hosts: all + roles: + - common + - role: webserver + vars: + http_port: 80 + tags: + - web +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + require.Len(t, plays[0].Roles, 2) + assert.Equal(t, "common", plays[0].Roles[0].Role) + assert.Equal(t, "webserver", plays[0].Roles[1].Role) + assert.Equal(t, 80, plays[0].Roles[1].Vars["http_port"]) + assert.Equal(t, []string{"web"}, plays[0].Roles[1].Tags) +} + +func TestParsePlaybook_Good_FullyQualifiedModules(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: FQCN modules + hosts: all + tasks: + - name: Copy file + ansible.builtin.copy: + src: /tmp/foo + dest: /tmp/bar + - name: Run shell + ansible.builtin.shell: echo hello +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + assert.Equal(t, "ansible.builtin.copy", plays[0].Tasks[0].Module) + assert.Equal(t, "/tmp/foo", plays[0].Tasks[0].Args["src"]) + assert.Equal(t, "ansible.builtin.shell", plays[0].Tasks[1].Module) + assert.Equal(t, "echo hello", plays[0].Tasks[1].Args["_raw_params"]) +} + +func TestParsePlaybook_Good_RegisterAndWhen(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: Conditional play + hosts: all + tasks: + - name: Check file + stat: + path: /etc/nginx/nginx.conf + register: nginx_conf + - name: Show result + debug: + msg: "File exists" + when: nginx_conf.stat.exists +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + assert.Equal(t, "nginx_conf", plays[0].Tasks[0].Register) + assert.NotNil(t, plays[0].Tasks[1].When) +} + +func TestParsePlaybook_Good_EmptyPlaybook(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + require.NoError(t, os.WriteFile(path, []byte("---\n[]"), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + assert.Empty(t, plays) +} + +func TestParsePlaybook_Bad_InvalidYAML(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.yml") + + require.NoError(t, os.WriteFile(path, []byte("{{invalid yaml}}"), 0644)) + + p := NewParser(dir) + _, err := p.ParsePlaybook(path) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "parse playbook") +} + +func TestParsePlaybook_Bad_FileNotFound(t *testing.T) { + p := NewParser(t.TempDir()) + _, err := p.ParsePlaybook("/nonexistent/playbook.yml") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "read playbook") +} + +func TestParsePlaybook_Good_GatherFactsDisabled(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "playbook.yml") + + yaml := `--- +- name: No facts + hosts: all + gather_facts: false + tasks: [] +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + plays, err := p.ParsePlaybook(path) + + require.NoError(t, err) + require.NotNil(t, plays[0].GatherFacts) + assert.False(t, *plays[0].GatherFacts) +} + +// --- ParseInventory --- + +func TestParseInventory_Good_SimpleInventory(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "inventory.yml") + + yaml := `--- +all: + hosts: + web1: + ansible_host: 192.168.1.10 + web2: + ansible_host: 192.168.1.11 +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + inv, err := p.ParseInventory(path) + + require.NoError(t, err) + require.NotNil(t, inv.All) + assert.Len(t, inv.All.Hosts, 2) + assert.Equal(t, "192.168.1.10", inv.All.Hosts["web1"].AnsibleHost) + assert.Equal(t, "192.168.1.11", inv.All.Hosts["web2"].AnsibleHost) +} + +func TestParseInventory_Good_WithGroups(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "inventory.yml") + + yaml := `--- +all: + children: + webservers: + hosts: + web1: + ansible_host: 10.0.0.1 + web2: + ansible_host: 10.0.0.2 + databases: + hosts: + db1: + ansible_host: 10.0.1.1 +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + inv, err := p.ParseInventory(path) + + require.NoError(t, err) + require.NotNil(t, inv.All.Children["webservers"]) + assert.Len(t, inv.All.Children["webservers"].Hosts, 2) + require.NotNil(t, inv.All.Children["databases"]) + assert.Len(t, inv.All.Children["databases"].Hosts, 1) +} + +func TestParseInventory_Good_WithVars(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "inventory.yml") + + yaml := `--- +all: + vars: + ansible_user: admin + children: + production: + vars: + env: prod + hosts: + prod1: + ansible_host: 10.0.0.1 + ansible_port: 2222 +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + inv, err := p.ParseInventory(path) + + require.NoError(t, err) + assert.Equal(t, "admin", inv.All.Vars["ansible_user"]) + assert.Equal(t, "prod", inv.All.Children["production"].Vars["env"]) + assert.Equal(t, 2222, inv.All.Children["production"].Hosts["prod1"].AnsiblePort) +} + +func TestParseInventory_Bad_InvalidYAML(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.yml") + + require.NoError(t, os.WriteFile(path, []byte("{{{bad"), 0644)) + + p := NewParser(dir) + _, err := p.ParseInventory(path) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "parse inventory") +} + +func TestParseInventory_Bad_FileNotFound(t *testing.T) { + p := NewParser(t.TempDir()) + _, err := p.ParseInventory("/nonexistent/inventory.yml") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "read inventory") +} + +// --- ParseTasks --- + +func TestParseTasks_Good_TaskFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tasks.yml") + + yaml := `--- +- name: First task + shell: echo first +- name: Second task + copy: + src: /tmp/a + dest: /tmp/b +` + require.NoError(t, os.WriteFile(path, []byte(yaml), 0644)) + + p := NewParser(dir) + tasks, err := p.ParseTasks(path) + + require.NoError(t, err) + require.Len(t, tasks, 2) + assert.Equal(t, "shell", tasks[0].Module) + assert.Equal(t, "echo first", tasks[0].Args["_raw_params"]) + assert.Equal(t, "copy", tasks[1].Module) + assert.Equal(t, "/tmp/a", tasks[1].Args["src"]) +} + +func TestParseTasks_Bad_InvalidYAML(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.yml") + + require.NoError(t, os.WriteFile(path, []byte("not: [valid: tasks"), 0644)) + + p := NewParser(dir) + _, err := p.ParseTasks(path) + + assert.Error(t, err) +} + +// --- GetHosts --- + +func TestGetHosts_Good_AllPattern(t *testing.T) { + inv := &Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "host1": {}, + "host2": {}, + }, + }, + } + + hosts := GetHosts(inv, "all") + assert.Len(t, hosts, 2) + assert.Contains(t, hosts, "host1") + assert.Contains(t, hosts, "host2") +} + +func TestGetHosts_Good_LocalhostPattern(t *testing.T) { + inv := &Inventory{All: &InventoryGroup{}} + hosts := GetHosts(inv, "localhost") + assert.Equal(t, []string{"localhost"}, hosts) +} + +func TestGetHosts_Good_GroupPattern(t *testing.T) { + inv := &Inventory{ + All: &InventoryGroup{ + Children: map[string]*InventoryGroup{ + "web": { + Hosts: map[string]*Host{ + "web1": {}, + "web2": {}, + }, + }, + "db": { + Hosts: map[string]*Host{ + "db1": {}, + }, + }, + }, + }, + } + + hosts := GetHosts(inv, "web") + assert.Len(t, hosts, 2) + assert.Contains(t, hosts, "web1") + assert.Contains(t, hosts, "web2") +} + +func TestGetHosts_Good_SpecificHost(t *testing.T) { + inv := &Inventory{ + All: &InventoryGroup{ + Children: map[string]*InventoryGroup{ + "servers": { + Hosts: map[string]*Host{ + "myhost": {}, + }, + }, + }, + }, + } + + hosts := GetHosts(inv, "myhost") + assert.Equal(t, []string{"myhost"}, hosts) +} + +func TestGetHosts_Good_AllIncludesChildren(t *testing.T) { + inv := &Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{"top": {}}, + Children: map[string]*InventoryGroup{ + "group1": { + Hosts: map[string]*Host{"child1": {}}, + }, + }, + }, + } + + hosts := GetHosts(inv, "all") + assert.Len(t, hosts, 2) + assert.Contains(t, hosts, "top") + assert.Contains(t, hosts, "child1") +} + +func TestGetHosts_Bad_NoMatch(t *testing.T) { + inv := &Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{"host1": {}}, + }, + } + + hosts := GetHosts(inv, "nonexistent") + assert.Empty(t, hosts) +} + +func TestGetHosts_Bad_NilGroup(t *testing.T) { + inv := &Inventory{All: nil} + hosts := GetHosts(inv, "all") + assert.Empty(t, hosts) +} + +// --- GetHostVars --- + +func TestGetHostVars_Good_DirectHost(t *testing.T) { + inv := &Inventory{ + All: &InventoryGroup{ + Vars: map[string]any{"global_var": "global"}, + Hosts: map[string]*Host{ + "myhost": { + AnsibleHost: "10.0.0.1", + AnsiblePort: 2222, + AnsibleUser: "deploy", + }, + }, + }, + } + + vars := GetHostVars(inv, "myhost") + assert.Equal(t, "10.0.0.1", vars["ansible_host"]) + assert.Equal(t, 2222, vars["ansible_port"]) + assert.Equal(t, "deploy", vars["ansible_user"]) + assert.Equal(t, "global", vars["global_var"]) +} + +func TestGetHostVars_Good_InheritedGroupVars(t *testing.T) { + inv := &Inventory{ + All: &InventoryGroup{ + Vars: map[string]any{"level": "all"}, + Children: map[string]*InventoryGroup{ + "production": { + Vars: map[string]any{"env": "prod", "level": "group"}, + Hosts: map[string]*Host{ + "prod1": { + AnsibleHost: "10.0.0.1", + }, + }, + }, + }, + }, + } + + vars := GetHostVars(inv, "prod1") + assert.Equal(t, "10.0.0.1", vars["ansible_host"]) + assert.Equal(t, "prod", vars["env"]) +} + +func TestGetHostVars_Good_HostNotFound(t *testing.T) { + inv := &Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{"other": {}}, + }, + } + + vars := GetHostVars(inv, "nonexistent") + assert.Empty(t, vars) +} + +// --- isModule --- + +func TestIsModule_Good_KnownModules(t *testing.T) { + assert.True(t, isModule("shell")) + assert.True(t, isModule("command")) + assert.True(t, isModule("copy")) + assert.True(t, isModule("file")) + assert.True(t, isModule("apt")) + assert.True(t, isModule("service")) + assert.True(t, isModule("systemd")) + assert.True(t, isModule("debug")) + assert.True(t, isModule("set_fact")) +} + +func TestIsModule_Good_FQCN(t *testing.T) { + assert.True(t, isModule("ansible.builtin.shell")) + assert.True(t, isModule("ansible.builtin.copy")) + assert.True(t, isModule("ansible.builtin.apt")) +} + +func TestIsModule_Good_DottedUnknown(t *testing.T) { + // Any key with dots is considered a module + assert.True(t, isModule("community.general.ufw")) + assert.True(t, isModule("ansible.posix.authorized_key")) +} + +func TestIsModule_Bad_NotAModule(t *testing.T) { + assert.False(t, isModule("some_random_key")) + assert.False(t, isModule("foobar")) +} + +// --- NormalizeModule --- + +func TestNormalizeModule_Good(t *testing.T) { + assert.Equal(t, "ansible.builtin.shell", NormalizeModule("shell")) + assert.Equal(t, "ansible.builtin.copy", NormalizeModule("copy")) + assert.Equal(t, "ansible.builtin.apt", NormalizeModule("apt")) +} + +func TestNormalizeModule_Good_AlreadyFQCN(t *testing.T) { + assert.Equal(t, "ansible.builtin.shell", NormalizeModule("ansible.builtin.shell")) + assert.Equal(t, "community.general.ufw", NormalizeModule("community.general.ufw")) +} + +// --- NewParser --- + +func TestNewParser_Good(t *testing.T) { + p := NewParser("/some/path") + assert.NotNil(t, p) + assert.Equal(t, "/some/path", p.basePath) + assert.NotNil(t, p.vars) +} diff --git a/ssh.go b/ssh.go new file mode 100644 index 0000000..f12d38e --- /dev/null +++ b/ssh.go @@ -0,0 +1,451 @@ +package ansible + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "forge.lthn.ai/core/go-log" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +// SSHClient handles SSH connections to remote hosts. +type SSHClient struct { + host string + port int + user string + password string + keyFile string + client *ssh.Client + mu sync.Mutex + become bool + becomeUser string + becomePass string + timeout time.Duration +} + +// SSHConfig holds SSH connection configuration. +type SSHConfig struct { + Host string + Port int + User string + Password string + KeyFile string + Become bool + BecomeUser string + BecomePass string + Timeout time.Duration +} + +// NewSSHClient creates a new SSH client. +func NewSSHClient(cfg SSHConfig) (*SSHClient, error) { + if cfg.Port == 0 { + cfg.Port = 22 + } + if cfg.User == "" { + cfg.User = "root" + } + if cfg.Timeout == 0 { + cfg.Timeout = 30 * time.Second + } + + client := &SSHClient{ + host: cfg.Host, + port: cfg.Port, + user: cfg.User, + password: cfg.Password, + keyFile: cfg.KeyFile, + become: cfg.Become, + becomeUser: cfg.BecomeUser, + becomePass: cfg.BecomePass, + timeout: cfg.Timeout, + } + + return client, nil +} + +// Connect establishes the SSH connection. +func (c *SSHClient) Connect(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.client != nil { + return nil + } + + var authMethods []ssh.AuthMethod + + // Try key-based auth first + if c.keyFile != "" { + keyPath := c.keyFile + if strings.HasPrefix(keyPath, "~") { + home, _ := os.UserHomeDir() + keyPath = filepath.Join(home, keyPath[1:]) + } + + if key, err := os.ReadFile(keyPath); err == nil { + if signer, err := ssh.ParsePrivateKey(key); err == nil { + authMethods = append(authMethods, ssh.PublicKeys(signer)) + } + } + } + + // Try default SSH keys + if len(authMethods) == 0 { + home, _ := os.UserHomeDir() + defaultKeys := []string{ + filepath.Join(home, ".ssh", "id_ed25519"), + filepath.Join(home, ".ssh", "id_rsa"), + } + for _, keyPath := range defaultKeys { + if key, err := os.ReadFile(keyPath); err == nil { + if signer, err := ssh.ParsePrivateKey(key); err == nil { + authMethods = append(authMethods, ssh.PublicKeys(signer)) + break + } + } + } + } + + // Fall back to password auth + if c.password != "" { + authMethods = append(authMethods, ssh.Password(c.password)) + authMethods = append(authMethods, ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) { + answers := make([]string, len(questions)) + for i := range questions { + answers[i] = c.password + } + return answers, nil + })) + } + + if len(authMethods) == 0 { + return log.E("ssh.Connect", "no authentication method available", nil) + } + + // Host key verification + var hostKeyCallback ssh.HostKeyCallback + + home, err := os.UserHomeDir() + if err != nil { + return log.E("ssh.Connect", "failed to get user home dir", err) + } + knownHostsPath := filepath.Join(home, ".ssh", "known_hosts") + + // Ensure known_hosts file exists + if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) { + if err := os.MkdirAll(filepath.Dir(knownHostsPath), 0700); err != nil { + return log.E("ssh.Connect", "failed to create .ssh dir", err) + } + if err := os.WriteFile(knownHostsPath, nil, 0600); err != nil { + return log.E("ssh.Connect", "failed to create known_hosts file", err) + } + } + + cb, err := knownhosts.New(knownHostsPath) + if err != nil { + return log.E("ssh.Connect", "failed to load known_hosts", err) + } + hostKeyCallback = cb + + config := &ssh.ClientConfig{ + User: c.user, + Auth: authMethods, + HostKeyCallback: hostKeyCallback, + Timeout: c.timeout, + } + + addr := fmt.Sprintf("%s:%d", c.host, c.port) + + // Connect with context timeout + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return log.E("ssh.Connect", fmt.Sprintf("dial %s", addr), err) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + // conn is closed by NewClientConn on error + return log.E("ssh.Connect", fmt.Sprintf("ssh connect %s", addr), err) + } + + c.client = ssh.NewClient(sshConn, chans, reqs) + return nil +} + +// Close closes the SSH connection. +func (c *SSHClient) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.client != nil { + err := c.client.Close() + c.client = nil + return err + } + return nil +} + +// Run executes a command on the remote host. +func (c *SSHClient) Run(ctx context.Context, cmd string) (stdout, stderr string, exitCode int, err error) { + if err := c.Connect(ctx); err != nil { + return "", "", -1, err + } + + session, err := c.client.NewSession() + if err != nil { + return "", "", -1, log.E("ssh.Run", "new session", err) + } + defer func() { _ = session.Close() }() + + var stdoutBuf, stderrBuf bytes.Buffer + session.Stdout = &stdoutBuf + session.Stderr = &stderrBuf + + // Apply become if needed + if c.become { + becomeUser := c.becomeUser + if becomeUser == "" { + becomeUser = "root" + } + // Escape single quotes in the command + escapedCmd := strings.ReplaceAll(cmd, "'", "'\\''") + if c.becomePass != "" { + // Use sudo with password via stdin (-S flag) + // We launch a goroutine to write the password to stdin + cmd = fmt.Sprintf("sudo -S -u %s bash -c '%s'", becomeUser, escapedCmd) + stdin, err := session.StdinPipe() + if err != nil { + return "", "", -1, log.E("ssh.Run", "stdin pipe", err) + } + go func() { + defer func() { _ = stdin.Close() }() + _, _ = io.WriteString(stdin, c.becomePass+"\n") + }() + } else if c.password != "" { + // Try using connection password for sudo + cmd = fmt.Sprintf("sudo -S -u %s bash -c '%s'", becomeUser, escapedCmd) + stdin, err := session.StdinPipe() + if err != nil { + return "", "", -1, log.E("ssh.Run", "stdin pipe", err) + } + go func() { + defer func() { _ = stdin.Close() }() + _, _ = io.WriteString(stdin, c.password+"\n") + }() + } else { + // Try passwordless sudo + cmd = fmt.Sprintf("sudo -n -u %s bash -c '%s'", becomeUser, escapedCmd) + } + } + + // Run with context + done := make(chan error, 1) + go func() { + done <- session.Run(cmd) + }() + + select { + case <-ctx.Done(): + _ = session.Signal(ssh.SIGKILL) + return "", "", -1, ctx.Err() + case err := <-done: + exitCode = 0 + if err != nil { + if exitErr, ok := err.(*ssh.ExitError); ok { + exitCode = exitErr.ExitStatus() + } else { + return stdoutBuf.String(), stderrBuf.String(), -1, err + } + } + return stdoutBuf.String(), stderrBuf.String(), exitCode, nil + } +} + +// RunScript runs a script on the remote host. +func (c *SSHClient) RunScript(ctx context.Context, script string) (stdout, stderr string, exitCode int, err error) { + // Escape the script for heredoc + cmd := fmt.Sprintf("bash <<'ANSIBLE_SCRIPT_EOF'\n%s\nANSIBLE_SCRIPT_EOF", script) + return c.Run(ctx, cmd) +} + +// Upload copies a file to the remote host. +func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string, mode os.FileMode) error { + if err := c.Connect(ctx); err != nil { + return err + } + + // Read content + content, err := io.ReadAll(local) + if err != nil { + return log.E("ssh.Upload", "read content", err) + } + + // Create parent directory + dir := filepath.Dir(remote) + dirCmd := fmt.Sprintf("mkdir -p %q", dir) + if c.become { + dirCmd = fmt.Sprintf("sudo mkdir -p %q", dir) + } + if _, _, _, err := c.Run(ctx, dirCmd); err != nil { + return log.E("ssh.Upload", "create parent dir", err) + } + + // Use cat to write the file (simpler than SCP) + writeCmd := fmt.Sprintf("cat > %q && chmod %o %q", remote, mode, remote) + + // If become is needed, we construct a command that reads password then content from stdin + // But we need to be careful with handling stdin for sudo + cat. + // We'll use a session with piped stdin. + + session2, err := c.client.NewSession() + if err != nil { + return log.E("ssh.Upload", "new session for write", err) + } + defer func() { _ = session2.Close() }() + + stdin, err := session2.StdinPipe() + if err != nil { + return log.E("ssh.Upload", "stdin pipe", err) + } + + var stderrBuf bytes.Buffer + session2.Stderr = &stderrBuf + + if c.become { + becomeUser := c.becomeUser + if becomeUser == "" { + becomeUser = "root" + } + + pass := c.becomePass + if pass == "" { + pass = c.password + } + + if pass != "" { + // Use sudo -S with password from stdin + writeCmd = fmt.Sprintf("sudo -S -u %s bash -c 'cat > %q && chmod %o %q'", + becomeUser, remote, mode, remote) + } else { + // Use passwordless sudo (sudo -n) to avoid consuming file content as password + writeCmd = fmt.Sprintf("sudo -n -u %s bash -c 'cat > %q && chmod %o %q'", + becomeUser, remote, mode, remote) + } + + if err := session2.Start(writeCmd); err != nil { + return log.E("ssh.Upload", "start write", err) + } + + go func() { + defer func() { _ = stdin.Close() }() + if pass != "" { + _, _ = io.WriteString(stdin, pass+"\n") + } + _, _ = stdin.Write(content) + }() + } else { + // Normal write + if err := session2.Start(writeCmd); err != nil { + return log.E("ssh.Upload", "start write", err) + } + + go func() { + defer func() { _ = stdin.Close() }() + _, _ = stdin.Write(content) + }() + } + + if err := session2.Wait(); err != nil { + return log.E("ssh.Upload", fmt.Sprintf("write failed (stderr: %s)", stderrBuf.String()), err) + } + + return nil +} + +// Download copies a file from the remote host. +func (c *SSHClient) Download(ctx context.Context, remote string) ([]byte, error) { + if err := c.Connect(ctx); err != nil { + return nil, err + } + + cmd := fmt.Sprintf("cat %q", remote) + + stdout, stderr, exitCode, err := c.Run(ctx, cmd) + if err != nil { + return nil, err + } + if exitCode != 0 { + return nil, log.E("ssh.Download", fmt.Sprintf("cat failed: %s", stderr), nil) + } + + return []byte(stdout), nil +} + +// FileExists checks if a file exists on the remote host. +func (c *SSHClient) FileExists(ctx context.Context, path string) (bool, error) { + cmd := fmt.Sprintf("test -e %q && echo yes || echo no", path) + stdout, _, exitCode, err := c.Run(ctx, cmd) + if err != nil { + return false, err + } + if exitCode != 0 { + // test command failed but didn't error - file doesn't exist + return false, nil + } + return strings.TrimSpace(stdout) == "yes", nil +} + +// Stat returns file info from the remote host. +func (c *SSHClient) Stat(ctx context.Context, path string) (map[string]any, error) { + // Simple approach - get basic file info + cmd := fmt.Sprintf(` +if [ -e %q ]; then + if [ -d %q ]; then + echo "exists=true isdir=true" + else + echo "exists=true isdir=false" + fi +else + echo "exists=false" +fi +`, path, path) + + stdout, _, _, err := c.Run(ctx, cmd) + if err != nil { + return nil, err + } + + result := make(map[string]any) + parts := strings.Fields(strings.TrimSpace(stdout)) + for _, part := range parts { + kv := strings.SplitN(part, "=", 2) + if len(kv) == 2 { + result[kv[0]] = kv[1] == "true" + } + } + + return result, nil +} + +// SetBecome enables privilege escalation. +func (c *SSHClient) SetBecome(become bool, user, password string) { + c.mu.Lock() + defer c.mu.Unlock() + c.become = become + if user != "" { + c.becomeUser = user + } + if password != "" { + c.becomePass = password + } +} diff --git a/ssh_test.go b/ssh_test.go new file mode 100644 index 0000000..17179b0 --- /dev/null +++ b/ssh_test.go @@ -0,0 +1,36 @@ +package ansible + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewSSHClient(t *testing.T) { + cfg := SSHConfig{ + Host: "localhost", + Port: 2222, + User: "root", + } + + client, err := NewSSHClient(cfg) + assert.NoError(t, err) + assert.NotNil(t, client) + assert.Equal(t, "localhost", client.host) + assert.Equal(t, 2222, client.port) + assert.Equal(t, "root", client.user) + assert.Equal(t, 30*time.Second, client.timeout) +} + +func TestSSHConfig_Defaults(t *testing.T) { + cfg := SSHConfig{ + Host: "localhost", + } + + client, err := NewSSHClient(cfg) + assert.NoError(t, err) + assert.Equal(t, 22, client.port) + assert.Equal(t, "root", client.user) + assert.Equal(t, 30*time.Second, client.timeout) +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..5a6939f --- /dev/null +++ b/types.go @@ -0,0 +1,258 @@ +package ansible + +import ( + "time" +) + +// Playbook represents an Ansible playbook. +type Playbook struct { + Plays []Play `yaml:",inline"` +} + +// Play represents a single play in a playbook. +type Play struct { + Name string `yaml:"name"` + Hosts string `yaml:"hosts"` + Connection string `yaml:"connection,omitempty"` + Become bool `yaml:"become,omitempty"` + BecomeUser string `yaml:"become_user,omitempty"` + GatherFacts *bool `yaml:"gather_facts,omitempty"` + Vars map[string]any `yaml:"vars,omitempty"` + PreTasks []Task `yaml:"pre_tasks,omitempty"` + Tasks []Task `yaml:"tasks,omitempty"` + PostTasks []Task `yaml:"post_tasks,omitempty"` + Roles []RoleRef `yaml:"roles,omitempty"` + Handlers []Task `yaml:"handlers,omitempty"` + Tags []string `yaml:"tags,omitempty"` + Environment map[string]string `yaml:"environment,omitempty"` + Serial any `yaml:"serial,omitempty"` // int or string + MaxFailPercent int `yaml:"max_fail_percentage,omitempty"` +} + +// RoleRef represents a role reference in a play. +type RoleRef struct { + Role string `yaml:"role,omitempty"` + Name string `yaml:"name,omitempty"` // Alternative to role + TasksFrom string `yaml:"tasks_from,omitempty"` + Vars map[string]any `yaml:"vars,omitempty"` + When any `yaml:"when,omitempty"` + Tags []string `yaml:"tags,omitempty"` +} + +// UnmarshalYAML handles both string and struct role refs. +func (r *RoleRef) UnmarshalYAML(unmarshal func(any) error) error { + // Try string first + var s string + if err := unmarshal(&s); err == nil { + r.Role = s + return nil + } + + // Try struct + type rawRoleRef RoleRef + var raw rawRoleRef + if err := unmarshal(&raw); err != nil { + return err + } + *r = RoleRef(raw) + if r.Role == "" && r.Name != "" { + r.Role = r.Name + } + return nil +} + +// Task represents an Ansible task. +type Task struct { + Name string `yaml:"name,omitempty"` + Module string `yaml:"-"` // Derived from the module key + Args map[string]any `yaml:"-"` // Module arguments + Register string `yaml:"register,omitempty"` + When any `yaml:"when,omitempty"` // string or []string + Loop any `yaml:"loop,omitempty"` // string or []any + LoopControl *LoopControl `yaml:"loop_control,omitempty"` + Vars map[string]any `yaml:"vars,omitempty"` + Environment map[string]string `yaml:"environment,omitempty"` + ChangedWhen any `yaml:"changed_when,omitempty"` + FailedWhen any `yaml:"failed_when,omitempty"` + IgnoreErrors bool `yaml:"ignore_errors,omitempty"` + NoLog bool `yaml:"no_log,omitempty"` + Become *bool `yaml:"become,omitempty"` + BecomeUser string `yaml:"become_user,omitempty"` + Delegate string `yaml:"delegate_to,omitempty"` + RunOnce bool `yaml:"run_once,omitempty"` + Tags []string `yaml:"tags,omitempty"` + Block []Task `yaml:"block,omitempty"` + Rescue []Task `yaml:"rescue,omitempty"` + Always []Task `yaml:"always,omitempty"` + Notify any `yaml:"notify,omitempty"` // string or []string + Retries int `yaml:"retries,omitempty"` + Delay int `yaml:"delay,omitempty"` + Until string `yaml:"until,omitempty"` + + // Include/import directives + IncludeTasks string `yaml:"include_tasks,omitempty"` + ImportTasks string `yaml:"import_tasks,omitempty"` + IncludeRole *struct { + Name string `yaml:"name"` + TasksFrom string `yaml:"tasks_from,omitempty"` + Vars map[string]any `yaml:"vars,omitempty"` + } `yaml:"include_role,omitempty"` + ImportRole *struct { + Name string `yaml:"name"` + TasksFrom string `yaml:"tasks_from,omitempty"` + Vars map[string]any `yaml:"vars,omitempty"` + } `yaml:"import_role,omitempty"` + + // Raw YAML for module extraction + raw map[string]any +} + +// LoopControl controls loop behavior. +type LoopControl struct { + LoopVar string `yaml:"loop_var,omitempty"` + IndexVar string `yaml:"index_var,omitempty"` + Label string `yaml:"label,omitempty"` + Pause int `yaml:"pause,omitempty"` + Extended bool `yaml:"extended,omitempty"` +} + +// TaskResult holds the result of executing a task. +type TaskResult struct { + Changed bool `json:"changed"` + Failed bool `json:"failed"` + Skipped bool `json:"skipped"` + Msg string `json:"msg,omitempty"` + Stdout string `json:"stdout,omitempty"` + Stderr string `json:"stderr,omitempty"` + RC int `json:"rc,omitempty"` + Results []TaskResult `json:"results,omitempty"` // For loops + Data map[string]any `json:"data,omitempty"` // Module-specific data + Duration time.Duration `json:"duration,omitempty"` +} + +// Inventory represents Ansible inventory. +type Inventory struct { + All *InventoryGroup `yaml:"all"` +} + +// InventoryGroup represents a group in inventory. +type InventoryGroup struct { + Hosts map[string]*Host `yaml:"hosts,omitempty"` + Children map[string]*InventoryGroup `yaml:"children,omitempty"` + Vars map[string]any `yaml:"vars,omitempty"` +} + +// Host represents a host in inventory. +type Host struct { + AnsibleHost string `yaml:"ansible_host,omitempty"` + AnsiblePort int `yaml:"ansible_port,omitempty"` + AnsibleUser string `yaml:"ansible_user,omitempty"` + AnsiblePassword string `yaml:"ansible_password,omitempty"` + AnsibleSSHPrivateKeyFile string `yaml:"ansible_ssh_private_key_file,omitempty"` + AnsibleConnection string `yaml:"ansible_connection,omitempty"` + AnsibleBecomePassword string `yaml:"ansible_become_password,omitempty"` + + // Custom vars + Vars map[string]any `yaml:",inline"` +} + +// Facts holds gathered facts about a host. +type Facts struct { + Hostname string `json:"ansible_hostname"` + FQDN string `json:"ansible_fqdn"` + OS string `json:"ansible_os_family"` + Distribution string `json:"ansible_distribution"` + Version string `json:"ansible_distribution_version"` + Architecture string `json:"ansible_architecture"` + Kernel string `json:"ansible_kernel"` + Memory int64 `json:"ansible_memtotal_mb"` + CPUs int `json:"ansible_processor_vcpus"` + IPv4 string `json:"ansible_default_ipv4_address"` +} + +// Known Ansible modules +var KnownModules = []string{ + // Builtin + "ansible.builtin.shell", + "ansible.builtin.command", + "ansible.builtin.raw", + "ansible.builtin.script", + "ansible.builtin.copy", + "ansible.builtin.template", + "ansible.builtin.file", + "ansible.builtin.lineinfile", + "ansible.builtin.blockinfile", + "ansible.builtin.stat", + "ansible.builtin.slurp", + "ansible.builtin.fetch", + "ansible.builtin.get_url", + "ansible.builtin.uri", + "ansible.builtin.apt", + "ansible.builtin.apt_key", + "ansible.builtin.apt_repository", + "ansible.builtin.yum", + "ansible.builtin.dnf", + "ansible.builtin.package", + "ansible.builtin.pip", + "ansible.builtin.service", + "ansible.builtin.systemd", + "ansible.builtin.user", + "ansible.builtin.group", + "ansible.builtin.cron", + "ansible.builtin.git", + "ansible.builtin.unarchive", + "ansible.builtin.archive", + "ansible.builtin.debug", + "ansible.builtin.fail", + "ansible.builtin.assert", + "ansible.builtin.pause", + "ansible.builtin.wait_for", + "ansible.builtin.set_fact", + "ansible.builtin.include_vars", + "ansible.builtin.add_host", + "ansible.builtin.group_by", + "ansible.builtin.meta", + "ansible.builtin.setup", + + // Short forms (legacy) + "shell", + "command", + "raw", + "script", + "copy", + "template", + "file", + "lineinfile", + "blockinfile", + "stat", + "slurp", + "fetch", + "get_url", + "uri", + "apt", + "apt_key", + "apt_repository", + "yum", + "dnf", + "package", + "pip", + "service", + "systemd", + "user", + "group", + "cron", + "git", + "unarchive", + "archive", + "debug", + "fail", + "assert", + "pause", + "wait_for", + "set_fact", + "include_vars", + "add_host", + "group_by", + "meta", + "setup", +} diff --git a/types_test.go b/types_test.go new file mode 100644 index 0000000..d11fe43 --- /dev/null +++ b/types_test.go @@ -0,0 +1,402 @@ +package ansible + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +// --- RoleRef UnmarshalYAML --- + +func TestRoleRef_UnmarshalYAML_Good_StringForm(t *testing.T) { + input := `common` + var ref RoleRef + err := yaml.Unmarshal([]byte(input), &ref) + + require.NoError(t, err) + assert.Equal(t, "common", ref.Role) +} + +func TestRoleRef_UnmarshalYAML_Good_StructForm(t *testing.T) { + input := ` +role: webserver +vars: + http_port: 80 +tags: + - web +` + var ref RoleRef + err := yaml.Unmarshal([]byte(input), &ref) + + require.NoError(t, err) + assert.Equal(t, "webserver", ref.Role) + assert.Equal(t, 80, ref.Vars["http_port"]) + assert.Equal(t, []string{"web"}, ref.Tags) +} + +func TestRoleRef_UnmarshalYAML_Good_NameField(t *testing.T) { + // Some playbooks use "name:" instead of "role:" + input := ` +name: myapp +tasks_from: install.yml +` + var ref RoleRef + err := yaml.Unmarshal([]byte(input), &ref) + + require.NoError(t, err) + assert.Equal(t, "myapp", ref.Role) // Name is copied to Role + assert.Equal(t, "install.yml", ref.TasksFrom) +} + +func TestRoleRef_UnmarshalYAML_Good_WithWhen(t *testing.T) { + input := ` +role: conditional_role +when: ansible_os_family == "Debian" +` + var ref RoleRef + err := yaml.Unmarshal([]byte(input), &ref) + + require.NoError(t, err) + assert.Equal(t, "conditional_role", ref.Role) + assert.NotNil(t, ref.When) +} + +// --- Task UnmarshalYAML --- + +func TestTask_UnmarshalYAML_Good_ModuleWithArgs(t *testing.T) { + input := ` +name: Install nginx +apt: + name: nginx + state: present +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + assert.Equal(t, "Install nginx", task.Name) + assert.Equal(t, "apt", task.Module) + assert.Equal(t, "nginx", task.Args["name"]) + assert.Equal(t, "present", task.Args["state"]) +} + +func TestTask_UnmarshalYAML_Good_FreeFormModule(t *testing.T) { + input := ` +name: Run command +shell: echo hello world +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + assert.Equal(t, "shell", task.Module) + assert.Equal(t, "echo hello world", task.Args["_raw_params"]) +} + +func TestTask_UnmarshalYAML_Good_ModuleNoArgs(t *testing.T) { + input := ` +name: Gather facts +setup: +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + assert.Equal(t, "setup", task.Module) + assert.NotNil(t, task.Args) +} + +func TestTask_UnmarshalYAML_Good_WithRegister(t *testing.T) { + input := ` +name: Check file +stat: + path: /etc/hosts +register: stat_result +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + assert.Equal(t, "stat_result", task.Register) + assert.Equal(t, "stat", task.Module) +} + +func TestTask_UnmarshalYAML_Good_WithWhen(t *testing.T) { + input := ` +name: Conditional task +debug: + msg: "hello" +when: some_var is defined +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + assert.NotNil(t, task.When) +} + +func TestTask_UnmarshalYAML_Good_WithLoop(t *testing.T) { + input := ` +name: Install packages +apt: + name: "{{ item }}" +loop: + - vim + - git + - curl +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + items, ok := task.Loop.([]any) + require.True(t, ok) + assert.Len(t, items, 3) +} + +func TestTask_UnmarshalYAML_Good_WithItems(t *testing.T) { + // with_items should be converted to loop + input := ` +name: Old-style loop +apt: + name: "{{ item }}" +with_items: + - vim + - git +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + // with_items should have been stored in Loop + items, ok := task.Loop.([]any) + require.True(t, ok) + assert.Len(t, items, 2) +} + +func TestTask_UnmarshalYAML_Good_WithNotify(t *testing.T) { + input := ` +name: Install package +apt: + name: nginx +notify: restart nginx +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + assert.Equal(t, "restart nginx", task.Notify) +} + +func TestTask_UnmarshalYAML_Good_WithNotifyList(t *testing.T) { + input := ` +name: Install package +apt: + name: nginx +notify: + - restart nginx + - reload config +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + notifyList, ok := task.Notify.([]any) + require.True(t, ok) + assert.Len(t, notifyList, 2) +} + +func TestTask_UnmarshalYAML_Good_IncludeTasks(t *testing.T) { + input := ` +name: Include tasks +include_tasks: other-tasks.yml +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + assert.Equal(t, "other-tasks.yml", task.IncludeTasks) +} + +func TestTask_UnmarshalYAML_Good_IncludeRole(t *testing.T) { + input := ` +name: Include role +include_role: + name: common + tasks_from: setup.yml +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + require.NotNil(t, task.IncludeRole) + assert.Equal(t, "common", task.IncludeRole.Name) + assert.Equal(t, "setup.yml", task.IncludeRole.TasksFrom) +} + +func TestTask_UnmarshalYAML_Good_BecomeFields(t *testing.T) { + input := ` +name: Privileged task +shell: systemctl restart nginx +become: true +become_user: root +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + require.NotNil(t, task.Become) + assert.True(t, *task.Become) + assert.Equal(t, "root", task.BecomeUser) +} + +func TestTask_UnmarshalYAML_Good_IgnoreErrors(t *testing.T) { + input := ` +name: Might fail +shell: some risky command +ignore_errors: true +` + var task Task + err := yaml.Unmarshal([]byte(input), &task) + + require.NoError(t, err) + assert.True(t, task.IgnoreErrors) +} + +// --- Inventory data structure --- + +func TestInventory_UnmarshalYAML_Good_Complex(t *testing.T) { + input := ` +all: + vars: + ansible_user: admin + ansible_ssh_private_key_file: ~/.ssh/id_ed25519 + hosts: + bastion: + ansible_host: 1.2.3.4 + ansible_port: 4819 + children: + webservers: + hosts: + web1: + ansible_host: 10.0.0.1 + web2: + ansible_host: 10.0.0.2 + vars: + http_port: 80 + databases: + hosts: + db1: + ansible_host: 10.0.1.1 + ansible_connection: ssh +` + var inv Inventory + err := yaml.Unmarshal([]byte(input), &inv) + + require.NoError(t, err) + require.NotNil(t, inv.All) + + // Check top-level vars + assert.Equal(t, "admin", inv.All.Vars["ansible_user"]) + + // Check top-level hosts + require.NotNil(t, inv.All.Hosts["bastion"]) + assert.Equal(t, "1.2.3.4", inv.All.Hosts["bastion"].AnsibleHost) + assert.Equal(t, 4819, inv.All.Hosts["bastion"].AnsiblePort) + + // Check children + require.NotNil(t, inv.All.Children["webservers"]) + assert.Len(t, inv.All.Children["webservers"].Hosts, 2) + assert.Equal(t, 80, inv.All.Children["webservers"].Vars["http_port"]) + + require.NotNil(t, inv.All.Children["databases"]) + assert.Equal(t, "ssh", inv.All.Children["databases"].Hosts["db1"].AnsibleConnection) +} + +// --- Facts --- + +func TestFacts_Struct(t *testing.T) { + facts := Facts{ + Hostname: "web1", + FQDN: "web1.example.com", + OS: "Debian", + Distribution: "ubuntu", + Version: "24.04", + Architecture: "x86_64", + Kernel: "6.8.0", + Memory: 16384, + CPUs: 4, + IPv4: "10.0.0.1", + } + + assert.Equal(t, "web1", facts.Hostname) + assert.Equal(t, "web1.example.com", facts.FQDN) + assert.Equal(t, "ubuntu", facts.Distribution) + assert.Equal(t, "x86_64", facts.Architecture) + assert.Equal(t, int64(16384), facts.Memory) + assert.Equal(t, 4, facts.CPUs) +} + +// --- TaskResult --- + +func TestTaskResult_Struct(t *testing.T) { + result := TaskResult{ + Changed: true, + Failed: false, + Skipped: false, + Msg: "task completed", + Stdout: "output", + Stderr: "", + RC: 0, + } + + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.Equal(t, "task completed", result.Msg) + assert.Equal(t, 0, result.RC) +} + +func TestTaskResult_WithLoopResults(t *testing.T) { + result := TaskResult{ + Changed: true, + Results: []TaskResult{ + {Changed: true, RC: 0}, + {Changed: false, RC: 0}, + {Changed: true, RC: 0}, + }, + } + + assert.Len(t, result.Results, 3) + assert.True(t, result.Results[0].Changed) + assert.False(t, result.Results[1].Changed) +} + +// --- KnownModules --- + +func TestKnownModules_ContainsExpected(t *testing.T) { + // Verify both FQCN and short forms are present + fqcnModules := []string{ + "ansible.builtin.shell", + "ansible.builtin.command", + "ansible.builtin.copy", + "ansible.builtin.file", + "ansible.builtin.apt", + "ansible.builtin.service", + "ansible.builtin.systemd", + "ansible.builtin.debug", + "ansible.builtin.set_fact", + } + for _, mod := range fqcnModules { + assert.Contains(t, KnownModules, mod, "expected FQCN module %s", mod) + } + + shortModules := []string{ + "shell", "command", "copy", "file", "apt", "service", + "systemd", "debug", "set_fact", "template", "user", "group", + } + for _, mod := range shortModules { + assert.Contains(t, KnownModules, mod, "expected short-form module %s", mod) + } +}