From 9ba6cdc5a437b6adad240c8be9df6018e3ee33ae Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 21:15:11 +0000 Subject: [PATCH] feat(ansible): refresh inventory on meta action --- executor.go | 31 +++++++++++++++++++++++++++++++ executor_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/executor.go b/executor.go index 3385782..79580f4 100644 --- a/executor.go +++ b/executor.go @@ -57,6 +57,7 @@ func (c *environmentSSHClient) RunScript(ctx context.Context, script string) (st type Executor struct { parser *Parser inventory *Inventory + inventoryPath string vars map[string]any facts map[string]*Facts results map[string]map[string]*TaskResult // host -> register_name -> result @@ -110,7 +111,10 @@ func (e *Executor) SetInventory(path string) error { if err != nil { return err } + e.mu.Lock() + e.inventoryPath = path e.inventory = inv + e.mu.Unlock() return nil } @@ -120,7 +124,10 @@ func (e *Executor) SetInventory(path string) error { // // exec.SetInventoryDirect(&Inventory{All: &InventoryGroup{}}) func (e *Executor) SetInventoryDirect(inv *Inventory) { + e.mu.Lock() + e.inventoryPath = "" e.inventory = inv + e.mu.Unlock() } // SetVar sets a variable. @@ -1880,6 +1887,8 @@ func (e *Executor) handleMetaAction(ctx context.Context, host string, hosts []st case "clear_host_errors": e.clearHostErrors() return nil + case "refresh_inventory": + return e.refreshInventory() case "end_play": return errEndPlay case "end_host": @@ -1912,6 +1921,28 @@ func (e *Executor) clearHostErrors() { e.batchFailedHosts = make(map[string]bool) } +// refreshInventory reloads inventory from the last configured source. +func (e *Executor) refreshInventory() error { + e.mu.RLock() + path := e.inventoryPath + e.mu.RUnlock() + + if path == "" { + return nil + } + + inv, err := e.parser.ParseInventory(path) + if err != nil { + return coreerr.E("Executor.refreshInventory", "reload inventory", err) + } + + e.mu.Lock() + e.inventory = inv + e.clients = make(map[string]sshExecutorClient) + e.mu.Unlock() + return nil +} + // markHostEnded records that a host should be skipped for the rest of the play. func (e *Executor) markHostEnded(host string) { if host == "" { diff --git a/executor_test.go b/executor_test.go index 09241c1..7523464 100644 --- a/executor_test.go +++ b/executor_test.go @@ -2,6 +2,8 @@ package ansible import ( "context" + "os" + "path/filepath" "testing" "time" @@ -703,6 +705,33 @@ func TestExecutor_HandleMetaAction_Good_ClearHostErrors(t *testing.T) { assert.Empty(t, e.batchFailedHosts) } +func TestExecutor_HandleMetaAction_Good_RefreshInventory(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "inventory.yml") + + initial := []byte("all:\n hosts:\n web1:\n ansible_host: 10.0.0.1\n") + updated := []byte("all:\n hosts:\n web1:\n ansible_host: 10.0.0.2\n") + + require.NoError(t, os.WriteFile(path, initial, 0644)) + + e := NewExecutor("/tmp") + require.NoError(t, e.SetInventory(path)) + e.clients["web1"] = &SSHClient{} + + require.NoError(t, os.WriteFile(path, updated, 0644)) + + result := &TaskResult{ + Data: map[string]any{"action": "refresh_inventory"}, + } + + require.NoError(t, e.handleMetaAction(context.Background(), "web1", []string{"web1"}, &Play{}, result)) + require.NotNil(t, e.inventory) + require.NotNil(t, e.inventory.All) + require.Contains(t, e.inventory.All.Hosts, "web1") + assert.Equal(t, "10.0.0.2", e.inventory.All.Hosts["web1"].AnsibleHost) + assert.Empty(t, e.clients) +} + func TestExecutor_RunPlay_Good_MetaEndPlayStopsRemainingTasks(t *testing.T) { e := NewExecutor("/tmp") e.SetInventoryDirect(&Inventory{