fix: harden DI container — lifecycle safety, Go 1.26 modernisation
- Prevent nil service registration and empty name discovery - PerformAsync uses sync.WaitGroup.Go() with shutdown guard (atomic.Bool) - ServiceShutdown respects context deadline, no goroutine leak on cancel - IPC handler signature mismatch now returns error instead of silent skip - Runtime.ServiceStartup/ServiceShutdown return error for Wails v3 compat - Replace manual sort/clone patterns with slices.Sorted, slices.Clone, slices.Backward, maps.Keys - Add async_test.go for PerformAsync coverage Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d08ecb1c0d
commit
e2a68fc283
7 changed files with 202 additions and 23 deletions
139
pkg/core/async_test.go
Normal file
139
pkg/core/async_test.go
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCore_PerformAsync_Good(t *testing.T) {
|
||||
c, _ := New()
|
||||
|
||||
var completed atomic.Bool
|
||||
var resultReceived any
|
||||
|
||||
c.RegisterAction(func(c *Core, msg Message) error {
|
||||
if tc, ok := msg.(ActionTaskCompleted); ok {
|
||||
resultReceived = tc.Result
|
||||
completed.Store(true)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
c.RegisterTask(func(c *Core, task Task) (any, bool, error) {
|
||||
return "async-result", true, nil
|
||||
})
|
||||
|
||||
taskID := c.PerformAsync(TestTask{})
|
||||
assert.NotEmpty(t, taskID)
|
||||
|
||||
// Wait for completion
|
||||
assert.Eventually(t, func() bool {
|
||||
return completed.Load()
|
||||
}, 1*time.Second, 10*time.Millisecond)
|
||||
|
||||
assert.Equal(t, "async-result", resultReceived)
|
||||
}
|
||||
|
||||
func TestCore_PerformAsync_Shutdown(t *testing.T) {
|
||||
c, _ := New()
|
||||
_ = c.ServiceShutdown(context.Background())
|
||||
|
||||
taskID := c.PerformAsync(TestTask{})
|
||||
assert.Empty(t, taskID, "PerformAsync should return empty string if already shut down")
|
||||
}
|
||||
|
||||
func TestCore_Progress_Good(t *testing.T) {
|
||||
c, _ := New()
|
||||
|
||||
var progressReceived float64
|
||||
var messageReceived string
|
||||
|
||||
c.RegisterAction(func(c *Core, msg Message) error {
|
||||
if tp, ok := msg.(ActionTaskProgress); ok {
|
||||
progressReceived = tp.Progress
|
||||
messageReceived = tp.Message
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
c.Progress("task-1", 0.5, "halfway", TestTask{})
|
||||
|
||||
assert.Equal(t, 0.5, progressReceived)
|
||||
assert.Equal(t, "halfway", messageReceived)
|
||||
}
|
||||
|
||||
func TestCore_WithService_UnnamedType(t *testing.T) {
|
||||
// Primitive types have no package path
|
||||
factory := func(c *Core) (any, error) {
|
||||
s := "primitive"
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
_, err := New(WithService(factory))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "service name could not be discovered")
|
||||
}
|
||||
|
||||
func TestRuntime_ServiceStartup_ErrorPropagation(t *testing.T) {
|
||||
rt, _ := NewRuntime(nil)
|
||||
|
||||
// Register a service that fails startup
|
||||
errSvc := &MockStartable{err: errors.New("startup failed")}
|
||||
_ = rt.Core.RegisterService("error-svc", errSvc)
|
||||
|
||||
err := rt.ServiceStartup(context.Background(), nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "startup failed")
|
||||
}
|
||||
|
||||
func TestCore_ServiceStartup_ContextCancellation(t *testing.T) {
|
||||
c, _ := New()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
s1 := &MockStartable{}
|
||||
_ = c.RegisterService("s1", s1)
|
||||
|
||||
err := c.ServiceStartup(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
assert.False(t, s1.started, "Service should not have started if context was cancelled before loop")
|
||||
}
|
||||
|
||||
func TestCore_ServiceShutdown_ContextCancellation(t *testing.T) {
|
||||
c, _ := New()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
s1 := &MockStoppable{}
|
||||
_ = c.RegisterService("s1", s1)
|
||||
|
||||
err := c.ServiceShutdown(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
assert.False(t, s1.stopped, "Service should not have stopped if context was cancelled before loop")
|
||||
}
|
||||
|
||||
type TaskWithIDImpl struct {
|
||||
id string
|
||||
}
|
||||
func (t *TaskWithIDImpl) SetTaskID(id string) { t.id = id }
|
||||
func (t *TaskWithIDImpl) GetTaskID() string { return t.id }
|
||||
|
||||
func TestCore_PerformAsync_InjectsID(t *testing.T) {
|
||||
c, _ := New()
|
||||
c.RegisterTask(func(c *Core, t Task) (any, bool, error) { return nil, true, nil })
|
||||
|
||||
task := &TaskWithIDImpl{}
|
||||
taskID := c.PerformAsync(task)
|
||||
|
||||
assert.Equal(t, taskID, task.GetTaskID())
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
|
@ -65,6 +66,9 @@ func WithService(factory func(*Core) (any, error)) Option {
|
|||
if err != nil {
|
||||
return fmt.Errorf("core: failed to create service: %w", err)
|
||||
}
|
||||
if serviceInstance == nil {
|
||||
return fmt.Errorf("core: service factory returned nil instance")
|
||||
}
|
||||
|
||||
// --- Service Name Discovery ---
|
||||
typeOfService := reflect.TypeOf(serviceInstance)
|
||||
|
|
@ -74,6 +78,9 @@ func WithService(factory func(*Core) (any, error)) Option {
|
|||
pkgPath := typeOfService.PkgPath()
|
||||
parts := strings.Split(pkgPath, "/")
|
||||
name := strings.ToLower(parts[len(parts)-1])
|
||||
if name == "" {
|
||||
return fmt.Errorf("core: service name could not be discovered for type %T (PkgPath is empty)", serviceInstance)
|
||||
}
|
||||
|
||||
// --- IPC Handler Discovery ---
|
||||
instanceValue := reflect.ValueOf(serviceInstance)
|
||||
|
|
@ -81,6 +88,8 @@ func WithService(factory func(*Core) (any, error)) Option {
|
|||
if handlerMethod.IsValid() {
|
||||
if handler, ok := handlerMethod.Interface().(func(*Core, Message) error); ok {
|
||||
c.RegisterAction(handler)
|
||||
} else {
|
||||
return fmt.Errorf("core: service %q has HandleIPCEvents but wrong signature; expected func(*Core, Message) error", name)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -141,6 +150,9 @@ func (c *Core) ServiceStartup(ctx context.Context, options any) error {
|
|||
|
||||
var agg error
|
||||
for _, s := range startables {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return errors.Join(agg, err)
|
||||
}
|
||||
if err := s.OnStartup(ctx); err != nil {
|
||||
agg = errors.Join(agg, err)
|
||||
}
|
||||
|
|
@ -156,18 +168,36 @@ func (c *Core) ServiceStartup(ctx context.Context, options any) error {
|
|||
// ServiceShutdown is the entry point for the Core service's shutdown lifecycle.
|
||||
// It is called by the GUI runtime when the application shuts down.
|
||||
func (c *Core) ServiceShutdown(ctx context.Context) error {
|
||||
c.shutdown.Store(true)
|
||||
|
||||
var agg error
|
||||
if err := c.ACTION(ActionServiceShutdown{}); err != nil {
|
||||
agg = errors.Join(agg, err)
|
||||
}
|
||||
|
||||
stoppables := c.svc.getStoppables()
|
||||
for i := len(stoppables) - 1; i >= 0; i-- {
|
||||
if err := stoppables[i].OnShutdown(ctx); err != nil {
|
||||
for _, s := range slices.Backward(stoppables) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
agg = errors.Join(agg, err)
|
||||
break // don't return — must still wait for background tasks below
|
||||
}
|
||||
if err := s.OnShutdown(ctx); err != nil {
|
||||
agg = errors.Join(agg, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for background tasks (PerformAsync), respecting context deadline.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
agg = errors.Join(agg, ctx.Err())
|
||||
}
|
||||
|
||||
return agg
|
||||
}
|
||||
|
||||
|
|
@ -209,6 +239,10 @@ func (c *Core) PERFORM(t Task) (any, bool, error) {
|
|||
// It returns a unique task ID that can be used to track the task's progress.
|
||||
// The result of the task will be broadcasted via an ActionTaskCompleted message.
|
||||
func (c *Core) PerformAsync(t Task) string {
|
||||
if c.shutdown.Load() {
|
||||
return ""
|
||||
}
|
||||
|
||||
taskID := fmt.Sprintf("task-%d", c.taskIDCounter.Add(1))
|
||||
|
||||
// If the task supports it, inject the ID
|
||||
|
|
@ -222,7 +256,7 @@ func (c *Core) PerformAsync(t Task) string {
|
|||
Task: t,
|
||||
})
|
||||
|
||||
go func() {
|
||||
c.wg.Go(func() {
|
||||
result, handled, err := c.PERFORM(t)
|
||||
if !handled && err == nil {
|
||||
err = fmt.Errorf("no handler found for task type %T", t)
|
||||
|
|
@ -235,7 +269,7 @@ func (c *Core) PerformAsync(t Task) string {
|
|||
Result: result,
|
||||
Error: err,
|
||||
})
|
||||
}()
|
||||
})
|
||||
|
||||
return taskID
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"embed"
|
||||
goio "io"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
|
|
@ -85,6 +86,8 @@ type Core struct {
|
|||
bus *messageBus
|
||||
|
||||
taskIDCounter atomic.Uint64
|
||||
wg sync.WaitGroup
|
||||
shutdown atomic.Bool
|
||||
}
|
||||
|
||||
// Config provides access to application configuration.
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package core
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
|
@ -28,7 +29,7 @@ func newMessageBus(c *Core) *messageBus {
|
|||
// action dispatches a message to all registered IPC handlers.
|
||||
func (b *messageBus) action(msg Message) error {
|
||||
b.ipcMu.RLock()
|
||||
handlers := append([]func(*Core, Message) error(nil), b.ipcHandlers...)
|
||||
handlers := slices.Clone(b.ipcHandlers)
|
||||
b.ipcMu.RUnlock()
|
||||
|
||||
var agg error
|
||||
|
|
@ -57,7 +58,7 @@ func (b *messageBus) registerActions(handlers ...func(*Core, Message) error) {
|
|||
// query dispatches a query to handlers until one responds.
|
||||
func (b *messageBus) query(q Query) (any, bool, error) {
|
||||
b.queryMu.RLock()
|
||||
handlers := append([]QueryHandler(nil), b.queryHandlers...)
|
||||
handlers := slices.Clone(b.queryHandlers)
|
||||
b.queryMu.RUnlock()
|
||||
|
||||
for _, h := range handlers {
|
||||
|
|
@ -72,7 +73,7 @@ func (b *messageBus) query(q Query) (any, bool, error) {
|
|||
// queryAll dispatches a query to all handlers and collects all responses.
|
||||
func (b *messageBus) queryAll(q Query) ([]any, error) {
|
||||
b.queryMu.RLock()
|
||||
handlers := append([]QueryHandler(nil), b.queryHandlers...)
|
||||
handlers := slices.Clone(b.queryHandlers)
|
||||
b.queryMu.RUnlock()
|
||||
|
||||
var results []any
|
||||
|
|
@ -99,7 +100,7 @@ func (b *messageBus) registerQuery(handler QueryHandler) {
|
|||
// perform dispatches a task to handlers until one executes it.
|
||||
func (b *messageBus) perform(t Task) (any, bool, error) {
|
||||
b.taskMu.RLock()
|
||||
handlers := append([]TaskHandler(nil), b.taskHandlers...)
|
||||
handlers := slices.Clone(b.taskHandlers)
|
||||
b.taskMu.RUnlock()
|
||||
|
||||
for _, h := range handlers {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ package core
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"maps"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// ServiceRuntime is a helper struct embedded in services to provide access to the core application.
|
||||
|
|
@ -58,14 +59,13 @@ func NewWithFactories(app any, factories map[string]ServiceFactory) (*Runtime, e
|
|||
WithApp(app),
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(factories))
|
||||
for name := range factories {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
names := slices.Sorted(maps.Keys(factories))
|
||||
|
||||
for _, name := range names {
|
||||
factory := factories[name]
|
||||
if factory == nil {
|
||||
return nil, fmt.Errorf("failed to create service %s: factory is nil", name)
|
||||
}
|
||||
svc, err := factory()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create service %s: %w", name, err)
|
||||
|
|
@ -99,14 +99,15 @@ func (r *Runtime) ServiceName() string {
|
|||
|
||||
// ServiceStartup is called by the GUI runtime at application startup.
|
||||
// This is where the Core's startup lifecycle is initiated.
|
||||
func (r *Runtime) ServiceStartup(ctx context.Context, options any) {
|
||||
_ = r.Core.ServiceStartup(ctx, options)
|
||||
func (r *Runtime) ServiceStartup(ctx context.Context, options any) error {
|
||||
return r.Core.ServiceStartup(ctx, options)
|
||||
}
|
||||
|
||||
// ServiceShutdown is called by the GUI runtime at application shutdown.
|
||||
// This is where the Core's shutdown lifecycle is initiated.
|
||||
func (r *Runtime) ServiceShutdown(ctx context.Context) {
|
||||
func (r *Runtime) ServiceShutdown(ctx context.Context) error {
|
||||
if r.Core != nil {
|
||||
_ = r.Core.ServiceShutdown(ctx)
|
||||
return r.Core.ServiceShutdown(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -88,9 +88,9 @@ func TestNewWithFactories_Ugly(t *testing.T) {
|
|||
factories := map[string]ServiceFactory{
|
||||
"test": nil,
|
||||
}
|
||||
assert.Panics(t, func() {
|
||||
_, _ = NewWithFactories(nil, factories)
|
||||
})
|
||||
_, err := NewWithFactories(nil, factories)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "factory is nil")
|
||||
}
|
||||
|
||||
func TestRuntime_Lifecycle_Good(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package core
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
|
@ -81,7 +82,7 @@ func (m *serviceManager) applyLock() {
|
|||
// getStartables returns a snapshot copy of the startables slice.
|
||||
func (m *serviceManager) getStartables() []Startable {
|
||||
m.mu.RLock()
|
||||
out := append([]Startable(nil), m.startables...)
|
||||
out := slices.Clone(m.startables)
|
||||
m.mu.RUnlock()
|
||||
return out
|
||||
}
|
||||
|
|
@ -89,7 +90,7 @@ func (m *serviceManager) getStartables() []Startable {
|
|||
// getStoppables returns a snapshot copy of the stoppables slice.
|
||||
func (m *serviceManager) getStoppables() []Stoppable {
|
||||
m.mu.RLock()
|
||||
out := append([]Stoppable(nil), m.stoppables...)
|
||||
out := slices.Clone(m.stoppables)
|
||||
m.mu.RUnlock()
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue