package mining import ( "context" "sync" "time" "forge.lthn.ai/Snider/Mining/pkg/logging" ) // supervisor.RegisterTask("cleanup", func(ctx context.Context) { runCleanup(ctx) }, 30*time.Second, 10) type TaskFunc func(ctx context.Context) // task := &SupervisedTask{name: "stats-collector", restartDelay: 5 * time.Second, maxRestarts: -1} type SupervisedTask struct { name string task TaskFunc restartDelay time.Duration maxRestarts int restartCount int running bool lastStartTime time.Time cancel context.CancelFunc mutex sync.Mutex } // supervisor := NewTaskSupervisor() // supervisor.RegisterTask("stats-collector", collectStats, 5*time.Second, -1) // supervisor.Start(); defer supervisor.Stop() type TaskSupervisor struct { tasks map[string]*SupervisedTask ctx context.Context cancel context.CancelFunc waitGroup sync.WaitGroup mutex sync.RWMutex started bool } // supervisor := NewTaskSupervisor() // supervisor.RegisterTask("stats-collector", collectStats, 5*time.Second, -1) // supervisor.Start() func NewTaskSupervisor() *TaskSupervisor { ctx, cancel := context.WithCancel(context.Background()) return &TaskSupervisor{ tasks: make(map[string]*SupervisedTask), ctx: ctx, cancel: cancel, } } // supervisor.RegisterTask("stats-collector", collectStats, 5*time.Second, -1) // supervisor.RegisterTask("cleanup", runCleanup, 30*time.Second, 10) func (supervisor *TaskSupervisor) RegisterTask(name string, task TaskFunc, restartDelay time.Duration, maxRestarts int) { supervisor.mutex.Lock() defer supervisor.mutex.Unlock() supervisor.tasks[name] = &SupervisedTask{ name: name, task: task, restartDelay: restartDelay, maxRestarts: maxRestarts, } } // supervisor.Start() // begins all registered tasks; no-op if already started func (supervisor *TaskSupervisor) Start() { supervisor.mutex.Lock() if supervisor.started { supervisor.mutex.Unlock() return } supervisor.started = true supervisor.mutex.Unlock() supervisor.mutex.RLock() for name, task := range supervisor.tasks { supervisor.startTask(name, task) } supervisor.mutex.RUnlock() } // supervisor.startTask("stats-collector", supervisor.tasks["stats-collector"]) func (supervisor *TaskSupervisor) startTask(name string, st *SupervisedTask) { st.mutex.Lock() if st.running { st.mutex.Unlock() return } st.running = true st.lastStartTime = time.Now() taskCtx, taskCancel := context.WithCancel(supervisor.ctx) st.cancel = taskCancel st.mutex.Unlock() supervisor.waitGroup.Add(1) go func() { defer supervisor.waitGroup.Done() for { select { case <-supervisor.ctx.Done(): return default: } // Run the task with panic recovery func() { defer func() { if r := recover(); r != nil { logging.Error("supervised task panicked", logging.Fields{ "task": name, "panic": r, }) } }() st.task(taskCtx) }() // Check if we should restart st.mutex.Lock() st.restartCount++ shouldRestart := st.restartCount <= st.maxRestarts || st.maxRestarts < 0 restartDelay := st.restartDelay st.mutex.Unlock() if !shouldRestart { logging.Warn("supervised task reached max restarts", logging.Fields{ "task": name, "maxRestart": st.maxRestarts, }) return } select { case <-supervisor.ctx.Done(): return case <-time.After(restartDelay): logging.Info("restarting supervised task", logging.Fields{ "task": name, "restartCount": st.restartCount, }) } } }() logging.Info("started supervised task", logging.Fields{"task": name}) } // supervisor.Stop() // cancels all tasks and waits for clean exit func (supervisor *TaskSupervisor) Stop() { supervisor.cancel() supervisor.waitGroup.Wait() supervisor.mutex.Lock() supervisor.started = false for _, task := range supervisor.tasks { task.mutex.Lock() task.running = false task.mutex.Unlock() } supervisor.mutex.Unlock() logging.Info("task supervisor stopped") } // running, restarts, ok := supervisor.GetTaskStatus("stats-collector") func (supervisor *TaskSupervisor) GetTaskStatus(name string) (running bool, restartCount int, found bool) { supervisor.mutex.RLock() task, ok := supervisor.tasks[name] supervisor.mutex.RUnlock() if !ok { return false, 0, false } task.mutex.Lock() defer task.mutex.Unlock() return task.running, task.restartCount, true } // for name, status := range supervisor.GetAllTaskStatuses() { log(name, status.Running) } func (supervisor *TaskSupervisor) GetAllTaskStatuses() map[string]TaskStatus { supervisor.mutex.RLock() defer supervisor.mutex.RUnlock() statuses := make(map[string]TaskStatus, len(supervisor.tasks)) for name, task := range supervisor.tasks { task.mutex.Lock() statuses[name] = TaskStatus{ Name: name, Running: task.running, RestartCount: task.restartCount, LastStart: task.lastStartTime, } task.mutex.Unlock() } return statuses } // status := supervisor.GetAllTaskStatuses()["stats-collector"] // if status.Running { log(status.RestartCount, status.LastStart) } type TaskStatus struct { Name string `json:"name"` Running bool `json:"running"` RestartCount int `json:"restartCount"` LastStart time.Time `json:"lastStart"` }