fix: harden DI container — lifecycle safety, Go 1.26 modernisation

- Prevent nil service registration and empty name discovery
- PerformAsync uses sync.WaitGroup.Go() with shutdown guard (atomic.Bool)
- ServiceShutdown respects context deadline, no goroutine leak on cancel
- IPC handler signature mismatch now returns error instead of silent skip
- Runtime.ServiceStartup/ServiceShutdown return error for Wails v3 compat
- Replace manual sort/clone patterns with slices.Sorted, slices.Clone,
  slices.Backward, maps.Keys
- Add async_test.go for PerformAsync coverage

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-03-09 09:11:22 +00:00
parent d08ecb1c0d
commit e2a68fc283
7 changed files with 202 additions and 23 deletions

139
pkg/core/async_test.go Normal file
View file

@ -0,0 +1,139 @@
package core
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCore_PerformAsync_Good(t *testing.T) {
c, _ := New()
var completed atomic.Bool
var resultReceived any
c.RegisterAction(func(c *Core, msg Message) error {
if tc, ok := msg.(ActionTaskCompleted); ok {
resultReceived = tc.Result
completed.Store(true)
}
return nil
})
c.RegisterTask(func(c *Core, task Task) (any, bool, error) {
return "async-result", true, nil
})
taskID := c.PerformAsync(TestTask{})
assert.NotEmpty(t, taskID)
// Wait for completion
assert.Eventually(t, func() bool {
return completed.Load()
}, 1*time.Second, 10*time.Millisecond)
assert.Equal(t, "async-result", resultReceived)
}
func TestCore_PerformAsync_Shutdown(t *testing.T) {
c, _ := New()
_ = c.ServiceShutdown(context.Background())
taskID := c.PerformAsync(TestTask{})
assert.Empty(t, taskID, "PerformAsync should return empty string if already shut down")
}
func TestCore_Progress_Good(t *testing.T) {
c, _ := New()
var progressReceived float64
var messageReceived string
c.RegisterAction(func(c *Core, msg Message) error {
if tp, ok := msg.(ActionTaskProgress); ok {
progressReceived = tp.Progress
messageReceived = tp.Message
}
return nil
})
c.Progress("task-1", 0.5, "halfway", TestTask{})
assert.Equal(t, 0.5, progressReceived)
assert.Equal(t, "halfway", messageReceived)
}
func TestCore_WithService_UnnamedType(t *testing.T) {
// Primitive types have no package path
factory := func(c *Core) (any, error) {
s := "primitive"
return &s, nil
}
_, err := New(WithService(factory))
require.Error(t, err)
assert.Contains(t, err.Error(), "service name could not be discovered")
}
func TestRuntime_ServiceStartup_ErrorPropagation(t *testing.T) {
rt, _ := NewRuntime(nil)
// Register a service that fails startup
errSvc := &MockStartable{err: errors.New("startup failed")}
_ = rt.Core.RegisterService("error-svc", errSvc)
err := rt.ServiceStartup(context.Background(), nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "startup failed")
}
func TestCore_ServiceStartup_ContextCancellation(t *testing.T) {
c, _ := New()
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
s1 := &MockStartable{}
_ = c.RegisterService("s1", s1)
err := c.ServiceStartup(ctx, nil)
assert.Error(t, err)
assert.ErrorIs(t, err, context.Canceled)
assert.False(t, s1.started, "Service should not have started if context was cancelled before loop")
}
func TestCore_ServiceShutdown_ContextCancellation(t *testing.T) {
c, _ := New()
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
s1 := &MockStoppable{}
_ = c.RegisterService("s1", s1)
err := c.ServiceShutdown(ctx)
assert.Error(t, err)
assert.ErrorIs(t, err, context.Canceled)
assert.False(t, s1.stopped, "Service should not have stopped if context was cancelled before loop")
}
type TaskWithIDImpl struct {
id string
}
func (t *TaskWithIDImpl) SetTaskID(id string) { t.id = id }
func (t *TaskWithIDImpl) GetTaskID() string { return t.id }
func TestCore_PerformAsync_InjectsID(t *testing.T) {
c, _ := New()
c.RegisterTask(func(c *Core, t Task) (any, bool, error) { return nil, true, nil })
task := &TaskWithIDImpl{}
taskID := c.PerformAsync(task)
assert.Equal(t, taskID, task.GetTaskID())
}

View file

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"reflect"
"slices"
"strings"
"sync"
)
@ -65,6 +66,9 @@ func WithService(factory func(*Core) (any, error)) Option {
if err != nil {
return fmt.Errorf("core: failed to create service: %w", err)
}
if serviceInstance == nil {
return fmt.Errorf("core: service factory returned nil instance")
}
// --- Service Name Discovery ---
typeOfService := reflect.TypeOf(serviceInstance)
@ -74,6 +78,9 @@ func WithService(factory func(*Core) (any, error)) Option {
pkgPath := typeOfService.PkgPath()
parts := strings.Split(pkgPath, "/")
name := strings.ToLower(parts[len(parts)-1])
if name == "" {
return fmt.Errorf("core: service name could not be discovered for type %T (PkgPath is empty)", serviceInstance)
}
// --- IPC Handler Discovery ---
instanceValue := reflect.ValueOf(serviceInstance)
@ -81,6 +88,8 @@ func WithService(factory func(*Core) (any, error)) Option {
if handlerMethod.IsValid() {
if handler, ok := handlerMethod.Interface().(func(*Core, Message) error); ok {
c.RegisterAction(handler)
} else {
return fmt.Errorf("core: service %q has HandleIPCEvents but wrong signature; expected func(*Core, Message) error", name)
}
}
@ -141,6 +150,9 @@ func (c *Core) ServiceStartup(ctx context.Context, options any) error {
var agg error
for _, s := range startables {
if err := ctx.Err(); err != nil {
return errors.Join(agg, err)
}
if err := s.OnStartup(ctx); err != nil {
agg = errors.Join(agg, err)
}
@ -156,18 +168,36 @@ func (c *Core) ServiceStartup(ctx context.Context, options any) error {
// ServiceShutdown is the entry point for the Core service's shutdown lifecycle.
// It is called by the GUI runtime when the application shuts down.
func (c *Core) ServiceShutdown(ctx context.Context) error {
c.shutdown.Store(true)
var agg error
if err := c.ACTION(ActionServiceShutdown{}); err != nil {
agg = errors.Join(agg, err)
}
stoppables := c.svc.getStoppables()
for i := len(stoppables) - 1; i >= 0; i-- {
if err := stoppables[i].OnShutdown(ctx); err != nil {
for _, s := range slices.Backward(stoppables) {
if err := ctx.Err(); err != nil {
agg = errors.Join(agg, err)
break // don't return — must still wait for background tasks below
}
if err := s.OnShutdown(ctx); err != nil {
agg = errors.Join(agg, err)
}
}
// Wait for background tasks (PerformAsync), respecting context deadline.
done := make(chan struct{})
go func() {
c.wg.Wait()
close(done)
}()
select {
case <-done:
case <-ctx.Done():
agg = errors.Join(agg, ctx.Err())
}
return agg
}
@ -209,6 +239,10 @@ func (c *Core) PERFORM(t Task) (any, bool, error) {
// It returns a unique task ID that can be used to track the task's progress.
// The result of the task will be broadcasted via an ActionTaskCompleted message.
func (c *Core) PerformAsync(t Task) string {
if c.shutdown.Load() {
return ""
}
taskID := fmt.Sprintf("task-%d", c.taskIDCounter.Add(1))
// If the task supports it, inject the ID
@ -222,7 +256,7 @@ func (c *Core) PerformAsync(t Task) string {
Task: t,
})
go func() {
c.wg.Go(func() {
result, handled, err := c.PERFORM(t)
if !handled && err == nil {
err = fmt.Errorf("no handler found for task type %T", t)
@ -235,7 +269,7 @@ func (c *Core) PerformAsync(t Task) string {
Result: result,
Error: err,
})
}()
})
return taskID
}

View file

@ -5,6 +5,7 @@ import (
"embed"
goio "io"
"slices"
"sync"
"sync/atomic"
)
@ -85,6 +86,8 @@ type Core struct {
bus *messageBus
taskIDCounter atomic.Uint64
wg sync.WaitGroup
shutdown atomic.Bool
}
// Config provides access to application configuration.

View file

@ -2,6 +2,7 @@ package core
import (
"errors"
"slices"
"sync"
)
@ -28,7 +29,7 @@ func newMessageBus(c *Core) *messageBus {
// action dispatches a message to all registered IPC handlers.
func (b *messageBus) action(msg Message) error {
b.ipcMu.RLock()
handlers := append([]func(*Core, Message) error(nil), b.ipcHandlers...)
handlers := slices.Clone(b.ipcHandlers)
b.ipcMu.RUnlock()
var agg error
@ -57,7 +58,7 @@ func (b *messageBus) registerActions(handlers ...func(*Core, Message) error) {
// query dispatches a query to handlers until one responds.
func (b *messageBus) query(q Query) (any, bool, error) {
b.queryMu.RLock()
handlers := append([]QueryHandler(nil), b.queryHandlers...)
handlers := slices.Clone(b.queryHandlers)
b.queryMu.RUnlock()
for _, h := range handlers {
@ -72,7 +73,7 @@ func (b *messageBus) query(q Query) (any, bool, error) {
// queryAll dispatches a query to all handlers and collects all responses.
func (b *messageBus) queryAll(q Query) ([]any, error) {
b.queryMu.RLock()
handlers := append([]QueryHandler(nil), b.queryHandlers...)
handlers := slices.Clone(b.queryHandlers)
b.queryMu.RUnlock()
var results []any
@ -99,7 +100,7 @@ func (b *messageBus) registerQuery(handler QueryHandler) {
// perform dispatches a task to handlers until one executes it.
func (b *messageBus) perform(t Task) (any, bool, error) {
b.taskMu.RLock()
handlers := append([]TaskHandler(nil), b.taskHandlers...)
handlers := slices.Clone(b.taskHandlers)
b.taskMu.RUnlock()
for _, h := range handlers {

View file

@ -3,7 +3,8 @@ package core
import (
"context"
"fmt"
"sort"
"maps"
"slices"
)
// ServiceRuntime is a helper struct embedded in services to provide access to the core application.
@ -58,14 +59,13 @@ func NewWithFactories(app any, factories map[string]ServiceFactory) (*Runtime, e
WithApp(app),
}
names := make([]string, 0, len(factories))
for name := range factories {
names = append(names, name)
}
sort.Strings(names)
names := slices.Sorted(maps.Keys(factories))
for _, name := range names {
factory := factories[name]
if factory == nil {
return nil, fmt.Errorf("failed to create service %s: factory is nil", name)
}
svc, err := factory()
if err != nil {
return nil, fmt.Errorf("failed to create service %s: %w", name, err)
@ -99,14 +99,15 @@ func (r *Runtime) ServiceName() string {
// ServiceStartup is called by the GUI runtime at application startup.
// This is where the Core's startup lifecycle is initiated.
func (r *Runtime) ServiceStartup(ctx context.Context, options any) {
_ = r.Core.ServiceStartup(ctx, options)
func (r *Runtime) ServiceStartup(ctx context.Context, options any) error {
return r.Core.ServiceStartup(ctx, options)
}
// ServiceShutdown is called by the GUI runtime at application shutdown.
// This is where the Core's shutdown lifecycle is initiated.
func (r *Runtime) ServiceShutdown(ctx context.Context) {
func (r *Runtime) ServiceShutdown(ctx context.Context) error {
if r.Core != nil {
_ = r.Core.ServiceShutdown(ctx)
return r.Core.ServiceShutdown(ctx)
}
return nil
}

View file

@ -88,9 +88,9 @@ func TestNewWithFactories_Ugly(t *testing.T) {
factories := map[string]ServiceFactory{
"test": nil,
}
assert.Panics(t, func() {
_, _ = NewWithFactories(nil, factories)
})
_, err := NewWithFactories(nil, factories)
assert.Error(t, err)
assert.Contains(t, err.Error(), "factory is nil")
}
func TestRuntime_Lifecycle_Good(t *testing.T) {

View file

@ -3,6 +3,7 @@ package core
import (
"errors"
"fmt"
"slices"
"sync"
)
@ -81,7 +82,7 @@ func (m *serviceManager) applyLock() {
// getStartables returns a snapshot copy of the startables slice.
func (m *serviceManager) getStartables() []Startable {
m.mu.RLock()
out := append([]Startable(nil), m.startables...)
out := slices.Clone(m.startables)
m.mu.RUnlock()
return out
}
@ -89,7 +90,7 @@ func (m *serviceManager) getStartables() []Startable {
// getStoppables returns a snapshot copy of the stoppables slice.
func (m *serviceManager) getStoppables() []Stoppable {
m.mu.RLock()
out := append([]Stoppable(nil), m.stoppables...)
out := slices.Clone(m.stoppables)
m.mu.RUnlock()
return out
}