From eb3b9cca07683e09df248ee9f579809fe57fb338 Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 20:00:45 +0000 Subject: [PATCH] feat(ansible): support delegate_to task execution Co-authored-by: Virgil --- executor.go | 37 +++++++++++++++++----- executor_test.go | 24 +++++++++++++++ mock_ssh_test.go | 19 ++++++------ modules.go | 80 +++++++++++++++++++++++------------------------- ssh.go | 7 +++++ 5 files changed, 110 insertions(+), 57 deletions(-) diff --git a/executor.go b/executor.go index 875b6f2..bfcf77f 100644 --- a/executor.go +++ b/executor.go @@ -3,6 +3,8 @@ package ansible import ( "context" "errors" + "io" + "io/fs" "regexp" "slices" "strconv" @@ -17,6 +19,19 @@ import ( var errEndPlay = errors.New("end play") +// sshExecutorClient is the client contract used by the executor. +type sshExecutorClient interface { + Run(ctx context.Context, cmd string) (stdout, stderr string, exitCode int, err error) + RunScript(ctx context.Context, script string) (stdout, stderr string, exitCode int, err error) + Upload(ctx context.Context, local io.Reader, remote string, mode fs.FileMode) error + Download(ctx context.Context, remote string) ([]byte, error) + FileExists(ctx context.Context, path string) (bool, error) + Stat(ctx context.Context, path string) (map[string]any, error) + BecomeState() (become bool, user, password string) + SetBecome(become bool, user, password string) + Close() error +} + // Executor runs Ansible playbooks. // // Example: @@ -30,7 +45,7 @@ type Executor struct { results map[string]map[string]*TaskResult // host -> register_name -> result handlers map[string][]Task notified map[string]bool - clients map[string]*SSHClient + clients map[string]sshExecutorClient mu sync.RWMutex // Callbacks @@ -61,7 +76,7 @@ func NewExecutor(basePath string) *Executor { results: make(map[string]map[string]*TaskResult), handlers: make(map[string][]Task), notified: make(map[string]bool), - clients: make(map[string]*SSHClient), + clients: make(map[string]sshExecutorClient), } } @@ -472,9 +487,17 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, hosts []strin } // Get SSH client - client, err := e.getClient(host, play) + executionHost := host + if task.Delegate != "" { + executionHost = e.templateString(task.Delegate, host, task) + if executionHost == "" { + executionHost = host + } + } + + client, err := e.getClient(executionHost, play) if err != nil { - return coreerr.E("Executor.runTaskOnHost", sprintf("get client for %s", host), err) + return coreerr.E("Executor.runTaskOnHost", sprintf("get client for %s", executionHost), err) } // Handle loops @@ -592,7 +615,7 @@ func shouldRetryTask(task *Task, host string, e *Executor, result *TaskResult) b } // runLoop handles task loops. -func (e *Executor) runLoop(ctx context.Context, host string, client *SSHClient, task *Task, play *Play) error { +func (e *Executor) runLoop(ctx context.Context, host string, client sshExecutorClient, task *Task, play *Play) error { items := e.resolveLoop(task.Loop, host) loopVar := "item" @@ -833,7 +856,7 @@ func (e *Executor) getHosts(pattern string) []string { } // getClient returns or creates an SSH client for a host. -func (e *Executor) getClient(host string, play *Play) (*SSHClient, error) { +func (e *Executor) getClient(host string, play *Play) (sshExecutorClient, error) { e.mu.Lock() defer e.mu.Unlock() @@ -1555,7 +1578,7 @@ func (e *Executor) Close() { for _, client := range e.clients { _ = client.Close() } - e.clients = make(map[string]*SSHClient) + e.clients = make(map[string]sshExecutorClient) } // TemplateFile processes a template file. diff --git a/executor_test.go b/executor_test.go index 449a916..1209384 100644 --- a/executor_test.go +++ b/executor_test.go @@ -207,6 +207,30 @@ func TestExecutor_RunTaskOnHosts_Good_RunOnceSharesRegisteredResult(t *testing.T assert.Equal(t, "hello", e.results["host2"]["debug_result"].Msg) } +func TestExecutor_RunTaskOnHost_Good_DelegateToUsesDelegatedClient(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + e.SetVar("delegate_host", "delegate1") + e.inventory.All.Hosts["delegate1"] = &Host{AnsibleHost: "127.0.0.2"} + e.clients["delegate1"] = mock + mock.expectCommand(`echo delegated`, "delegated", "", 0) + + task := &Task{ + Name: "Delegate command", + Module: "command", + Args: map[string]any{"cmd": "echo delegated"}, + Delegate: "{{ delegate_host }}", + Register: "delegated_result", + } + + err := e.runTaskOnHost(context.Background(), "host1", []string{"host1"}, task, &Play{}) + require.NoError(t, err) + + require.NotNil(t, e.results["host1"]["delegated_result"]) + assert.Equal(t, "delegated", e.results["host1"]["delegated_result"].Stdout) + assert.True(t, mock.hasExecuted(`echo delegated`)) + assert.Equal(t, 1, mock.commandCount()) +} + func TestExecutor_RunPlay_Good_SerialBatchesHosts(t *testing.T) { e := NewExecutor("/tmp") e.SetInventoryDirect(&Inventory{ diff --git a/mock_ssh_test.go b/mock_ssh_test.go index afcb955..49cd08b 100644 --- a/mock_ssh_test.go +++ b/mock_ssh_test.go @@ -269,6 +269,13 @@ func (m *MockSSHClient) Close() error { return nil } +// BecomeState returns the current privilege escalation settings. +func (m *MockSSHClient) BecomeState() (bool, string, string) { + m.mu.Lock() + defer m.mu.Unlock() + return m.become, m.becomeUser, m.becomePass +} + // --- Assertion helpers --- // executedCommands returns a copy of the execution log. @@ -367,19 +374,12 @@ func (m *MockSSHClient) reset() { // --- Test helper: create executor with mock client --- // newTestExecutorWithMock creates an Executor pre-wired with a MockSSHClient -// for the given host. The executor has a minimal inventory so that -// executeModule can be called directly. +// for the given host. The executor has a minimal inventory so that tasks can +// be executed through the normal host/client lookup path. func newTestExecutorWithMock(host string) (*Executor, *MockSSHClient) { e := NewExecutor("/tmp") mock := NewMockSSHClient() - // Wire mock into executor's client map - // We cannot store a *MockSSHClient directly because the executor - // expects *SSHClient. Instead, we provide a helper that calls - // modules the same way the executor does but with the mock. - // Since modules call methods on *SSHClient directly and the mock - // has identical method signatures, we use a shim approach. - // Set up minimal inventory so host resolution works e.SetInventoryDirect(&Inventory{ All: &InventoryGroup{ @@ -388,6 +388,7 @@ func newTestExecutorWithMock(host string) (*Executor, *MockSSHClient) { }, }, }) + e.clients[host] = mock return e, mock } diff --git a/modules.go b/modules.go index 10c1011..6a63494 100644 --- a/modules.go +++ b/modules.go @@ -20,15 +20,13 @@ type sshFactsRunner interface { } // executeModule dispatches to the appropriate module handler. -func (e *Executor) executeModule(ctx context.Context, host string, client *SSHClient, task *Task, play *Play) (*TaskResult, error) { +func (e *Executor) executeModule(ctx context.Context, host string, client sshExecutorClient, task *Task, play *Play) (*TaskResult, error) { module := NormalizeModule(task.Module) // Apply task-level become if task.Become != nil && *task.Become { // Save old state to restore - oldBecome := client.become - oldUser := client.becomeUser - oldPass := client.becomePass + oldBecome, oldUser, oldPass := client.BecomeState() client.SetBecome(true, task.BecomeUser, "") @@ -191,7 +189,7 @@ func (e *Executor) templateArgs(args map[string]any, host string, task *Task) ma // --- Command Modules --- -func (e *Executor) moduleShell(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleShell(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { cmd := getStringArg(args, "_raw_params", "") if cmd == "" { cmd = getStringArg(args, "cmd", "") @@ -219,7 +217,7 @@ func (e *Executor) moduleShell(ctx context.Context, client *SSHClient, args map[ }, nil } -func (e *Executor) moduleCommand(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleCommand(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { cmd := getStringArg(args, "_raw_params", "") if cmd == "" { cmd = getStringArg(args, "cmd", "") @@ -247,7 +245,7 @@ func (e *Executor) moduleCommand(ctx context.Context, client *SSHClient, args ma }, nil } -func (e *Executor) moduleRaw(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleRaw(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { cmd := getStringArg(args, "_raw_params", "") if cmd == "" { return nil, coreerr.E("Executor.moduleRaw", "no command specified", nil) @@ -266,7 +264,7 @@ func (e *Executor) moduleRaw(ctx context.Context, client *SSHClient, args map[st }, nil } -func (e *Executor) moduleScript(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleScript(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { script := getStringArg(args, "_raw_params", "") if script == "" { return nil, coreerr.E("Executor.moduleScript", "no script specified", nil) @@ -294,7 +292,7 @@ func (e *Executor) moduleScript(ctx context.Context, client *SSHClient, args map // --- File Modules --- -func (e *Executor) moduleCopy(ctx context.Context, client *SSHClient, args map[string]any, host string, task *Task) (*TaskResult, error) { +func (e *Executor) moduleCopy(ctx context.Context, client sshExecutorClient, args map[string]any, host string, task *Task) (*TaskResult, error) { dest := getStringArg(args, "dest", "") if dest == "" { return nil, coreerr.E("Executor.moduleCopy", "dest required", nil) @@ -337,7 +335,7 @@ func (e *Executor) moduleCopy(ctx context.Context, client *SSHClient, args map[s return &TaskResult{Changed: true, Msg: sprintf("copied to %s", dest)}, nil } -func (e *Executor) moduleTemplate(ctx context.Context, client *SSHClient, args map[string]any, host string, task *Task) (*TaskResult, error) { +func (e *Executor) moduleTemplate(ctx context.Context, client sshExecutorClient, args map[string]any, host string, task *Task) (*TaskResult, error) { src := getStringArg(args, "src", "") dest := getStringArg(args, "dest", "") if src == "" || dest == "" { @@ -365,7 +363,7 @@ func (e *Executor) moduleTemplate(ctx context.Context, client *SSHClient, args m return &TaskResult{Changed: true, Msg: sprintf("templated to %s", dest)}, nil } -func (e *Executor) moduleFile(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleFile(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { path := getStringArg(args, "path", "") if path == "" { path = getStringArg(args, "dest", "") @@ -433,7 +431,7 @@ func (e *Executor) moduleFile(ctx context.Context, client *SSHClient, args map[s return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleLineinfile(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleLineinfile(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { path := getStringArg(args, "path", "") if path == "" { path = getStringArg(args, "dest", "") @@ -491,7 +489,7 @@ func (e *Executor) moduleLineinfile(ctx context.Context, client *SSHClient, args return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleStat(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleStat(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { path := getStringArg(args, "path", "") if path == "" { return nil, coreerr.E("Executor.moduleStat", "path required", nil) @@ -508,7 +506,7 @@ func (e *Executor) moduleStat(ctx context.Context, client *SSHClient, args map[s }, nil } -func (e *Executor) moduleSlurp(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleSlurp(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { path := getStringArg(args, "path", "") if path == "" { path = getStringArg(args, "src", "") @@ -530,7 +528,7 @@ func (e *Executor) moduleSlurp(ctx context.Context, client *SSHClient, args map[ }, nil } -func (e *Executor) moduleFetch(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleFetch(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { src := getStringArg(args, "src", "") dest := getStringArg(args, "dest", "") if src == "" || dest == "" { @@ -554,7 +552,7 @@ func (e *Executor) moduleFetch(ctx context.Context, client *SSHClient, args map[ return &TaskResult{Changed: true, Msg: sprintf("fetched %s to %s", src, dest)}, nil } -func (e *Executor) moduleGetURL(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleGetURL(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { url := getStringArg(args, "url", "") dest := getStringArg(args, "dest", "") if url == "" || dest == "" { @@ -578,7 +576,7 @@ func (e *Executor) moduleGetURL(ctx context.Context, client *SSHClient, args map // --- Package Modules --- -func (e *Executor) moduleApt(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleApt(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { name := getStringArg(args, "name", "") state := getStringArg(args, "state", "present") updateCache := getBoolArg(args, "update_cache", false) @@ -612,7 +610,7 @@ func (e *Executor) moduleApt(ctx context.Context, client *SSHClient, args map[st return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleAptKey(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleAptKey(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { url := getStringArg(args, "url", "") keyring := getStringArg(args, "keyring", "") state := getStringArg(args, "state", "present") @@ -643,7 +641,7 @@ func (e *Executor) moduleAptKey(ctx context.Context, client *SSHClient, args map return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleAptRepository(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleAptRepository(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { repo := getStringArg(args, "repo", "") filename := getStringArg(args, "filename", "") state := getStringArg(args, "state", "present") @@ -680,7 +678,7 @@ func (e *Executor) moduleAptRepository(ctx context.Context, client *SSHClient, a return &TaskResult{Changed: true}, nil } -func (e *Executor) modulePackage(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) modulePackage(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { // Detect package manager and delegate stdout, _, _, _ := client.Run(ctx, "which apt-get yum dnf 2>/dev/null | head -1") stdout = corexTrimSpace(stdout) @@ -695,15 +693,15 @@ func (e *Executor) modulePackage(ctx context.Context, client *SSHClient, args ma } } -func (e *Executor) moduleYum(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleYum(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { return e.moduleRPM(ctx, client, args, "yum") } -func (e *Executor) moduleDnf(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleDnf(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { return e.moduleRPM(ctx, client, args, "dnf") } -func (e *Executor) moduleRPM(ctx context.Context, client *SSHClient, args map[string]any, manager string) (*TaskResult, error) { +func (e *Executor) moduleRPM(ctx context.Context, client sshExecutorClient, args map[string]any, manager string) (*TaskResult, error) { name := getStringArg(args, "name", "") state := getStringArg(args, "state", "present") updateCache := getBoolArg(args, "update_cache", false) @@ -744,7 +742,7 @@ func (e *Executor) moduleRPM(ctx context.Context, client *SSHClient, args map[st return &TaskResult{Changed: true}, nil } -func (e *Executor) modulePip(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) modulePip(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { name := getStringArg(args, "name", "") state := getStringArg(args, "state", "present") executable := getStringArg(args, "executable", "pip3") @@ -769,7 +767,7 @@ func (e *Executor) modulePip(ctx context.Context, client *SSHClient, args map[st // --- Service Modules --- -func (e *Executor) moduleService(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleService(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { name := getStringArg(args, "name", "") state := getStringArg(args, "state", "") enabled := args["enabled"] @@ -811,7 +809,7 @@ func (e *Executor) moduleService(ctx context.Context, client *SSHClient, args ma return &TaskResult{Changed: len(cmds) > 0}, nil } -func (e *Executor) moduleSystemd(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleSystemd(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { // systemd is similar to service if getBoolArg(args, "daemon_reload", false) { _, _, _, _ = client.Run(ctx, "systemctl daemon-reload") @@ -822,7 +820,7 @@ func (e *Executor) moduleSystemd(ctx context.Context, client *SSHClient, args ma // --- User/Group Modules --- -func (e *Executor) moduleUser(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleUser(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { name := getStringArg(args, "name", "") state := getStringArg(args, "state", "present") @@ -879,7 +877,7 @@ func (e *Executor) moduleUser(ctx context.Context, client *SSHClient, args map[s return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleGroup(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleGroup(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { name := getStringArg(args, "name", "") state := getStringArg(args, "state", "present") @@ -914,7 +912,7 @@ func (e *Executor) moduleGroup(ctx context.Context, client *SSHClient, args map[ // --- HTTP Module --- -func (e *Executor) moduleURI(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleURI(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { url := getStringArg(args, "url", "") method := getStringArg(args, "method", "GET") @@ -1297,7 +1295,7 @@ func findInventoryHost(group *InventoryGroup, name string) *Host { return nil } -func (e *Executor) moduleWaitFor(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleWaitFor(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { port := 0 if p, ok := args["port"].(int); ok { port = p @@ -1321,7 +1319,7 @@ func (e *Executor) moduleWaitFor(ctx context.Context, client *SSHClient, args ma return &TaskResult{Changed: false}, nil } -func (e *Executor) moduleGit(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleGit(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { repo := getStringArg(args, "repo", "") dest := getStringArg(args, "dest", "") version := getStringArg(args, "version", "HEAD") @@ -1350,7 +1348,7 @@ func (e *Executor) moduleGit(ctx context.Context, client *SSHClient, args map[st return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleUnarchive(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleUnarchive(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { src := getStringArg(args, "src", "") dest := getStringArg(args, "dest", "") remote := getBoolArg(args, "remote_src", false) @@ -1401,7 +1399,7 @@ func (e *Executor) moduleUnarchive(ctx context.Context, client *SSHClient, args return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleArchive(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleArchive(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { dest := getStringArg(args, "dest", "") format := lower(getStringArg(args, "format", "")) paths := archivePaths(args) @@ -1517,7 +1515,7 @@ func getBoolArg(args map[string]any, key string, def bool) bool { // --- Additional Modules --- -func (e *Executor) moduleHostname(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleHostname(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { name := getStringArg(args, "name", "") if name == "" { return nil, coreerr.E("Executor.moduleHostname", "name required", nil) @@ -1536,7 +1534,7 @@ func (e *Executor) moduleHostname(ctx context.Context, client *SSHClient, args m return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleSysctl(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleSysctl(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { name := getStringArg(args, "name", "") value := getStringArg(args, "value", "") state := getStringArg(args, "state", "present") @@ -1569,7 +1567,7 @@ func (e *Executor) moduleSysctl(ctx context.Context, client *SSHClient, args map return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleCron(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleCron(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { name := getStringArg(args, "name", "") job := getStringArg(args, "job", "") state := getStringArg(args, "state", "present") @@ -1606,7 +1604,7 @@ func (e *Executor) moduleCron(ctx context.Context, client *SSHClient, args map[s return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleBlockinfile(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleBlockinfile(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { path := getStringArg(args, "path", "") if path == "" { path = getStringArg(args, "dest", "") @@ -1897,7 +1895,7 @@ func osFamilyFromReleaseID(id string) string { } } -func (e *Executor) moduleReboot(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleReboot(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { preRebootDelay := 0 if d, ok := args["pre_reboot_delay"].(int); ok { preRebootDelay = d @@ -1915,7 +1913,7 @@ func (e *Executor) moduleReboot(ctx context.Context, client *SSHClient, args map return &TaskResult{Changed: true, Msg: "Reboot initiated"}, nil } -func (e *Executor) moduleUFW(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleUFW(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { rule := getStringArg(args, "rule", "") port := getStringArg(args, "port", "") proto := getStringArg(args, "proto", "tcp") @@ -1966,7 +1964,7 @@ func (e *Executor) moduleUFW(ctx context.Context, client *SSHClient, args map[st return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleAuthorizedKey(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleAuthorizedKey(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { user := getStringArg(args, "user", "") key := getStringArg(args, "key", "") state := getStringArg(args, "state", "present") @@ -2017,7 +2015,7 @@ func (e *Executor) moduleAuthorizedKey(ctx context.Context, client *SSHClient, a return &TaskResult{Changed: true}, nil } -func (e *Executor) moduleDockerCompose(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { +func (e *Executor) moduleDockerCompose(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { projectSrc := getStringArg(args, "project_src", "") state := getStringArg(args, "state", "present") diff --git a/ssh.go b/ssh.go index 26ef65a..3813093 100644 --- a/ssh.go +++ b/ssh.go @@ -212,6 +212,13 @@ func (c *SSHClient) Close() error { return nil } +// BecomeState returns the current privilege escalation settings. +func (c *SSHClient) BecomeState() (bool, string, string) { + c.mu.Lock() + defer c.mu.Unlock() + return c.become, c.becomeUser, c.becomePass +} + // Run executes a command on the remote host. // // Example: