diff --git a/pkg/core/async_test.go b/pkg/core/async_test.go new file mode 100644 index 0000000..f29ff9e --- /dev/null +++ b/pkg/core/async_test.go @@ -0,0 +1,139 @@ +package core + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCore_PerformAsync_Good(t *testing.T) { + c, _ := New() + + var completed atomic.Bool + var resultReceived any + + c.RegisterAction(func(c *Core, msg Message) error { + if tc, ok := msg.(ActionTaskCompleted); ok { + resultReceived = tc.Result + completed.Store(true) + } + return nil + }) + + c.RegisterTask(func(c *Core, task Task) (any, bool, error) { + return "async-result", true, nil + }) + + taskID := c.PerformAsync(TestTask{}) + assert.NotEmpty(t, taskID) + + // Wait for completion + assert.Eventually(t, func() bool { + return completed.Load() + }, 1*time.Second, 10*time.Millisecond) + + assert.Equal(t, "async-result", resultReceived) +} + +func TestCore_PerformAsync_Shutdown(t *testing.T) { + c, _ := New() + _ = c.ServiceShutdown(context.Background()) + + taskID := c.PerformAsync(TestTask{}) + assert.Empty(t, taskID, "PerformAsync should return empty string if already shut down") +} + +func TestCore_Progress_Good(t *testing.T) { + c, _ := New() + + var progressReceived float64 + var messageReceived string + + c.RegisterAction(func(c *Core, msg Message) error { + if tp, ok := msg.(ActionTaskProgress); ok { + progressReceived = tp.Progress + messageReceived = tp.Message + } + return nil + }) + + c.Progress("task-1", 0.5, "halfway", TestTask{}) + + assert.Equal(t, 0.5, progressReceived) + assert.Equal(t, "halfway", messageReceived) +} + +func TestCore_WithService_UnnamedType(t *testing.T) { + // Primitive types have no package path + factory := func(c *Core) (any, error) { + s := "primitive" + return &s, nil + } + + _, err := New(WithService(factory)) + require.Error(t, err) + assert.Contains(t, err.Error(), "service name could not be discovered") +} + +func TestRuntime_ServiceStartup_ErrorPropagation(t *testing.T) { + rt, _ := NewRuntime(nil) + + // Register a service that fails startup + errSvc := &MockStartable{err: errors.New("startup failed")} + _ = rt.Core.RegisterService("error-svc", errSvc) + + err := rt.ServiceStartup(context.Background(), nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "startup failed") +} + +func TestCore_ServiceStartup_ContextCancellation(t *testing.T) { + c, _ := New() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + s1 := &MockStartable{} + _ = c.RegisterService("s1", s1) + + err := c.ServiceStartup(ctx, nil) + assert.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + assert.False(t, s1.started, "Service should not have started if context was cancelled before loop") +} + +func TestCore_ServiceShutdown_ContextCancellation(t *testing.T) { + c, _ := New() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + s1 := &MockStoppable{} + _ = c.RegisterService("s1", s1) + + err := c.ServiceShutdown(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + assert.False(t, s1.stopped, "Service should not have stopped if context was cancelled before loop") +} + +type TaskWithIDImpl struct { + id string +} +func (t *TaskWithIDImpl) SetTaskID(id string) { t.id = id } +func (t *TaskWithIDImpl) GetTaskID() string { return t.id } + +func TestCore_PerformAsync_InjectsID(t *testing.T) { + c, _ := New() + c.RegisterTask(func(c *Core, t Task) (any, bool, error) { return nil, true, nil }) + + task := &TaskWithIDImpl{} + taskID := c.PerformAsync(task) + + assert.Equal(t, taskID, task.GetTaskID()) +} diff --git a/pkg/core/core.go b/pkg/core/core.go index a91d93c..eb7c64b 100644 --- a/pkg/core/core.go +++ b/pkg/core/core.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "reflect" + "slices" "strings" "sync" ) @@ -65,6 +66,9 @@ func WithService(factory func(*Core) (any, error)) Option { if err != nil { return fmt.Errorf("core: failed to create service: %w", err) } + if serviceInstance == nil { + return fmt.Errorf("core: service factory returned nil instance") + } // --- Service Name Discovery --- typeOfService := reflect.TypeOf(serviceInstance) @@ -74,6 +78,9 @@ func WithService(factory func(*Core) (any, error)) Option { pkgPath := typeOfService.PkgPath() parts := strings.Split(pkgPath, "/") name := strings.ToLower(parts[len(parts)-1]) + if name == "" { + return fmt.Errorf("core: service name could not be discovered for type %T (PkgPath is empty)", serviceInstance) + } // --- IPC Handler Discovery --- instanceValue := reflect.ValueOf(serviceInstance) @@ -81,6 +88,8 @@ func WithService(factory func(*Core) (any, error)) Option { if handlerMethod.IsValid() { if handler, ok := handlerMethod.Interface().(func(*Core, Message) error); ok { c.RegisterAction(handler) + } else { + return fmt.Errorf("core: service %q has HandleIPCEvents but wrong signature; expected func(*Core, Message) error", name) } } @@ -141,6 +150,9 @@ func (c *Core) ServiceStartup(ctx context.Context, options any) error { var agg error for _, s := range startables { + if err := ctx.Err(); err != nil { + return errors.Join(agg, err) + } if err := s.OnStartup(ctx); err != nil { agg = errors.Join(agg, err) } @@ -156,18 +168,36 @@ func (c *Core) ServiceStartup(ctx context.Context, options any) error { // ServiceShutdown is the entry point for the Core service's shutdown lifecycle. // It is called by the GUI runtime when the application shuts down. func (c *Core) ServiceShutdown(ctx context.Context) error { + c.shutdown.Store(true) + var agg error if err := c.ACTION(ActionServiceShutdown{}); err != nil { agg = errors.Join(agg, err) } stoppables := c.svc.getStoppables() - for i := len(stoppables) - 1; i >= 0; i-- { - if err := stoppables[i].OnShutdown(ctx); err != nil { + for _, s := range slices.Backward(stoppables) { + if err := ctx.Err(); err != nil { + agg = errors.Join(agg, err) + break // don't return — must still wait for background tasks below + } + if err := s.OnShutdown(ctx); err != nil { agg = errors.Join(agg, err) } } + // Wait for background tasks (PerformAsync), respecting context deadline. + done := make(chan struct{}) + go func() { + c.wg.Wait() + close(done) + }() + select { + case <-done: + case <-ctx.Done(): + agg = errors.Join(agg, ctx.Err()) + } + return agg } @@ -209,6 +239,10 @@ func (c *Core) PERFORM(t Task) (any, bool, error) { // It returns a unique task ID that can be used to track the task's progress. // The result of the task will be broadcasted via an ActionTaskCompleted message. func (c *Core) PerformAsync(t Task) string { + if c.shutdown.Load() { + return "" + } + taskID := fmt.Sprintf("task-%d", c.taskIDCounter.Add(1)) // If the task supports it, inject the ID @@ -222,7 +256,7 @@ func (c *Core) PerformAsync(t Task) string { Task: t, }) - go func() { + c.wg.Go(func() { result, handled, err := c.PERFORM(t) if !handled && err == nil { err = fmt.Errorf("no handler found for task type %T", t) @@ -235,7 +269,7 @@ func (c *Core) PerformAsync(t Task) string { Result: result, Error: err, }) - }() + }) return taskID } diff --git a/pkg/core/interfaces.go b/pkg/core/interfaces.go index ee74b47..036b4b2 100644 --- a/pkg/core/interfaces.go +++ b/pkg/core/interfaces.go @@ -5,6 +5,7 @@ import ( "embed" goio "io" "slices" + "sync" "sync/atomic" ) @@ -85,6 +86,8 @@ type Core struct { bus *messageBus taskIDCounter atomic.Uint64 + wg sync.WaitGroup + shutdown atomic.Bool } // Config provides access to application configuration. diff --git a/pkg/core/message_bus.go b/pkg/core/message_bus.go index 457ced2..4f81e77 100644 --- a/pkg/core/message_bus.go +++ b/pkg/core/message_bus.go @@ -2,6 +2,7 @@ package core import ( "errors" + "slices" "sync" ) @@ -28,7 +29,7 @@ func newMessageBus(c *Core) *messageBus { // action dispatches a message to all registered IPC handlers. func (b *messageBus) action(msg Message) error { b.ipcMu.RLock() - handlers := append([]func(*Core, Message) error(nil), b.ipcHandlers...) + handlers := slices.Clone(b.ipcHandlers) b.ipcMu.RUnlock() var agg error @@ -57,7 +58,7 @@ func (b *messageBus) registerActions(handlers ...func(*Core, Message) error) { // query dispatches a query to handlers until one responds. func (b *messageBus) query(q Query) (any, bool, error) { b.queryMu.RLock() - handlers := append([]QueryHandler(nil), b.queryHandlers...) + handlers := slices.Clone(b.queryHandlers) b.queryMu.RUnlock() for _, h := range handlers { @@ -72,7 +73,7 @@ func (b *messageBus) query(q Query) (any, bool, error) { // queryAll dispatches a query to all handlers and collects all responses. func (b *messageBus) queryAll(q Query) ([]any, error) { b.queryMu.RLock() - handlers := append([]QueryHandler(nil), b.queryHandlers...) + handlers := slices.Clone(b.queryHandlers) b.queryMu.RUnlock() var results []any @@ -99,7 +100,7 @@ func (b *messageBus) registerQuery(handler QueryHandler) { // perform dispatches a task to handlers until one executes it. func (b *messageBus) perform(t Task) (any, bool, error) { b.taskMu.RLock() - handlers := append([]TaskHandler(nil), b.taskHandlers...) + handlers := slices.Clone(b.taskHandlers) b.taskMu.RUnlock() for _, h := range handlers { diff --git a/pkg/core/runtime_pkg.go b/pkg/core/runtime_pkg.go index 0cb941d..7071e9c 100644 --- a/pkg/core/runtime_pkg.go +++ b/pkg/core/runtime_pkg.go @@ -3,7 +3,8 @@ package core import ( "context" "fmt" - "sort" + "maps" + "slices" ) // ServiceRuntime is a helper struct embedded in services to provide access to the core application. @@ -58,14 +59,13 @@ func NewWithFactories(app any, factories map[string]ServiceFactory) (*Runtime, e WithApp(app), } - names := make([]string, 0, len(factories)) - for name := range factories { - names = append(names, name) - } - sort.Strings(names) + names := slices.Sorted(maps.Keys(factories)) for _, name := range names { factory := factories[name] + if factory == nil { + return nil, fmt.Errorf("failed to create service %s: factory is nil", name) + } svc, err := factory() if err != nil { return nil, fmt.Errorf("failed to create service %s: %w", name, err) @@ -99,14 +99,15 @@ func (r *Runtime) ServiceName() string { // ServiceStartup is called by the GUI runtime at application startup. // This is where the Core's startup lifecycle is initiated. -func (r *Runtime) ServiceStartup(ctx context.Context, options any) { - _ = r.Core.ServiceStartup(ctx, options) +func (r *Runtime) ServiceStartup(ctx context.Context, options any) error { + return r.Core.ServiceStartup(ctx, options) } // ServiceShutdown is called by the GUI runtime at application shutdown. // This is where the Core's shutdown lifecycle is initiated. -func (r *Runtime) ServiceShutdown(ctx context.Context) { +func (r *Runtime) ServiceShutdown(ctx context.Context) error { if r.Core != nil { - _ = r.Core.ServiceShutdown(ctx) + return r.Core.ServiceShutdown(ctx) } + return nil } diff --git a/pkg/core/runtime_pkg_test.go b/pkg/core/runtime_pkg_test.go index 175b569..bc9b388 100644 --- a/pkg/core/runtime_pkg_test.go +++ b/pkg/core/runtime_pkg_test.go @@ -88,9 +88,9 @@ func TestNewWithFactories_Ugly(t *testing.T) { factories := map[string]ServiceFactory{ "test": nil, } - assert.Panics(t, func() { - _, _ = NewWithFactories(nil, factories) - }) + _, err := NewWithFactories(nil, factories) + assert.Error(t, err) + assert.Contains(t, err.Error(), "factory is nil") } func TestRuntime_Lifecycle_Good(t *testing.T) { diff --git a/pkg/core/service_manager.go b/pkg/core/service_manager.go index 9c4b0bd..0105cf7 100644 --- a/pkg/core/service_manager.go +++ b/pkg/core/service_manager.go @@ -3,6 +3,7 @@ package core import ( "errors" "fmt" + "slices" "sync" ) @@ -81,7 +82,7 @@ func (m *serviceManager) applyLock() { // getStartables returns a snapshot copy of the startables slice. func (m *serviceManager) getStartables() []Startable { m.mu.RLock() - out := append([]Startable(nil), m.startables...) + out := slices.Clone(m.startables) m.mu.RUnlock() return out } @@ -89,7 +90,7 @@ func (m *serviceManager) getStartables() []Startable { // getStoppables returns a snapshot copy of the stoppables slice. func (m *serviceManager) getStoppables() []Stoppable { m.mu.RLock() - out := append([]Stoppable(nil), m.stoppables...) + out := slices.Clone(m.stoppables) m.mu.RUnlock() return out }