diff --git a/modules.go b/modules.go index 500e816..1517339 100644 --- a/modules.go +++ b/modules.go @@ -8,6 +8,7 @@ import ( "path/filepath" "sort" "strconv" + "time" coreio "dappco.re/go/core/io" coreerr "dappco.re/go/core/log" @@ -1129,53 +1130,38 @@ func (e *Executor) moduleGroupBy(host string, args map[string]any) (*TaskResult, } func (e *Executor) modulePause(ctx context.Context, args map[string]any) (*TaskResult, error) { - seconds := 0 + duration := time.Duration(0) if s, ok := args["seconds"].(int); ok { - seconds = s + duration += time.Duration(s) * time.Second } if s, ok := args["seconds"].(string); ok { - seconds, _ = strconv.Atoi(s) + if seconds, err := strconv.Atoi(s); err == nil { + duration += time.Duration(seconds) * time.Second + } + } + if m, ok := args["minutes"].(int); ok { + duration += time.Duration(m) * time.Minute + } + if s, ok := args["minutes"].(string); ok { + if minutes, err := strconv.Atoi(s); err == nil { + duration += time.Duration(minutes) * time.Minute + } } - if seconds > 0 { + if duration > 0 { + timer := time.NewTimer(duration) + defer timer.Stop() + select { case <-ctx.Done(): return nil, ctx.Err() - case <-ctxSleep(ctx, seconds): + case <-timer.C: } } return &TaskResult{Changed: false}, nil } -func ctxSleep(ctx context.Context, seconds int) <-chan struct{} { - ch := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - case <-sleepChan(seconds): - } - close(ch) - }() - return ch -} - -func sleepChan(seconds int) <-chan struct{} { - ch := make(chan struct{}) - go func() { - for range seconds { - select { - case <-ch: - return - default: - // Sleep 1 second at a time - } - } - close(ch) - }() - return ch -} - func normalizeStringList(value any) []string { switch v := value.(type) { case nil: diff --git a/modules_adv_test.go b/modules_adv_test.go index 0d3bab5..6391178 100644 --- a/modules_adv_test.go +++ b/modules_adv_test.go @@ -1,7 +1,9 @@ package ansible import ( + "context" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -702,6 +704,23 @@ func TestModulesAdv_ModuleUnarchive_Bad_LocalFileNotFound(t *testing.T) { assert.Contains(t, err.Error(), "read src") } +// --- pause module --- + +func TestModulesAdv_ModulePause_Good_WaitsForSeconds(t *testing.T) { + e := NewExecutor("/tmp") + + start := time.Now() + result, err := e.modulePause(context.Background(), map[string]any{ + "seconds": 1, + }) + elapsed := time.Since(start) + + require.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.Changed) + assert.GreaterOrEqual(t, elapsed, 900*time.Millisecond) +} + // --- include_vars module --- func TestModulesAdv_ModuleIncludeVars_Good_LoadSingleFile(t *testing.T) {