diff --git a/executor.go b/executor.go index 913d747..7a20b7c 100644 --- a/executor.go +++ b/executor.go @@ -178,12 +178,8 @@ func (e *Executor) runPlay(ctx context.Context, play *Play) error { } // Run notified handlers - for _, handler := range play.Handlers { - if e.notified[handler.Name] { - if err := e.runTaskOnHosts(ctx, hosts, &handler, play); err != nil { - return err - } - } + if err := e.runNotifiedHandlers(ctx, hosts, play); err != nil { + return err } return nil @@ -262,7 +258,7 @@ func (e *Executor) runTaskOnHosts(ctx context.Context, hosts []string, task *Tas } for _, host := range hosts { - if err := e.runTaskOnHost(ctx, host, task, play); err != nil { + if err := e.runTaskOnHost(ctx, host, hosts, task, play); err != nil { if !task.IgnoreErrors { return err } @@ -312,7 +308,7 @@ func (e *Executor) copyRegisteredResultToHosts(hosts []string, sourceHost, regis } // runTaskOnHost runs a task on a single host. -func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, play *Play) error { +func (e *Executor) runTaskOnHost(ctx context.Context, host string, hosts []string, task *Task, play *Play) error { start := time.Now() if e.OnTaskStart != nil { @@ -384,6 +380,12 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, p e.OnTaskEnd(host, task, result) } + if NormalizeModule(task.Module) == "ansible.builtin.meta" { + if err := e.handleMetaAction(ctx, hosts, play, result); err != nil { + return err + } + } + if result.Failed && !task.IgnoreErrors { return coreerr.E("Executor.runTaskOnHost", "task failed: "+result.Msg, nil) } @@ -1368,6 +1370,52 @@ func (e *Executor) handleNotify(notify any) { } } +// runNotifiedHandlers executes any handlers that have been notified and then +// clears the notification state for those handlers. +func (e *Executor) runNotifiedHandlers(ctx context.Context, hosts []string, play *Play) error { + if play == nil || len(play.Handlers) == 0 { + return nil + } + + pending := make(map[string]bool) + for name, notified := range e.notified { + if notified { + pending[name] = true + e.notified[name] = false + } + } + + if len(pending) == 0 { + return nil + } + + for _, handler := range play.Handlers { + if pending[handler.Name] { + if err := e.runTaskOnHosts(ctx, hosts, &handler, play); err != nil { + return err + } + } + } + + return nil +} + +// handleMetaAction applies module meta side effects after the task result has +// been recorded and callbacks have fired. +func (e *Executor) handleMetaAction(ctx context.Context, hosts []string, play *Play, result *TaskResult) error { + if result == nil || result.Data == nil { + return nil + } + + action, _ := result.Data["action"].(string) + switch action { + case "flush_handlers": + return e.runNotifiedHandlers(ctx, hosts, play) + default: + return nil + } +} + // Close closes all SSH connections. // // Example: diff --git a/executor_extra_test.go b/executor_extra_test.go index 22aadad..905db7e 100644 --- a/executor_extra_test.go +++ b/executor_extra_test.go @@ -308,6 +308,8 @@ func TestExecutorExtra_ModuleMeta_Good(t *testing.T) { require.NoError(t, err) assert.False(t, result.Changed) + require.NotNil(t, result.Data) + assert.Equal(t, "flush_handlers", result.Data["action"]) } // ============================================================ diff --git a/executor_test.go b/executor_test.go index 8bdb0d2..b575346 100644 --- a/executor_test.go +++ b/executor_test.go @@ -271,7 +271,7 @@ func TestExecutor_RunTaskOnHost_Good_CheckModeSkipsMutatingTask(t *testing.T) { ended = result } - err := e.runTaskOnHost(context.Background(), "host1", task, &Play{}) + err := e.runTaskOnHost(context.Background(), "host1", []string{"host1"}, task, &Play{}) require.NoError(t, err) require.NotNil(t, ended) @@ -289,6 +289,54 @@ func TestExecutor_NormalizeConditions_Good_String(t *testing.T) { assert.Equal(t, []string{"my_var is defined"}, result) } +// --- meta flush handlers --- + +func TestExecutor_RunTaskOnHosts_Good_MetaFlushesHandlers(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "host1": {}, + }, + }, + }) + e.clients["host1"] = &SSHClient{} + + var executed []string + e.OnTaskEnd = func(_ string, task *Task, _ *TaskResult) { + executed = append(executed, task.Name) + } + + play := &Play{ + Handlers: []Task{ + { + Name: "restart app", + Module: "debug", + Args: map[string]any{"msg": "handler"}, + }, + }, + } + + notifyTask := &Task{ + Name: "change config", + Module: "set_fact", + Args: map[string]any{"restart_required": true}, + Notify: "restart app", + } + require.NoError(t, e.runTaskOnHosts(context.Background(), []string{"host1"}, notifyTask, play)) + assert.True(t, e.notified["restart app"]) + + metaTask := &Task{ + Name: "flush handlers", + Module: "meta", + Args: map[string]any{"_raw_params": "flush_handlers"}, + } + require.NoError(t, e.runTaskOnHosts(context.Background(), []string{"host1"}, metaTask, play)) + + assert.False(t, e.notified["restart app"]) + assert.Equal(t, []string{"change config", "flush handlers", "restart app"}, executed) +} + func TestExecutor_NormalizeConditions_Good_StringSlice(t *testing.T) { result := normalizeConditions([]string{"cond1", "cond2"}) assert.Equal(t, []string{"cond1", "cond2"}, result) diff --git a/modules.go b/modules.go index 348e18d..10c1011 100644 --- a/modules.go +++ b/modules.go @@ -1761,8 +1761,19 @@ func mergeVars(dst, src map[string]any, mergeMaps bool) { func (e *Executor) moduleMeta(args map[string]any) (*TaskResult, error) { // meta module controls play execution - // Most actions are no-ops for us - return &TaskResult{Changed: false}, nil + // Most actions are no-ops for us, but we preserve the requested action so + // the executor can apply side effects such as handler flushing. + action := getStringArg(args, "_raw_params", "") + if action == "" { + action = getStringArg(args, "free_form", "") + } + + result := &TaskResult{Changed: false} + if action != "" { + result.Data = map[string]any{"action": action} + } + + return result, nil } func (e *Executor) moduleSetup(ctx context.Context, host string, client sshFactsRunner) (*TaskResult, error) {