diff --git a/executor.go b/executor.go index 7f9b0f7..ad36793 100644 --- a/executor.go +++ b/executor.go @@ -529,7 +529,7 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, hosts []strin } if NormalizeModule(task.Module) == "ansible.builtin.meta" { - if err := e.handleMetaAction(ctx, hosts, play, result); err != nil { + if err := e.handleMetaAction(ctx, host, hosts, play, result); err != nil { return err } } @@ -1629,7 +1629,7 @@ func (e *Executor) runNotifiedHandlers(ctx context.Context, hosts []string, play // handleMetaAction applies module meta side effects after the task result has // been recorded and callbacks have fired. -func (e *Executor) handleMetaAction(ctx context.Context, hosts []string, play *Play, result *TaskResult) error { +func (e *Executor) handleMetaAction(ctx context.Context, host string, hosts []string, play *Play, result *TaskResult) error { if result == nil || result.Data == nil { return nil } @@ -1643,6 +1643,9 @@ func (e *Executor) handleMetaAction(ctx context.Context, hosts []string, play *P return nil case "end_play": return errEndPlay + case "reset_connection": + e.resetConnection(host) + return nil default: return nil } @@ -1658,6 +1661,24 @@ func (e *Executor) clearFacts(hosts []string) { } } +// resetConnection closes and removes the cached SSH client for a host. +func (e *Executor) resetConnection(host string) { + if host == "" { + return + } + + e.mu.Lock() + client, ok := e.clients[host] + if ok { + delete(e.clients, host) + } + e.mu.Unlock() + + if ok { + _ = client.Close() + } +} + // Close closes all SSH connections. // // Example: diff --git a/executor_extra_test.go b/executor_extra_test.go index 4283f3c..fa37ddc 100644 --- a/executor_extra_test.go +++ b/executor_extra_test.go @@ -324,13 +324,23 @@ func TestExecutorExtra_ModuleMeta_Good_ClearFacts(t *testing.T) { assert.Equal(t, "clear_facts", result.Data["action"]) } +func TestExecutorExtra_ModuleMeta_Good_ResetConnection(t *testing.T) { + e := NewExecutor("/tmp") + result, err := e.moduleMeta(map[string]any{"_raw_params": "reset_connection"}) + + require.NoError(t, err) + assert.False(t, result.Changed) + require.NotNil(t, result.Data) + assert.Equal(t, "reset_connection", result.Data["action"]) +} + func TestExecutorExtra_HandleMetaAction_Good_ClearFacts(t *testing.T) { e := NewExecutor("/tmp") e.facts["host1"] = &Facts{Hostname: "web01"} e.facts["host2"] = &Facts{Hostname: "web02"} result := &TaskResult{Data: map[string]any{"action": "clear_facts"}} - require.NoError(t, e.handleMetaAction(context.Background(), []string{"host1"}, nil, result)) + require.NoError(t, e.handleMetaAction(context.Background(), "host1", []string{"host1"}, nil, result)) _, ok := e.facts["host1"] assert.False(t, ok) @@ -338,6 +348,22 @@ func TestExecutorExtra_HandleMetaAction_Good_ClearFacts(t *testing.T) { assert.Equal(t, "web02", e.facts["host2"].Hostname) } +func TestExecutorExtra_HandleMetaAction_Good_ResetConnection(t *testing.T) { + e := NewExecutor("/tmp") + mock := NewMockSSHClient() + e.clients["host1"] = mock + e.clients["host2"] = NewMockSSHClient() + + result := &TaskResult{Data: map[string]any{"action": "reset_connection"}} + require.NoError(t, e.handleMetaAction(context.Background(), "host1", []string{"host1", "host2"}, nil, result)) + + _, ok := e.clients["host1"] + assert.False(t, ok) + _, ok = e.clients["host2"] + assert.True(t, ok) + assert.True(t, mock.closed) +} + // ============================================================ // Tests for handleLookup (0% coverage) // ============================================================ diff --git a/mock_ssh_test.go b/mock_ssh_test.go index 49cd08b..3261303 100644 --- a/mock_ssh_test.go +++ b/mock_ssh_test.go @@ -36,6 +36,9 @@ type MockSSHClient struct { becomeUser string becomePass string + // Lifecycle tracking + closed bool + // Execution log: every command that was executed executed []executedCommand @@ -266,6 +269,9 @@ func (m *MockSSHClient) SetBecome(become bool, user, password string) { // // _ = mock.Close() func (m *MockSSHClient) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.closed = true return nil }