refactor(update): use watcher pattern for auto-restart
Replace the direct exec-based restart with a spawned watcher process: - Add hidden --watch-pid flag for internal use - spawnWatcher() spawns background process before update - watchAndRestart() polls for parent death, then restarts binary - Uses signal 0 on Unix to check if process is alive - Windows fallback spawns new process and exits This approach is safer because: - Parent exits cleanly before restart (no file locking issues) - Watcher is detached from parent process group - Works reliably across platforms Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
e41ed47264
commit
180ce7428f
1 changed files with 116 additions and 35 deletions
|
|
@ -3,8 +3,11 @@ package updater
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/host-uk/core/pkg/cli"
|
"github.com/host-uk/core/pkg/cli"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
@ -18,9 +21,10 @@ const (
|
||||||
|
|
||||||
// Command flags
|
// Command flags
|
||||||
var (
|
var (
|
||||||
updateChannel string
|
updateChannel string
|
||||||
updateForce bool
|
updateForce bool
|
||||||
updateCheck bool
|
updateCheck bool
|
||||||
|
updateWatchPID int
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
@ -48,6 +52,8 @@ Examples:
|
||||||
updateCmd.PersistentFlags().StringVar(&updateChannel, "channel", "stable", "Release channel: stable, beta, alpha, or dev")
|
updateCmd.PersistentFlags().StringVar(&updateChannel, "channel", "stable", "Release channel: stable, beta, alpha, or dev")
|
||||||
updateCmd.PersistentFlags().BoolVar(&updateForce, "force", false, "Force update even if already on latest version")
|
updateCmd.PersistentFlags().BoolVar(&updateForce, "force", false, "Force update even if already on latest version")
|
||||||
updateCmd.Flags().BoolVar(&updateCheck, "check", false, "Only check for updates, don't apply")
|
updateCmd.Flags().BoolVar(&updateCheck, "check", false, "Only check for updates, don't apply")
|
||||||
|
updateCmd.Flags().IntVar(&updateWatchPID, "watch-pid", 0, "Internal: watch for parent PID to die then restart")
|
||||||
|
_ = updateCmd.Flags().MarkHidden("watch-pid")
|
||||||
|
|
||||||
updateCmd.AddCommand(&cobra.Command{
|
updateCmd.AddCommand(&cobra.Command{
|
||||||
Use: "check",
|
Use: "check",
|
||||||
|
|
@ -62,6 +68,11 @@ Examples:
|
||||||
}
|
}
|
||||||
|
|
||||||
func runUpdate(cmd *cobra.Command, args []string) error {
|
func runUpdate(cmd *cobra.Command, args []string) error {
|
||||||
|
// If we're in watch mode, wait for parent to die then restart
|
||||||
|
if updateWatchPID > 0 {
|
||||||
|
return watchAndRestart(updateWatchPID)
|
||||||
|
}
|
||||||
|
|
||||||
currentVersion := cli.AppVersion
|
currentVersion := cli.AppVersion
|
||||||
|
|
||||||
cli.Print("%s %s\n", cli.DimStyle.Render("Current version:"), cli.ValueStyle.Render(currentVersion))
|
cli.Print("%s %s\n", cli.DimStyle.Render("Current version:"), cli.ValueStyle.Render(currentVersion))
|
||||||
|
|
@ -104,6 +115,12 @@ func runUpdate(cmd *cobra.Command, args []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Spawn watcher before applying update
|
||||||
|
if err := spawnWatcher(); err != nil {
|
||||||
|
// If watcher fails, continue anyway - update will still work
|
||||||
|
cli.Print("%s Could not spawn restart watcher: %v\n", cli.DimStyle.Render("!"), err)
|
||||||
|
}
|
||||||
|
|
||||||
// Apply update
|
// Apply update
|
||||||
cli.Print("\n%s Downloading update...\n", cli.DimStyle.Render("→"))
|
cli.Print("\n%s Downloading update...\n", cli.DimStyle.Render("→"))
|
||||||
|
|
||||||
|
|
@ -117,8 +134,11 @@ func runUpdate(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
cli.Print("%s Updated to %s\n", cli.SuccessStyle.Render(cli.Glyph(":check:")), release.TagName)
|
cli.Print("%s Updated to %s\n", cli.SuccessStyle.Render(cli.Glyph(":check:")), release.TagName)
|
||||||
|
cli.Print("%s Restarting...\n", cli.DimStyle.Render("→"))
|
||||||
|
|
||||||
return restartBinary()
|
// Exit so the watcher can restart us
|
||||||
|
os.Exit(0)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleDevUpdate handles updates from the dev release (rolling prerelease)
|
// handleDevUpdate handles updates from the dev release (rolling prerelease)
|
||||||
|
|
@ -143,6 +163,11 @@ func handleDevUpdate(currentVersion string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Spawn watcher before applying update
|
||||||
|
if err := spawnWatcher(); err != nil {
|
||||||
|
cli.Print("%s Could not spawn restart watcher: %v\n", cli.DimStyle.Render("!"), err)
|
||||||
|
}
|
||||||
|
|
||||||
cli.Print("\n%s Downloading update...\n", cli.DimStyle.Render("→"))
|
cli.Print("\n%s Downloading update...\n", cli.DimStyle.Render("→"))
|
||||||
|
|
||||||
downloadURL, err := GetDownloadURL(release, "")
|
downloadURL, err := GetDownloadURL(release, "")
|
||||||
|
|
@ -155,8 +180,10 @@ func handleDevUpdate(currentVersion string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
cli.Print("%s Updated to %s\n", cli.SuccessStyle.Render(cli.Glyph(":check:")), release.TagName)
|
cli.Print("%s Updated to %s\n", cli.SuccessStyle.Render(cli.Glyph(":check:")), release.TagName)
|
||||||
|
cli.Print("%s Restarting...\n", cli.DimStyle.Render("→"))
|
||||||
|
|
||||||
return restartBinary()
|
os.Exit(0)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleDevTagUpdate fetches the dev release using the direct tag
|
// handleDevTagUpdate fetches the dev release using the direct tag
|
||||||
|
|
@ -178,6 +205,11 @@ func handleDevTagUpdate(currentVersion string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Spawn watcher before applying update
|
||||||
|
if err := spawnWatcher(); err != nil {
|
||||||
|
cli.Print("%s Could not spawn restart watcher: %v\n", cli.DimStyle.Render("!"), err)
|
||||||
|
}
|
||||||
|
|
||||||
cli.Print("\n%s Downloading from dev release...\n", cli.DimStyle.Render("→"))
|
cli.Print("\n%s Downloading from dev release...\n", cli.DimStyle.Render("→"))
|
||||||
|
|
||||||
if err := DoUpdate(downloadURL); err != nil {
|
if err := DoUpdate(downloadURL); err != nil {
|
||||||
|
|
@ -185,36 +217,85 @@ func handleDevTagUpdate(currentVersion string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
cli.Print("%s Updated to latest dev build\n", cli.SuccessStyle.Render(cli.Glyph(":check:")))
|
cli.Print("%s Updated to latest dev build\n", cli.SuccessStyle.Render(cli.Glyph(":check:")))
|
||||||
|
cli.Print("%s Restarting...\n", cli.DimStyle.Render("→"))
|
||||||
|
|
||||||
return restartBinary()
|
os.Exit(0)
|
||||||
}
|
|
||||||
|
|
||||||
// restartBinary re-executes the current binary to load the new version.
|
|
||||||
// On Unix systems, it uses syscall.Exec to replace the current process.
|
|
||||||
// On Windows, it prints a message to restart manually.
|
|
||||||
func restartBinary() error {
|
|
||||||
executable, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
cli.Print("Restart the CLI to use the new version.\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// On Windows, exec doesn't work the same way - just ask to restart
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
cli.Print("Restart the CLI to use the new version.\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cli.Print("\n%s Restarting...\n", cli.DimStyle.Render("→"))
|
|
||||||
|
|
||||||
// Re-exec with --version to confirm the update
|
|
||||||
err = syscall.Exec(executable, []string{executable, "--version"}, os.Environ())
|
|
||||||
if err != nil {
|
|
||||||
// If exec fails, just tell user to restart
|
|
||||||
cli.Print("Restart the CLI to use the new version.\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// This line is never reached if exec succeeds
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// spawnWatcher spawns a background process that watches for the current process
|
||||||
|
// to exit, then restarts the binary with --version to confirm the update.
|
||||||
|
func spawnWatcher() error {
|
||||||
|
executable, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pid := os.Getpid()
|
||||||
|
|
||||||
|
// Spawn: core update --watch-pid=<pid>
|
||||||
|
cmd := exec.Command(executable, "update", "--watch-pid", strconv.Itoa(pid))
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
// Detach from parent process group
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
Setpgid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmd.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
// watchAndRestart waits for the given PID to exit, then restarts the binary.
|
||||||
|
func watchAndRestart(pid int) error {
|
||||||
|
// Wait for the parent process to die
|
||||||
|
for {
|
||||||
|
if !isProcessRunning(pid) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Small delay to ensure file handle is released
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Get executable path
|
||||||
|
executable, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// On Unix, use exec to replace this process
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return syscall.Exec(executable, []string{executable, "--version"}, os.Environ())
|
||||||
|
}
|
||||||
|
|
||||||
|
// On Windows, spawn new process and exit
|
||||||
|
cmd := exec.Command(executable, "--version")
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Exit(0)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isProcessRunning checks if a process with the given PID is still running.
|
||||||
|
func isProcessRunning(pid int) bool {
|
||||||
|
process, err := os.FindProcess(pid)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// On Unix, FindProcess always succeeds, so we need to send signal 0
|
||||||
|
// to check if the process actually exists
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
err = process.Signal(syscall.Signal(0))
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// On Windows, FindProcess returns an error if process doesn't exist
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue