From a2b7cfe2281fcd4eb50a90176b6cacf908c6bf51 Mon Sep 17 00:00:00 2001 From: Virgil Date: Fri, 3 Apr 2026 14:40:29 +0000 Subject: [PATCH] feat(ansible): add wait_for_connection module Co-Authored-By: Virgil --- mock_ssh_test.go | 3 ++ modules.go | 77 +++++++++++++++++++++++++++++++++++++++++++++ modules_adv_test.go | 38 ++++++++++++++++++++++ types.go | 2 ++ 4 files changed, 120 insertions(+) diff --git a/mock_ssh_test.go b/mock_ssh_test.go index a3dde88..49ba79b 100644 --- a/mock_ssh_test.go +++ b/mock_ssh_test.go @@ -481,6 +481,9 @@ func executeModuleWithMock(e *Executor, mock *MockSSHClient, host string, task * case "ansible.builtin.git": return moduleGitWithClient(e, mock, args) + case "ansible.builtin.wait_for_connection", "wait_for_connection": + return e.moduleWaitForConnection(context.Background(), mock, args) + // Archive case "ansible.builtin.unarchive": return moduleUnarchiveWithClient(e, mock, args) diff --git a/modules.go b/modules.go index 212e1d4..81ce800 100644 --- a/modules.go +++ b/modules.go @@ -156,6 +156,8 @@ func (e *Executor) executeModule(ctx context.Context, host string, client sshExe return e.modulePause(ctx, args) case "ansible.builtin.wait_for": return e.moduleWaitFor(ctx, client, args) + case "ansible.builtin.wait_for_connection": + return e.moduleWaitForConnection(ctx, client, args) case "ansible.builtin.git": return e.moduleGit(ctx, client, args) case "ansible.builtin.unarchive": @@ -2500,6 +2502,81 @@ func (e *Executor) moduleWaitFor(ctx context.Context, client sshExecutorClient, return &TaskResult{Changed: false}, nil } +func (e *Executor) moduleWaitForConnection(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { + timeout := getIntArg(args, "timeout", 300) + delay := getIntArg(args, "delay", 0) + sleep := getIntArg(args, "sleep", 1) + timeoutMsg := getStringArg(args, "msg", "wait_for_connection timed out") + + if delay > 0 { + timer := time.NewTimer(time.Duration(delay) * time.Second) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } + } + + runCheck := func() (*TaskResult, bool) { + stdout, stderr, rc, err := client.Run(ctx, "true") + if err == nil && rc == 0 { + return &TaskResult{Changed: false}, true + } + if timeout <= 0 { + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error(), Stdout: stdout, Stderr: stderr, RC: rc}, true + } + return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, Stderr: stderr, RC: rc}, true + } + return &TaskResult{Stdout: stdout, Stderr: stderr, RC: rc}, false + } + + if timeout <= 0 { + result, done := runCheck() + if done { + return result, nil + } + return &TaskResult{Failed: true, Msg: timeoutMsg}, nil + } + + deadline := time.NewTimer(time.Duration(timeout) * time.Second) + defer deadline.Stop() + + sleepDuration := time.Duration(sleep) * time.Second + if sleepDuration < 0 { + sleepDuration = 0 + } + + for { + result, done := runCheck() + if done { + return result, nil + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-deadline.C: + return &TaskResult{Failed: true, Msg: timeoutMsg, Stdout: result.Stdout, Stderr: result.Stderr, RC: result.RC}, nil + default: + } + + if sleepDuration > 0 { + timer := time.NewTimer(sleepDuration) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-deadline.C: + timer.Stop() + return &TaskResult{Failed: true, Msg: timeoutMsg, Stdout: result.Stdout, Stderr: result.Stderr, RC: result.RC}, nil + case <-timer.C: + } + } + } +} + func (e *Executor) moduleGit(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { repo := getStringArg(args, "repo", "") dest := getStringArg(args, "dest", "") diff --git a/modules_adv_test.go b/modules_adv_test.go index b70350c..5058e5c 100644 --- a/modules_adv_test.go +++ b/modules_adv_test.go @@ -1264,6 +1264,44 @@ func TestModulesAdv_ModuleWaitFor_Good_AcceptsStringNumericArgs(t *testing.T) { assert.True(t, mock.hasExecuted(`until ! nc -z 127.0.0.1 8080`)) } +// --- wait_for_connection module --- + +func TestModulesAdv_ModuleWaitForConnection_Good_ReturnsWhenHostIsReachable(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`^true$`, "", "", 0) + + result, err := executeModuleWithMock(e, mock, "host1", &Task{ + Module: "wait_for_connection", + Args: map[string]any{ + "timeout": 0, + }, + }) + + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.Failed) + assert.False(t, result.Changed) + assert.True(t, mock.hasExecuted(`^true$`)) +} + +func TestModulesAdv_ModuleWaitForConnection_Bad_ImmediateFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommandError(`^true$`, assert.AnError) + + result, err := executeModuleWithMock(e, mock, "host1", &Task{ + Module: "ansible.builtin.wait_for_connection", + Args: map[string]any{ + "timeout": 0, + }, + }) + + require.NoError(t, err) + assert.NotNil(t, result) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, assert.AnError.Error()) + assert.True(t, mock.hasExecuted(`^true$`)) +} + // --- include_vars module --- func TestModulesAdv_ModuleIncludeVars_Good_LoadSingleFile(t *testing.T) { diff --git a/types.go b/types.go index 12d0b4f..25d578e 100644 --- a/types.go +++ b/types.go @@ -427,6 +427,7 @@ var KnownModules = []string{ "ansible.builtin.ping", "ansible.builtin.pause", "ansible.builtin.wait_for", + "ansible.builtin.wait_for_connection", "ansible.builtin.set_fact", "ansible.builtin.include_vars", "ansible.builtin.add_host", @@ -481,6 +482,7 @@ var KnownModules = []string{ "ping", "pause", "wait_for", + "wait_for_connection", "set_fact", "include_vars", "add_host",