From 22e689bbfa02a4b2029bfdb76652f9a0223bf680 Mon Sep 17 00:00:00 2001 From: Virgil Date: Fri, 3 Apr 2026 11:49:21 +0000 Subject: [PATCH] fix(ansible): fail fast on reboot initiation errors --- modules.go | 19 +++++++++++++++++-- modules_adv_test.go | 15 +++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/modules.go b/modules.go index 26f08e0..1a1d263 100644 --- a/modules.go +++ b/modules.go @@ -3253,12 +3253,27 @@ func (e *Executor) moduleReboot(ctx context.Context, client sshExecutorClient, a testCommand := getStringArg(args, "test_command", "whoami") msg := getStringArg(args, "msg", "Reboot initiated by Ansible") + runReboot := func(cmd string) (*TaskResult, error) { + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + msg := stderr + if msg == "" && err != nil { + msg = err.Error() + } + return &TaskResult{Failed: true, Msg: msg, Stdout: stdout, Stderr: stderr, RC: rc}, nil + } + return nil, nil + } if preRebootDelay > 0 { cmd := sprintf("sleep %d && shutdown -r now '%s' &", preRebootDelay, msg) - _, _, _, _ = client.Run(ctx, cmd) + if result, err := runReboot(cmd); err != nil || result != nil { + return result, err + } } else { - _, _, _, _ = client.Run(ctx, sprintf("shutdown -r now '%s' &", msg)) + if result, err := runReboot(sprintf("shutdown -r now '%s' &", msg)); err != nil || result != nil { + return result, err + } } if postRebootDelay > 0 { diff --git a/modules_adv_test.go b/modules_adv_test.go index 2cc0008..b622401 100644 --- a/modules_adv_test.go +++ b/modules_adv_test.go @@ -1999,3 +1999,18 @@ func TestModulesAdv_ModuleReboot_Bad_TimesOutWaitingForTestCommand(t *testing.T) assert.Equal(t, "host unreachable", result.Stderr) assert.Equal(t, 1, result.RC) } + +func TestModulesAdv_ModuleReboot_Bad_ReportsInitialShutdownFailure(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`shutdown -r now 'Reboot initiated by Ansible' &`, "", "permission denied", 1) + + result, err := e.moduleReboot(context.Background(), mock, map[string]any{}) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.Failed) + assert.Equal(t, "permission denied", result.Msg) + assert.Equal(t, 1, result.RC) + assert.Equal(t, 1, mock.commandCount()) + assert.False(t, mock.hasExecuted(`whoami`)) +}