diff --git a/pkg/updater/cmd.go b/pkg/updater/cmd.go index e72bf3bf..55eae0c5 100644 --- a/pkg/updater/cmd.go +++ b/pkg/updater/cmd.go @@ -3,8 +3,11 @@ package updater import ( "fmt" "os" + "os/exec" "runtime" + "strconv" "syscall" + "time" "github.com/host-uk/core/pkg/cli" "github.com/spf13/cobra" @@ -18,9 +21,10 @@ const ( // Command flags var ( - updateChannel string - updateForce bool - updateCheck bool + updateChannel string + updateForce bool + updateCheck bool + updateWatchPID int ) func init() { @@ -48,6 +52,8 @@ Examples: updateCmd.PersistentFlags().StringVar(&updateChannel, "channel", "stable", "Release channel: stable, beta, alpha, or dev") updateCmd.PersistentFlags().BoolVar(&updateForce, "force", false, "Force update even if already on latest version") updateCmd.Flags().BoolVar(&updateCheck, "check", false, "Only check for updates, don't apply") + updateCmd.Flags().IntVar(&updateWatchPID, "watch-pid", 0, "Internal: watch for parent PID to die then restart") + _ = updateCmd.Flags().MarkHidden("watch-pid") updateCmd.AddCommand(&cobra.Command{ Use: "check", @@ -62,6 +68,11 @@ Examples: } func runUpdate(cmd *cobra.Command, args []string) error { + // If we're in watch mode, wait for parent to die then restart + if updateWatchPID > 0 { + return watchAndRestart(updateWatchPID) + } + currentVersion := cli.AppVersion cli.Print("%s %s\n", cli.DimStyle.Render("Current version:"), cli.ValueStyle.Render(currentVersion)) @@ -104,6 +115,12 @@ func runUpdate(cmd *cobra.Command, args []string) error { return nil } + // Spawn watcher before applying update + if err := spawnWatcher(); err != nil { + // If watcher fails, continue anyway - update will still work + cli.Print("%s Could not spawn restart watcher: %v\n", cli.DimStyle.Render("!"), err) + } + // Apply update cli.Print("\n%s Downloading update...\n", cli.DimStyle.Render("→")) @@ -117,8 +134,11 @@ func runUpdate(cmd *cobra.Command, args []string) error { } cli.Print("%s Updated to %s\n", cli.SuccessStyle.Render(cli.Glyph(":check:")), release.TagName) + cli.Print("%s Restarting...\n", cli.DimStyle.Render("→")) - return restartBinary() + // Exit so the watcher can restart us + os.Exit(0) + return nil } // handleDevUpdate handles updates from the dev release (rolling prerelease) @@ -143,6 +163,11 @@ func handleDevUpdate(currentVersion string) error { return nil } + // Spawn watcher before applying update + if err := spawnWatcher(); err != nil { + cli.Print("%s Could not spawn restart watcher: %v\n", cli.DimStyle.Render("!"), err) + } + cli.Print("\n%s Downloading update...\n", cli.DimStyle.Render("→")) downloadURL, err := GetDownloadURL(release, "") @@ -155,8 +180,10 @@ func handleDevUpdate(currentVersion string) error { } cli.Print("%s Updated to %s\n", cli.SuccessStyle.Render(cli.Glyph(":check:")), release.TagName) + cli.Print("%s Restarting...\n", cli.DimStyle.Render("→")) - return restartBinary() + os.Exit(0) + return nil } // handleDevTagUpdate fetches the dev release using the direct tag @@ -178,6 +205,11 @@ func handleDevTagUpdate(currentVersion string) error { return nil } + // Spawn watcher before applying update + if err := spawnWatcher(); err != nil { + cli.Print("%s Could not spawn restart watcher: %v\n", cli.DimStyle.Render("!"), err) + } + cli.Print("\n%s Downloading from dev release...\n", cli.DimStyle.Render("→")) if err := DoUpdate(downloadURL); err != nil { @@ -185,36 +217,85 @@ func handleDevTagUpdate(currentVersion string) error { } cli.Print("%s Updated to latest dev build\n", cli.SuccessStyle.Render(cli.Glyph(":check:"))) + cli.Print("%s Restarting...\n", cli.DimStyle.Render("→")) - return restartBinary() -} - -// restartBinary re-executes the current binary to load the new version. -// On Unix systems, it uses syscall.Exec to replace the current process. -// On Windows, it prints a message to restart manually. -func restartBinary() error { - executable, err := os.Executable() - if err != nil { - cli.Print("Restart the CLI to use the new version.\n") - return nil - } - - // On Windows, exec doesn't work the same way - just ask to restart - if runtime.GOOS == "windows" { - cli.Print("Restart the CLI to use the new version.\n") - return nil - } - - cli.Print("\n%s Restarting...\n", cli.DimStyle.Render("→")) - - // Re-exec with --version to confirm the update - err = syscall.Exec(executable, []string{executable, "--version"}, os.Environ()) - if err != nil { - // If exec fails, just tell user to restart - cli.Print("Restart the CLI to use the new version.\n") - return nil - } - - // This line is never reached if exec succeeds + os.Exit(0) return nil } + +// spawnWatcher spawns a background process that watches for the current process +// to exit, then restarts the binary with --version to confirm the update. +func spawnWatcher() error { + executable, err := os.Executable() + if err != nil { + return err + } + + pid := os.Getpid() + + // Spawn: core update --watch-pid= + cmd := exec.Command(executable, "update", "--watch-pid", strconv.Itoa(pid)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Detach from parent process group + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } + + return cmd.Start() +} + +// watchAndRestart waits for the given PID to exit, then restarts the binary. +func watchAndRestart(pid int) error { + // Wait for the parent process to die + for { + if !isProcessRunning(pid) { + break + } + time.Sleep(100 * time.Millisecond) + } + + // Small delay to ensure file handle is released + time.Sleep(200 * time.Millisecond) + + // Get executable path + executable, err := os.Executable() + if err != nil { + return err + } + + // On Unix, use exec to replace this process + if runtime.GOOS != "windows" { + return syscall.Exec(executable, []string{executable, "--version"}, os.Environ()) + } + + // On Windows, spawn new process and exit + cmd := exec.Command(executable, "--version") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Start(); err != nil { + return err + } + + os.Exit(0) + return nil +} + +// isProcessRunning checks if a process with the given PID is still running. +func isProcessRunning(pid int) bool { + process, err := os.FindProcess(pid) + if err != nil { + return false + } + + // On Unix, FindProcess always succeeds, so we need to send signal 0 + // to check if the process actually exists + if runtime.GOOS != "windows" { + err = process.Signal(syscall.Signal(0)) + return err == nil + } + + // On Windows, FindProcess returns an error if process doesn't exist + return true +}