diff --git a/modules.go b/modules.go index 81ce800..2c1bda8 100644 --- a/modules.go +++ b/modules.go @@ -2506,6 +2506,7 @@ func (e *Executor) moduleWaitForConnection(ctx context.Context, client sshExecut timeout := getIntArg(args, "timeout", 300) delay := getIntArg(args, "delay", 0) sleep := getIntArg(args, "sleep", 1) + connectTimeout := getIntArg(args, "connect_timeout", 5) timeoutMsg := getStringArg(args, "msg", "wait_for_connection timed out") if delay > 0 { @@ -2519,7 +2520,14 @@ func (e *Executor) moduleWaitForConnection(ctx context.Context, client sshExecut } runCheck := func() (*TaskResult, bool) { - stdout, stderr, rc, err := client.Run(ctx, "true") + runCtx := ctx + if connectTimeout > 0 { + var cancel context.CancelFunc + runCtx, cancel = context.WithTimeout(ctx, time.Duration(connectTimeout)*time.Second) + defer cancel() + } + + stdout, stderr, rc, err := client.Run(runCtx, "true") if err == nil && rc == 0 { return &TaskResult{Changed: false}, true } diff --git a/modules_adv_test.go b/modules_adv_test.go index 5058e5c..4b52086 100644 --- a/modules_adv_test.go +++ b/modules_adv_test.go @@ -1266,6 +1266,17 @@ func TestModulesAdv_ModuleWaitFor_Good_AcceptsStringNumericArgs(t *testing.T) { // --- wait_for_connection module --- +type deadlineAwareMockClient struct { + *MockSSHClient +} + +func (c *deadlineAwareMockClient) Run(ctx context.Context, cmd string) (string, string, int, error) { + if _, ok := ctx.Deadline(); !ok { + return "", "", 1, context.DeadlineExceeded + } + return c.MockSSHClient.Run(ctx, cmd) +} + func TestModulesAdv_ModuleWaitForConnection_Good_ReturnsWhenHostIsReachable(t *testing.T) { e, mock := newTestExecutorWithMock("host1") mock.expectCommand(`^true$`, "", "", 0) @@ -1284,6 +1295,22 @@ func TestModulesAdv_ModuleWaitForConnection_Good_ReturnsWhenHostIsReachable(t *t assert.True(t, mock.hasExecuted(`^true$`)) } +func TestModulesAdv_ModuleWaitForConnection_Good_UsesConnectTimeout(t *testing.T) { + e := NewExecutor("/tmp") + client := &deadlineAwareMockClient{MockSSHClient: NewMockSSHClient()} + + result, err := e.moduleWaitForConnection(context.Background(), client, map[string]any{ + "timeout": 0, + "connect_timeout": 1, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.Failed) + assert.False(t, result.Changed) + assert.True(t, client.hasExecuted(`^true$`)) +} + func TestModulesAdv_ModuleWaitForConnection_Bad_ImmediateFailure(t *testing.T) { e, mock := newTestExecutorWithMock("host1") mock.expectCommandError(`^true$`, assert.AnError)