diff --git a/daemon.go b/daemon.go index 5adbb26..6199fc6 100644 --- a/daemon.go +++ b/daemon.go @@ -144,6 +144,10 @@ func (d *Daemon) Start() error { // // if err := daemon.Run(ctx); err != nil { return err } func (d *Daemon) Run(ctx context.Context) error { + if ctx == nil { + return coreerr.E("Daemon.Run", "daemon context is required", ErrDaemonContextRequired) + } + d.mu.Lock() if !d.running { d.mu.Unlock() @@ -243,3 +247,6 @@ func (d *Daemon) HealthAddr() string { } return "" } + +// ErrDaemonContextRequired is returned when Run is called without a context. +var ErrDaemonContextRequired = coreerr.E("", "daemon context is required", nil) diff --git a/daemon_test.go b/daemon_test.go index 2c12333..57c2cc6 100644 --- a/daemon_test.go +++ b/daemon_test.go @@ -221,6 +221,14 @@ func TestDaemon_RunWithoutStartFails(t *testing.T) { assert.Contains(t, err.Error(), "not started") } +func TestDaemon_RunNilContextFails(t *testing.T) { + d := NewDaemon(DaemonOptions{}) + + err := d.Run(nil) + require.Error(t, err) + assert.ErrorIs(t, err, ErrDaemonContextRequired) +} + func TestDaemon_SetReady(t *testing.T) { d := NewDaemon(DaemonOptions{ HealthAddr: "127.0.0.1:0", diff --git a/runner.go b/runner.go index 710abc9..e7b045a 100644 --- a/runner.go +++ b/runner.go @@ -20,6 +20,9 @@ var ErrRunnerNoService = coreerr.E("", "runner service is nil", nil) // ErrRunnerInvalidSpecName is returned when a RunSpec name is empty or duplicated. var ErrRunnerInvalidSpecName = coreerr.E("", "runner spec names must be non-empty and unique", nil) +// ErrRunnerContextRequired is returned when a runner method is called without a context. +var ErrRunnerContextRequired = coreerr.E("", "runner context is required", nil) + // NewRunner creates a runner for the given service. // // Example: @@ -98,6 +101,9 @@ func (r *Runner) RunAll(ctx context.Context, specs []RunSpec) (*RunAllResult, er if err := r.ensureService(); err != nil { return nil, err } + if err := ensureRunnerContext(ctx); err != nil { + return nil, err + } if err := validateSpecs(specs); err != nil { return nil, err } @@ -288,6 +294,9 @@ func (r *Runner) RunSequential(ctx context.Context, specs []RunSpec) (*RunAllRes if err := r.ensureService(); err != nil { return nil, err } + if err := ensureRunnerContext(ctx); err != nil { + return nil, err + } if err := validateSpecs(specs); err != nil { return nil, err } @@ -339,6 +348,9 @@ func (r *Runner) RunParallel(ctx context.Context, specs []RunSpec) (*RunAllResul if err := r.ensureService(); err != nil { return nil, err } + if err := ensureRunnerContext(ctx); err != nil { + return nil, err + } if err := validateSpecs(specs); err != nil { return nil, err } @@ -391,6 +403,13 @@ func validateSpecs(specs []RunSpec) error { return nil } +func ensureRunnerContext(ctx context.Context) error { + if ctx == nil { + return coreerr.E("Runner.ensureRunnerContext", "runner context is required", ErrRunnerContextRequired) + } + return nil +} + func skippedRunResult(op string, spec RunSpec, err error) RunResult { result := RunResult{ Name: spec.Name, diff --git a/runner_test.go b/runner_test.go index 43705a8..94dbf50 100644 --- a/runner_test.go +++ b/runner_test.go @@ -294,6 +294,22 @@ func TestRunner_NilService(t *testing.T) { assert.ErrorIs(t, err, ErrRunnerNoService) } +func TestRunner_NilContext(t *testing.T) { + runner := newTestRunner(t) + + _, err := runner.RunAll(nil, nil) + require.Error(t, err) + assert.ErrorIs(t, err, ErrRunnerContextRequired) + + _, err = runner.RunSequential(nil, nil) + require.Error(t, err) + assert.ErrorIs(t, err, ErrRunnerContextRequired) + + _, err = runner.RunParallel(nil, nil) + require.Error(t, err) + assert.ErrorIs(t, err, ErrRunnerContextRequired) +} + func TestRunner_InvalidSpecNames(t *testing.T) { runner := newTestRunner(t)