Compare commits
3 commits
ax/review-
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0613a3aae8 | ||
|
|
bbd06db2df | ||
|
|
00f30708ac |
15 changed files with 354 additions and 638 deletions
16
cmd.go
16
cmd.go
|
|
@ -2,8 +2,8 @@ package updater
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"github.com/spf13/cobra"
|
||||
|
|
@ -27,9 +27,7 @@ func init() {
|
|||
cli.RegisterCommands(AddUpdateCommands)
|
||||
}
|
||||
|
||||
// AddUpdateCommands registers the update command and its subcommands with the root command.
|
||||
//
|
||||
// cli.RegisterCommands(updater.AddUpdateCommands)
|
||||
// AddUpdateCommands registers the update command and subcommands.
|
||||
func AddUpdateCommands(root *cobra.Command) {
|
||||
updateCmd := &cobra.Command{
|
||||
Use: "update",
|
||||
|
|
@ -47,7 +45,7 @@ Examples:
|
|||
RunE: runUpdate,
|
||||
}
|
||||
|
||||
updateCmd.PersistentFlags().StringVar(&updateChannel, "channel", "stable", "Release channel: stable, beta, alpha, or dev")
|
||||
updateCmd.PersistentFlags().StringVar(&updateChannel, "channel", "stable", "Release channel: stable, beta, alpha, prerelease, or dev")
|
||||
updateCmd.PersistentFlags().BoolVar(&updateForce, "force", false, "Force update even if already on latest version")
|
||||
updateCmd.Flags().BoolVar(&updateCheck, "check", false, "Only check for updates, do not apply")
|
||||
updateCmd.Flags().IntVar(&updateWatchPID, "watch-pid", 0, "Internal: watch for parent PID to die then restart")
|
||||
|
|
@ -74,7 +72,7 @@ func runUpdate(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
|
||||
currentVersion := cli.AppVersion
|
||||
normalizedChannel := strings.TrimSpace(strings.ToLower(updateChannel))
|
||||
normalizedChannel := normaliseGitHubChannel(updateChannel)
|
||||
|
||||
cli.Print("%s %s\n", cli.DimStyle.Render("Current version:"), cli.ValueStyle.Render(currentVersion))
|
||||
cli.Print("%s %s/%s\n", cli.DimStyle.Render("Platform:"), runtime.GOOS, runtime.GOARCH)
|
||||
|
|
@ -187,8 +185,10 @@ func handleDevUpdate(currentVersion string) error {
|
|||
// handleDevTagUpdate fetches the dev release using the direct tag
|
||||
func handleDevTagUpdate(currentVersion string) error {
|
||||
// Construct download URL directly for dev release
|
||||
downloadURL := "https://github.com/" + repoOwner + "/" + repoName +
|
||||
"/releases/download/dev/core-" + runtime.GOOS + "-" + runtime.GOARCH
|
||||
downloadURL := fmt.Sprintf(
|
||||
"https://github.com/%s/%s/releases/download/dev/core-%s-%s",
|
||||
repoOwner, repoName, runtime.GOOS, runtime.GOARCH,
|
||||
)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
downloadURL += ".exe"
|
||||
|
|
|
|||
13
cmd_unix.go
13
cmd_unix.go
|
|
@ -10,9 +10,8 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// spawnWatcher spawns a detached background process that restarts the binary after the current process exits.
|
||||
//
|
||||
// if err := spawnWatcher(); err != nil { /* non-fatal: update proceeds anyway */ }
|
||||
// spawnWatcher spawns a background process that watches for the current process
|
||||
// to exit, then restarts the binary with --version to confirm the update.
|
||||
func spawnWatcher() error {
|
||||
executable, err := os.Executable()
|
||||
if err != nil {
|
||||
|
|
@ -34,9 +33,7 @@ func spawnWatcher() error {
|
|||
return cmd.Start()
|
||||
}
|
||||
|
||||
// watchAndRestart polls until the given PID exits, then exec-replaces itself with --version.
|
||||
//
|
||||
// return watchAndRestart(os.Getppid())
|
||||
// watchAndRestart waits for the given PID to exit, then restarts the binary.
|
||||
func watchAndRestart(pid int) error {
|
||||
// Wait for the parent process to die
|
||||
for isProcessRunning(pid) {
|
||||
|
|
@ -57,9 +54,7 @@ func watchAndRestart(pid int) error {
|
|||
return syscall.Exec(executable, []string{executable, "--version"}, os.Environ())
|
||||
}
|
||||
|
||||
// isProcessRunning returns true if a process with the given PID is still running.
|
||||
//
|
||||
// for isProcessRunning(parentPID) { time.Sleep(100 * time.Millisecond) }
|
||||
// isProcessRunning checks if a process with the given PID is still running.
|
||||
func isProcessRunning(pid int) bool {
|
||||
process, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -10,9 +10,8 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// spawnWatcher spawns a detached background process that restarts the binary after the current process exits.
|
||||
//
|
||||
// if err := spawnWatcher(); err != nil { /* non-fatal: update proceeds anyway */ }
|
||||
// spawnWatcher spawns a background process that watches for the current process
|
||||
// to exit, then restarts the binary with --version to confirm the update.
|
||||
func spawnWatcher() error {
|
||||
executable, err := os.Executable()
|
||||
if err != nil {
|
||||
|
|
@ -34,9 +33,7 @@ func spawnWatcher() error {
|
|||
return cmd.Start()
|
||||
}
|
||||
|
||||
// watchAndRestart polls until the given PID exits, then spawns --version and exits.
|
||||
//
|
||||
// return watchAndRestart(os.Getppid())
|
||||
// watchAndRestart waits for the given PID to exit, then restarts the binary.
|
||||
func watchAndRestart(pid int) error {
|
||||
// Wait for the parent process to die
|
||||
for {
|
||||
|
|
@ -67,9 +64,7 @@ func watchAndRestart(pid int) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// isProcessRunning returns true if a process with the given PID is still running.
|
||||
//
|
||||
// for isProcessRunning(parentPID) { time.Sleep(100 * time.Millisecond) }
|
||||
// isProcessRunning checks if a process with the given PID is still running.
|
||||
func isProcessRunning(pid int) bool {
|
||||
// On Windows, try to open the process with query rights
|
||||
handle, err := syscall.OpenProcess(syscall.PROCESS_QUERY_INFORMATION, false, uint32(pid))
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
|
|
@ -19,9 +19,15 @@ type GenericUpdateInfo struct {
|
|||
}
|
||||
|
||||
// GetLatestUpdateFromURL fetches and parses a latest.json file from a base URL.
|
||||
// The server at the baseURL should host a 'latest.json' file that contains
|
||||
// the version and download URL for the latest update.
|
||||
//
|
||||
// info, err := updater.GetLatestUpdateFromURL("https://my-server.com/updates")
|
||||
// // info.Version == "1.2.3", info.URL == "https://..."
|
||||
// Example of latest.json:
|
||||
//
|
||||
// {
|
||||
// "version": "1.2.3",
|
||||
// "url": "https://your-server.com/path/to/release-asset"
|
||||
// }
|
||||
func GetLatestUpdateFromURL(baseURL string) (*GenericUpdateInfo, error) {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
|
|
@ -43,7 +49,7 @@ func GetLatestUpdateFromURL(baseURL string) (*GenericUpdateInfo, error) {
|
|||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, coreerr.E("GetLatestUpdateFromURL", "failed to fetch latest.json: status "+strconv.Itoa(resp.StatusCode), nil)
|
||||
return nil, coreerr.E("GetLatestUpdateFromURL", fmt.Sprintf("failed to fetch latest.json: status code %d", resp.StatusCode), nil)
|
||||
}
|
||||
|
||||
var info GenericUpdateInfo
|
||||
|
|
|
|||
|
|
@ -7,52 +7,49 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func TestGenericHttp_GetLatestUpdateFromURL_Good(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = fmt.Fprintln(w, `{"version": "v1.1.0", "url": "http://example.com/release.zip"}`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
info, err := GetLatestUpdateFromURL(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Version != "v1.1.0" {
|
||||
t.Errorf("expected version v1.1.0, got %q", info.Version)
|
||||
}
|
||||
if info.URL != "http://example.com/release.zip" {
|
||||
t.Errorf("expected URL http://example.com/release.zip, got %q", info.URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericHttp_GetLatestUpdateFromURL_Bad(t *testing.T) {
|
||||
func TestGetLatestUpdateFromURL(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
expectError bool
|
||||
expectedVersion string
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
name: "invalid JSON",
|
||||
name: "Valid latest.json",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = fmt.Fprintln(w, `{"version": "v1.1.0", "url": "http://example.com/release.zip"`) // missing closing brace
|
||||
_, _ = fmt.Fprintln(w, `{"version": "v1.1.0", "url": "http://example.com/release.zip"}`)
|
||||
},
|
||||
expectedVersion: "v1.1.0",
|
||||
expectedURL: "http://example.com/release.zip",
|
||||
},
|
||||
{
|
||||
name: "missing version field",
|
||||
name: "Invalid JSON",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = fmt.Fprintln(w, `{"version": "v1.1.0", "url": "http://example.com/release.zip"`) // Missing closing brace
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Missing version",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = fmt.Fprintln(w, `{"url": "http://example.com/release.zip"}`)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing url field",
|
||||
name: "Missing URL",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = fmt.Fprintln(w, `{"version": "v1.1.0"}`)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "server error",
|
||||
name: "Server error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -61,29 +58,20 @@ func TestGenericHttp_GetLatestUpdateFromURL_Bad(t *testing.T) {
|
|||
server := httptest.NewServer(tc.handler)
|
||||
defer server.Close()
|
||||
|
||||
_, err := GetLatestUpdateFromURL(server.URL)
|
||||
if err == nil {
|
||||
t.Errorf("expected error for case %q, got nil", tc.name)
|
||||
info, err := GetLatestUpdateFromURL(server.URL)
|
||||
|
||||
if (err != nil) != tc.expectError {
|
||||
t.Errorf("Expected error: %v, got: %v", tc.expectError, err)
|
||||
}
|
||||
|
||||
if !tc.expectError {
|
||||
if info.Version != tc.expectedVersion {
|
||||
t.Errorf("Expected version: %s, got: %s", tc.expectedVersion, info.Version)
|
||||
}
|
||||
if info.URL != tc.expectedURL {
|
||||
t.Errorf("Expected URL: %s, got: %s", tc.expectedURL, info.URL)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericHttp_GetLatestUpdateFromURL_Ugly(t *testing.T) {
|
||||
// Invalid base URL
|
||||
_, err := GetLatestUpdateFromURL("://invalid-url")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid URL, got nil")
|
||||
}
|
||||
|
||||
// Server returns empty body
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err = GetLatestUpdateFromURL(server.URL)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty response body, got nil")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
45
github.go
45
github.go
|
|
@ -3,10 +3,10 @@ package updater
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
|
|
@ -44,13 +44,13 @@ type GithubClient interface {
|
|||
|
||||
type githubClient struct{}
|
||||
|
||||
// NewAuthenticatedClient returns an HTTP client using GITHUB_TOKEN if set, or the default client.
|
||||
//
|
||||
// client := updater.NewAuthenticatedClient(ctx)
|
||||
// NewAuthenticatedClient creates a new HTTP client that authenticates with the GitHub API.
|
||||
// It uses the GITHUB_TOKEN environment variable for authentication.
|
||||
// If the token is not set, it returns the default HTTP client.
|
||||
var NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
token := os.Getenv("GITHUB_TOKEN")
|
||||
if token == "" {
|
||||
return NewHTTPClient()
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
ts := oauth2.StaticTokenSource(
|
||||
|
|
@ -68,7 +68,7 @@ func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]
|
|||
func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) {
|
||||
client := NewAuthenticatedClient(ctx)
|
||||
var allCloneURLs []string
|
||||
url := apiURL + "/users/" + userOrOrg + "/repos"
|
||||
url := fmt.Sprintf("%s/users/%s/repos", apiURL, userOrOrg)
|
||||
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
|
|
@ -85,8 +85,8 @@ func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, use
|
|||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = resp.Body.Close()
|
||||
// Try organisation endpoint
|
||||
url = apiURL + "/orgs/" + userOrOrg + "/repos"
|
||||
// Try organization endpoint
|
||||
url = fmt.Sprintf("%s/orgs/%s/repos", apiURL, userOrOrg)
|
||||
req, err = newAgentRequest(ctx, "GET", url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -99,7 +99,7 @@ func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, use
|
|||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = resp.Body.Close()
|
||||
return nil, coreerr.E("github.getPublicReposWithAPIURL", "failed to fetch repos: "+resp.Status, nil)
|
||||
return nil, coreerr.E("github.getPublicReposWithAPIURL", fmt.Sprintf("failed to fetch repos: %s", resp.Status), nil)
|
||||
}
|
||||
|
||||
var repos []Repo
|
||||
|
|
@ -138,14 +138,13 @@ func (g *githubClient) findNextURL(linkHeader string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
// GetLatestRelease fetches the first release matching the given channel ("stable", "beta", or "alpha").
|
||||
//
|
||||
// release, err := client.GetLatestRelease(ctx, "myorg", "myrepo", "stable")
|
||||
// GetLatestRelease fetches the latest release for a given repository and channel.
|
||||
// The channel can be "stable", "beta", or "alpha".
|
||||
func (g *githubClient) GetLatestRelease(ctx context.Context, owner, repo, channel string) (*Release, error) {
|
||||
client := NewAuthenticatedClient(ctx)
|
||||
releaseURL := "https://api.github.com/repos/" + owner + "/" + repo + "/releases"
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases", owner, repo)
|
||||
|
||||
req, err := newAgentRequest(ctx, "GET", releaseURL)
|
||||
req, err := newAgentRequest(ctx, "GET", url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -157,7 +156,7 @@ func (g *githubClient) GetLatestRelease(ctx context.Context, owner, repo, channe
|
|||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, coreerr.E("github.GetLatestRelease", "failed to fetch releases: "+resp.Status, nil)
|
||||
return nil, coreerr.E("github.GetLatestRelease", fmt.Sprintf("failed to fetch releases: %s", resp.Status), nil)
|
||||
}
|
||||
|
||||
var releases []Release
|
||||
|
|
@ -194,14 +193,12 @@ func determineChannel(tagName string, isPreRelease bool) string {
|
|||
return "stable"
|
||||
}
|
||||
|
||||
// GetReleaseByPullRequest fetches the release whose tag contains .pr.<prNumber>.
|
||||
//
|
||||
// release, err := client.GetReleaseByPullRequest(ctx, "myorg", "myrepo", 42)
|
||||
// GetReleaseByPullRequest fetches a release associated with a specific pull request number.
|
||||
func (g *githubClient) GetReleaseByPullRequest(ctx context.Context, owner, repo string, prNumber int) (*Release, error) {
|
||||
client := NewAuthenticatedClient(ctx)
|
||||
releaseURL := "https://api.github.com/repos/" + owner + "/" + repo + "/releases"
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases", owner, repo)
|
||||
|
||||
req, err := newAgentRequest(ctx, "GET", releaseURL)
|
||||
req, err := newAgentRequest(ctx, "GET", url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -213,7 +210,7 @@ func (g *githubClient) GetReleaseByPullRequest(ctx context.Context, owner, repo
|
|||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, coreerr.E("github.GetReleaseByPullRequest", "failed to fetch releases: "+resp.Status, nil)
|
||||
return nil, coreerr.E("github.GetReleaseByPullRequest", fmt.Sprintf("failed to fetch releases: %s", resp.Status), nil)
|
||||
}
|
||||
|
||||
var releases []Release
|
||||
|
|
@ -221,8 +218,8 @@ func (g *githubClient) GetReleaseByPullRequest(ctx context.Context, owner, repo
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// The PR number is included in the tag name: vX.Y.Z-alpha.pr.123 or vX.Y.Z-beta.pr.123
|
||||
prTagSuffix := ".pr." + strconv.Itoa(prNumber)
|
||||
// The pr number is included in the tag name with the format `vX.Y.Z-alpha.pr.123` or `vX.Y.Z-beta.pr.123`
|
||||
prTagSuffix := fmt.Sprintf(".pr.%d", prNumber)
|
||||
for _, release := range releases {
|
||||
if strings.Contains(release.TagName, prTagSuffix) {
|
||||
return &release, nil
|
||||
|
|
@ -301,5 +298,5 @@ func GetDownloadURL(release *Release, releaseURLFormat string) (string, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return "", coreerr.E("GetDownloadURL", "no suitable download asset found for "+osName+"/"+archName, nil)
|
||||
return "", coreerr.E("GetDownloadURL", fmt.Sprintf("no suitable download asset found for %s/%s", osName, archName), nil)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func TestGithubInternal_FilterReleases_Good(t *testing.T) {
|
||||
func TestFilterReleases_Good(t *testing.T) {
|
||||
releases := []Release{
|
||||
{TagName: "v1.0.0-alpha.1", PreRelease: true},
|
||||
{TagName: "v1.0.0-beta.1", PreRelease: true},
|
||||
|
|
@ -34,7 +34,7 @@ func TestGithubInternal_FilterReleases_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_FilterReleases_Bad(t *testing.T) {
|
||||
func TestFilterReleases_Bad(t *testing.T) {
|
||||
releases := []Release{
|
||||
{TagName: "v1.0.0", PreRelease: false},
|
||||
}
|
||||
|
|
@ -44,16 +44,11 @@ func TestGithubInternal_FilterReleases_Bad(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_FilterReleases_Ugly(t *testing.T) {
|
||||
// Empty releases slice
|
||||
got := filterReleases([]Release{}, "stable")
|
||||
if got != nil {
|
||||
t.Errorf("expected nil for empty releases, got %v", got)
|
||||
func TestFilterReleases_PreReleaseWithoutLabel(t *testing.T) {
|
||||
releases := []Release{
|
||||
{TagName: "v2.0.0-rc.1", PreRelease: true},
|
||||
}
|
||||
|
||||
// Pre-release without alpha/beta label maps to beta channel
|
||||
releases := []Release{{TagName: "v2.0.0-rc.1", PreRelease: true}}
|
||||
got = filterReleases(releases, "beta")
|
||||
got := filterReleases(releases, "beta")
|
||||
if got == nil {
|
||||
t.Fatal("expected pre-release without alpha/beta label to match beta channel")
|
||||
}
|
||||
|
|
@ -62,7 +57,7 @@ func TestGithubInternal_FilterReleases_Ugly(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_DetermineChannel_Good(t *testing.T) {
|
||||
func TestDetermineChannel_Good(t *testing.T) {
|
||||
tests := []struct {
|
||||
tag string
|
||||
isPreRelease bool
|
||||
|
|
@ -74,6 +69,7 @@ func TestGithubInternal_DetermineChannel_Good(t *testing.T) {
|
|||
{"v1.0.0-beta.1", false, "beta"},
|
||||
{"v1.0.0-BETA.1", false, "beta"},
|
||||
{"v1.0.0-rc.1", true, "beta"},
|
||||
{"v1.0.0-rc.1", false, "stable"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
@ -86,28 +82,55 @@ func TestGithubInternal_DetermineChannel_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_DetermineChannel_Bad(t *testing.T) {
|
||||
// rc tag without prerelease flag resolves to stable (not beta)
|
||||
got := determineChannel("v1.0.0-rc.1", false)
|
||||
if got != "stable" {
|
||||
t.Errorf("expected stable for rc tag without prerelease flag, got %q", got)
|
||||
func TestCheckForUpdatesByTag_UsesCurrentVersionChannel(t *testing.T) {
|
||||
originalVersion := Version
|
||||
originalCheckForUpdates := CheckForUpdates
|
||||
defer func() {
|
||||
Version = originalVersion
|
||||
CheckForUpdates = originalCheckForUpdates
|
||||
}()
|
||||
|
||||
var gotChannel string
|
||||
CheckForUpdates = func(owner, repo, channel string, forceSemVerPrefix bool, releaseURLFormat string) error {
|
||||
gotChannel = channel
|
||||
return nil
|
||||
}
|
||||
|
||||
Version = "v2.0.0-rc.1"
|
||||
if err := CheckForUpdatesByTag("owner", "repo"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if gotChannel != "beta" {
|
||||
t.Fatalf("expected beta channel, got %q", gotChannel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_DetermineChannel_Ugly(t *testing.T) {
|
||||
// Empty tag with prerelease=true — maps to beta
|
||||
got := determineChannel("", true)
|
||||
if got != "beta" {
|
||||
t.Errorf("expected beta for empty tag with prerelease=true, got %q", got)
|
||||
func TestCheckOnlyByTag_UsesCurrentVersionChannel(t *testing.T) {
|
||||
originalVersion := Version
|
||||
originalCheckOnly := CheckOnly
|
||||
defer func() {
|
||||
Version = originalVersion
|
||||
CheckOnly = originalCheckOnly
|
||||
}()
|
||||
|
||||
var gotChannel string
|
||||
CheckOnly = func(owner, repo, channel string, forceSemVerPrefix bool, releaseURLFormat string) error {
|
||||
gotChannel = channel
|
||||
return nil
|
||||
}
|
||||
// Empty tag with prerelease=false — maps to stable
|
||||
got = determineChannel("", false)
|
||||
if got != "stable" {
|
||||
t.Errorf("expected stable for empty tag with prerelease=false, got %q", got)
|
||||
|
||||
Version = "v2.0.0-alpha.1"
|
||||
if err := CheckOnlyByTag("owner", "repo"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if gotChannel != "alpha" {
|
||||
t.Fatalf("expected alpha channel, got %q", gotChannel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_GetDownloadURL_Good(t *testing.T) {
|
||||
func TestGetDownloadURL_Good(t *testing.T) {
|
||||
osName := runtime.GOOS
|
||||
archName := runtime.GOARCH
|
||||
|
||||
|
|
@ -119,16 +142,50 @@ func TestGithubInternal_GetDownloadURL_Good(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
downloadURL, err := GetDownloadURL(release, "")
|
||||
url, err := GetDownloadURL(release, "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if downloadURL != "https://example.com/full-match" {
|
||||
t.Errorf("expected full match URL, got %q", downloadURL)
|
||||
if url != "https://example.com/full-match" {
|
||||
t.Errorf("expected full match URL, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_GetDownloadURL_Bad(t *testing.T) {
|
||||
func TestGetDownloadURL_OSOnlyFallback(t *testing.T) {
|
||||
osName := runtime.GOOS
|
||||
|
||||
release := &Release{
|
||||
TagName: "v1.2.3",
|
||||
Assets: []ReleaseAsset{
|
||||
{Name: "app-other-other", DownloadURL: "https://example.com/other"},
|
||||
{Name: "app-" + osName + "-other", DownloadURL: "https://example.com/os-only"},
|
||||
},
|
||||
}
|
||||
|
||||
url, err := GetDownloadURL(release, "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if url != "https://example.com/os-only" {
|
||||
t.Errorf("expected OS-only fallback URL, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDownloadURL_WithFormat(t *testing.T) {
|
||||
release := &Release{TagName: "v1.2.3"}
|
||||
|
||||
url, err := GetDownloadURL(release, "https://example.com/{tag}/{os}/{arch}")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "https://example.com/v1.2.3/" + runtime.GOOS + "/" + runtime.GOARCH
|
||||
if url != expected {
|
||||
t.Errorf("expected %q, got %q", expected, url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDownloadURL_Bad(t *testing.T) {
|
||||
// nil release
|
||||
_, err := GetDownloadURL(nil, "")
|
||||
if err == nil {
|
||||
|
|
@ -148,44 +205,14 @@ func TestGithubInternal_GetDownloadURL_Bad(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_GetDownloadURL_Ugly(t *testing.T) {
|
||||
osName := runtime.GOOS
|
||||
|
||||
// OS-only fallback when arch does not match
|
||||
release := &Release{
|
||||
TagName: "v1.2.3",
|
||||
Assets: []ReleaseAsset{
|
||||
{Name: "app-other-other", DownloadURL: "https://example.com/other"},
|
||||
{Name: "app-" + osName + "-other", DownloadURL: "https://example.com/os-only"},
|
||||
},
|
||||
}
|
||||
downloadURL, err := GetDownloadURL(release, "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for OS-only fallback: %v", err)
|
||||
}
|
||||
if downloadURL != "https://example.com/os-only" {
|
||||
t.Errorf("expected OS-only fallback URL, got %q", downloadURL)
|
||||
}
|
||||
|
||||
// URL format template with placeholders
|
||||
release = &Release{TagName: "v1.2.3"}
|
||||
downloadURL, err = GetDownloadURL(release, "https://example.com/{tag}/{os}/{arch}")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for URL format: %v", err)
|
||||
}
|
||||
expected := "https://example.com/v1.2.3/" + runtime.GOOS + "/" + runtime.GOARCH
|
||||
if downloadURL != expected {
|
||||
t.Errorf("expected %q, got %q", expected, downloadURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_FormatVersionForComparison_Good(t *testing.T) {
|
||||
func TestFormatVersionForComparison(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"1.0.0", "v1.0.0"},
|
||||
{"v1.0.0", "v1.0.0"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
@ -198,22 +225,7 @@ func TestGithubInternal_FormatVersionForComparison_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_FormatVersionForComparison_Bad(t *testing.T) {
|
||||
// Non-semver string still gets v prefix
|
||||
got := formatVersionForComparison("not-a-version")
|
||||
if got != "vnot-a-version" {
|
||||
t.Errorf("expected vnot-a-version, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_FormatVersionForComparison_Ugly(t *testing.T) {
|
||||
// Empty string returns empty (no v prefix added)
|
||||
if got := formatVersionForComparison(""); got != "" {
|
||||
t.Errorf("expected empty string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_FormatVersionForDisplay_Good(t *testing.T) {
|
||||
func TestFormatVersionForDisplay(t *testing.T) {
|
||||
tests := []struct {
|
||||
version string
|
||||
force bool
|
||||
|
|
@ -235,23 +247,14 @@ func TestGithubInternal_FormatVersionForDisplay_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_FormatVersionForDisplay_Bad(t *testing.T) {
|
||||
// force=true on already-prefixed version should not double-prefix
|
||||
got := formatVersionForDisplay("vv1.0.0", true)
|
||||
if got != "vv1.0.0" {
|
||||
t.Errorf("expected vv1.0.0 unchanged, got %q", got)
|
||||
func boolStr(b bool) string {
|
||||
if b {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
||||
func TestGithubInternal_FormatVersionForDisplay_Ugly(t *testing.T) {
|
||||
// Empty version with force=true — results in "v"
|
||||
got := formatVersionForDisplay("", true)
|
||||
if got != "v" {
|
||||
t.Errorf("expected \"v\" for empty version with force=true, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_StartGitHubCheck_Ugly(t *testing.T) {
|
||||
func TestStartGitHubCheck_UnknownMode(t *testing.T) {
|
||||
s := &UpdateService{
|
||||
config: UpdateServiceConfig{
|
||||
CheckOnStartup: StartupCheckMode(99),
|
||||
|
|
@ -260,12 +263,13 @@ func TestGithubInternal_StartGitHubCheck_Ugly(t *testing.T) {
|
|||
owner: "owner",
|
||||
repo: "repo",
|
||||
}
|
||||
if err := s.Start(); err == nil {
|
||||
err := s.Start()
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown startup check mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithubInternal_StartHTTPCheck_Ugly(t *testing.T) {
|
||||
func TestStartHTTPCheck_UnknownMode(t *testing.T) {
|
||||
s := &UpdateService{
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: "https://example.com/updates",
|
||||
|
|
@ -273,14 +277,8 @@ func TestGithubInternal_StartHTTPCheck_Ugly(t *testing.T) {
|
|||
},
|
||||
isGitHub: false,
|
||||
}
|
||||
if err := s.Start(); err == nil {
|
||||
err := s.Start()
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown startup check mode")
|
||||
}
|
||||
}
|
||||
|
||||
func boolStr(b bool) string {
|
||||
if b {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/Snider/Borg/pkg/mocks"
|
||||
)
|
||||
|
||||
func TestGithub_GetPublicRepos_Good(t *testing.T) {
|
||||
func TestGetPublicRepos(t *testing.T) {
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/users/testuser/repos": {
|
||||
StatusCode: http.StatusOK,
|
||||
|
|
@ -31,11 +31,13 @@ func TestGithub_GetPublicRepos_Good(t *testing.T) {
|
|||
})
|
||||
|
||||
client := &githubClient{}
|
||||
originalNewAuthenticatedClient := NewAuthenticatedClient
|
||||
oldClient := NewAuthenticatedClient
|
||||
NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
return mockClient
|
||||
}
|
||||
defer func() { NewAuthenticatedClient = originalNewAuthenticatedClient }()
|
||||
defer func() {
|
||||
NewAuthenticatedClient = oldClient
|
||||
}()
|
||||
|
||||
// Test user repos
|
||||
repos, err := client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testuser")
|
||||
|
|
@ -55,8 +57,7 @@ func TestGithub_GetPublicRepos_Good(t *testing.T) {
|
|||
t.Errorf("unexpected org repos: %v", repos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithub_GetPublicRepos_Bad(t *testing.T) {
|
||||
func TestGetPublicRepos_Error(t *testing.T) {
|
||||
u, _ := url.Parse("https://api.github.com/users/testuser/repos")
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/users/testuser/repos": {
|
||||
|
|
@ -77,89 +78,47 @@ func TestGithub_GetPublicRepos_Bad(t *testing.T) {
|
|||
expectedErr := "github.getPublicReposWithAPIURL: failed to fetch repos: 404 Not Found"
|
||||
|
||||
client := &githubClient{}
|
||||
originalNewAuthenticatedClient := NewAuthenticatedClient
|
||||
oldClient := NewAuthenticatedClient
|
||||
NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
return mockClient
|
||||
}
|
||||
defer func() { NewAuthenticatedClient = originalNewAuthenticatedClient }()
|
||||
defer func() {
|
||||
NewAuthenticatedClient = oldClient
|
||||
}()
|
||||
|
||||
// Test user repos
|
||||
_, err := client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testuser")
|
||||
if err == nil || err.Error() != expectedErr {
|
||||
t.Fatalf("expected %q, got %q", expectedErr, err)
|
||||
if err.Error() != expectedErr {
|
||||
t.Fatalf("getPublicReposWithAPIURL for user failed: expected %q, got %q", expectedErr, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithub_GetPublicRepos_Ugly(t *testing.T) {
|
||||
// Context already cancelled — loop should terminate immediately
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{})
|
||||
client := &githubClient{}
|
||||
originalNewAuthenticatedClient := NewAuthenticatedClient
|
||||
NewAuthenticatedClient = func(ctx context.Context) *http.Client { return mockClient }
|
||||
defer func() { NewAuthenticatedClient = originalNewAuthenticatedClient }()
|
||||
|
||||
_, err := client.getPublicReposWithAPIURL(ctx, "https://api.github.com", "testuser")
|
||||
if err == nil {
|
||||
t.Error("expected error from cancelled context, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithub_FindNextURL_Good(t *testing.T) {
|
||||
func TestFindNextURL(t *testing.T) {
|
||||
client := &githubClient{}
|
||||
linkHeader := `<https://api.github.com/organizations/123/repos?page=2>; rel="next", <https://api.github.com/organizations/123/repos?page=1>; rel="prev"`
|
||||
nextURL := client.findNextURL(linkHeader)
|
||||
if nextURL != "https://api.github.com/organizations/123/repos?page=2" {
|
||||
t.Errorf("unexpected next URL: %s", nextURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithub_FindNextURL_Bad(t *testing.T) {
|
||||
client := &githubClient{}
|
||||
linkHeader := `<https://api.github.com/organizations/123/repos?page=1>; rel="prev"`
|
||||
nextURL := client.findNextURL(linkHeader)
|
||||
linkHeader = `<https://api.github.com/organizations/123/repos?page=1>; rel="prev"`
|
||||
nextURL = client.findNextURL(linkHeader)
|
||||
if nextURL != "" {
|
||||
t.Errorf("expected empty string for no next link, got %q", nextURL)
|
||||
t.Errorf("unexpected next URL: %s", nextURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithub_FindNextURL_Ugly(t *testing.T) {
|
||||
client := &githubClient{}
|
||||
// Empty header
|
||||
if got := client.findNextURL(""); got != "" {
|
||||
t.Errorf("expected empty string for empty header, got %q", got)
|
||||
func TestNewAuthenticatedClient(t *testing.T) {
|
||||
// Test with no token
|
||||
client := NewAuthenticatedClient(context.Background())
|
||||
if client != http.DefaultClient {
|
||||
t.Errorf("expected http.DefaultClient, but got something else")
|
||||
}
|
||||
// Malformed header — no rel attribute
|
||||
if got := client.findNextURL(`<https://example.com/page=2>`); got != "" {
|
||||
t.Errorf("expected empty string for malformed header, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithub_NewAuthenticatedClient_Good(t *testing.T) {
|
||||
// With token set — should return an authenticated (non-default) client
|
||||
// Test with token
|
||||
t.Setenv("GITHUB_TOKEN", "test-token")
|
||||
client := NewAuthenticatedClient(context.Background())
|
||||
client = NewAuthenticatedClient(context.Background())
|
||||
if client == http.DefaultClient {
|
||||
t.Error("expected authenticated client, got http.DefaultClient")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithub_NewAuthenticatedClient_Bad(t *testing.T) {
|
||||
// Without token — should return a plain (unauthenticated) HTTP client, not http.DefaultClient
|
||||
t.Setenv("GITHUB_TOKEN", "")
|
||||
client := NewAuthenticatedClient(context.Background())
|
||||
if client == nil {
|
||||
t.Error("expected non-nil client when no token set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithub_NewAuthenticatedClient_Ugly(t *testing.T) {
|
||||
// Cancelled context still returns a client (client creation does not fail on cancelled context)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
client := NewAuthenticatedClient(ctx)
|
||||
if client == nil {
|
||||
t.Error("expected non-nil client even with cancelled context")
|
||||
t.Errorf("expected an authenticated client, but got http.DefaultClient")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
14
go.mod
14
go.mod
|
|
@ -1,11 +1,11 @@
|
|||
module forge.lthn.ai/core/go-update
|
||||
module dappco.re/go/core/update
|
||||
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
forge.lthn.ai/core/cli v0.3.6
|
||||
forge.lthn.ai/core/go-io v0.1.5
|
||||
forge.lthn.ai/core/go-log v0.0.4
|
||||
dappco.re/go/core/cli v0.3.6
|
||||
dappco.re/go/core/io v0.1.5
|
||||
dappco.re/go/core/log v0.0.4
|
||||
github.com/Snider/Borg v0.2.0
|
||||
github.com/minio/selfupdate v0.6.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
|
|
@ -16,9 +16,9 @@ require (
|
|||
require (
|
||||
aead.dev/minisign v0.3.0 // indirect
|
||||
dario.cat/mergo v1.0.0 // indirect
|
||||
forge.lthn.ai/core/go v0.3.1 // indirect
|
||||
forge.lthn.ai/core/go-i18n v0.1.6 // indirect
|
||||
forge.lthn.ai/core/go-inference v0.1.5 // indirect
|
||||
dappco.re/go/core v0.3.1 // indirect
|
||||
dappco.re/go/core/i18n v0.1.6 // indirect
|
||||
dappco.re/go/core/inference v0.1.5 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.4.0 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package updater
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -27,5 +28,5 @@ func updaterUserAgent() string {
|
|||
if version == "" {
|
||||
version = "unknown"
|
||||
}
|
||||
return "agent-go-update/" + version
|
||||
return fmt.Sprintf("agent-go-update/%s", version)
|
||||
}
|
||||
|
|
|
|||
96
service.go
96
service.go
|
|
@ -5,7 +5,9 @@
|
|||
package updater
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
)
|
||||
|
|
@ -28,7 +30,8 @@ type UpdateServiceConfig struct {
|
|||
// repository URL (e.g., "https://github.com/owner/repo") or a base URL
|
||||
// for a generic HTTP update server.
|
||||
RepoURL string
|
||||
// Channel specifies the release channel to track (e.g., "stable", "prerelease").
|
||||
// Channel specifies the release channel to track (e.g., "stable", "beta", or "prerelease").
|
||||
// "prerelease" is normalised to "beta" to match the GitHub release filter.
|
||||
// This is only used for GitHub-based updates.
|
||||
Channel string
|
||||
// CheckOnStartup determines the update behavior when the service starts.
|
||||
|
|
@ -53,18 +56,15 @@ type UpdateService struct {
|
|||
}
|
||||
|
||||
// NewUpdateService creates and configures a new UpdateService.
|
||||
//
|
||||
// svc, err := updater.NewUpdateService(updater.UpdateServiceConfig{
|
||||
// RepoURL: "https://github.com/owner/repo",
|
||||
// Channel: "stable",
|
||||
// CheckOnStartup: updater.CheckAndUpdateOnStartup,
|
||||
// })
|
||||
// It parses the repository URL to determine if it's a GitHub repository
|
||||
// and extracts the owner and repo name.
|
||||
func NewUpdateService(config UpdateServiceConfig) (*UpdateService, error) {
|
||||
isGitHub := containsStr(config.RepoURL, "github.com")
|
||||
isGitHub := strings.Contains(config.RepoURL, "github.com")
|
||||
var owner, repo string
|
||||
var err error
|
||||
|
||||
if isGitHub {
|
||||
config.Channel = normaliseGitHubChannel(config.Channel)
|
||||
owner, repo, err = ParseRepoURL(config.RepoURL)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("NewUpdateService", "failed to parse GitHub repo URL", err)
|
||||
|
|
@ -80,10 +80,9 @@ func NewUpdateService(config UpdateServiceConfig) (*UpdateService, error) {
|
|||
}
|
||||
|
||||
// Start initiates the update check based on the service configuration.
|
||||
//
|
||||
// if err := svc.Start(); err != nil {
|
||||
// log.Printf("update check failed: %v", err)
|
||||
// }
|
||||
// It determines whether to perform a GitHub or HTTP-based update check
|
||||
// based on the RepoURL. The behavior of the check is controlled by the
|
||||
// CheckOnStartup setting in the configuration.
|
||||
func (s *UpdateService) Start() error {
|
||||
if s.isGitHub {
|
||||
return s.startGitHubCheck()
|
||||
|
|
@ -100,7 +99,7 @@ func (s *UpdateService) startGitHubCheck() error {
|
|||
case CheckAndUpdateOnStartup:
|
||||
return CheckForUpdates(s.owner, s.repo, s.config.Channel, s.config.ForceSemVerPrefix, s.config.ReleaseURLFormat)
|
||||
default:
|
||||
return coreerr.E("startGitHubCheck", "unknown startup check mode", nil)
|
||||
return coreerr.E("startGitHubCheck", fmt.Sprintf("unknown startup check mode: %d", s.config.CheckOnStartup), nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -113,72 +112,31 @@ func (s *UpdateService) startHTTPCheck() error {
|
|||
case CheckAndUpdateOnStartup:
|
||||
return CheckForUpdatesHTTP(s.config.RepoURL)
|
||||
default:
|
||||
return coreerr.E("startHTTPCheck", "unknown startup check mode", nil)
|
||||
return coreerr.E("startHTTPCheck", fmt.Sprintf("unknown startup check mode: %d", s.config.CheckOnStartup), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// containsStr reports whether substr appears within s.
|
||||
func containsStr(s, substr string) bool {
|
||||
if substr == "" {
|
||||
return true
|
||||
}
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// trimChars removes all leading and trailing occurrences of any rune in cutset from s.
|
||||
func trimChars(s, cutset string) string {
|
||||
start, end := 0, len(s)
|
||||
for start < end && containsStr(cutset, s[start:start+1]) {
|
||||
start++
|
||||
}
|
||||
for end > start && containsStr(cutset, s[end-1:end]) {
|
||||
end--
|
||||
}
|
||||
return s[start:end]
|
||||
}
|
||||
|
||||
// splitStr splits s by sep and returns a slice of substrings.
|
||||
func splitStr(s, sep string) []string {
|
||||
if sep == "" || s == "" {
|
||||
return []string{s}
|
||||
}
|
||||
var parts []string
|
||||
for {
|
||||
idx := -1
|
||||
for i := 0; i <= len(s)-len(sep); i++ {
|
||||
if s[i:i+len(sep)] == sep {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
parts = append(parts, s)
|
||||
break
|
||||
}
|
||||
parts = append(parts, s[:idx])
|
||||
s = s[idx+len(sep):]
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// ParseRepoURL extracts the owner and repository name from a GitHub URL.
|
||||
//
|
||||
// owner, repo, err := updater.ParseRepoURL("https://github.com/myorg/myrepo")
|
||||
// // owner == "myorg", repo == "myrepo"
|
||||
// It handles standard GitHub URL formats.
|
||||
func ParseRepoURL(repoURL string) (owner string, repo string, err error) {
|
||||
u, err := url.Parse(repoURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
path := trimChars(u.Path, "/")
|
||||
parts := splitStr(path, "/")
|
||||
parts := strings.Split(strings.Trim(u.Path, "/"), "/")
|
||||
if len(parts) < 2 {
|
||||
return "", "", coreerr.E("ParseRepoURL", "invalid repo URL path: "+u.Path, nil)
|
||||
return "", "", coreerr.E("ParseRepoURL", fmt.Sprintf("invalid repo URL path: %s", u.Path), nil)
|
||||
}
|
||||
return parts[0], parts[1], nil
|
||||
}
|
||||
|
||||
func normaliseGitHubChannel(channel string) string {
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel == "" {
|
||||
return "stable"
|
||||
}
|
||||
if channel == "prerelease" {
|
||||
return "beta"
|
||||
}
|
||||
return channel
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,13 +9,12 @@ import (
|
|||
|
||||
func ExampleNewUpdateService() {
|
||||
// Mock the update check functions to prevent actual updates during tests
|
||||
originalCheckForUpdates := updater.CheckForUpdates
|
||||
updater.CheckForUpdates = func(owner, repo, channel string, forceSemVerPrefix bool, releaseURLFormat string) error {
|
||||
fmt.Println("CheckForUpdates called")
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
updater.CheckForUpdates = originalCheckForUpdates
|
||||
updater.CheckForUpdates = nil // Restore original function
|
||||
}()
|
||||
|
||||
config := updater.UpdateServiceConfig{
|
||||
|
|
|
|||
136
service_test.go
136
service_test.go
|
|
@ -6,62 +6,74 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func TestService_NewUpdateService_Good(t *testing.T) {
|
||||
func TestNewUpdateService(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config UpdateServiceConfig
|
||||
isGitHub bool
|
||||
name string
|
||||
config UpdateServiceConfig
|
||||
expectError bool
|
||||
isGitHub bool
|
||||
wantChannel string
|
||||
}{
|
||||
{
|
||||
name: "github URL detected as GitHub",
|
||||
name: "Valid GitHub URL",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: "https://github.com/owner/repo",
|
||||
},
|
||||
isGitHub: true,
|
||||
isGitHub: true,
|
||||
wantChannel: "stable",
|
||||
},
|
||||
{
|
||||
name: "non-GitHub URL not detected as GitHub",
|
||||
name: "Valid non-GitHub URL",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: "https://example.com/updates",
|
||||
},
|
||||
isGitHub: false,
|
||||
},
|
||||
{
|
||||
name: "GitHub channel is normalised",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: "https://github.com/owner/repo",
|
||||
Channel: " Beta ",
|
||||
},
|
||||
isGitHub: true,
|
||||
wantChannel: "beta",
|
||||
},
|
||||
{
|
||||
name: "GitHub prerelease channel maps to beta",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: "https://github.com/owner/repo",
|
||||
Channel: " prerelease ",
|
||||
},
|
||||
isGitHub: true,
|
||||
wantChannel: "beta",
|
||||
},
|
||||
{
|
||||
name: "Invalid GitHub URL",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: "https://github.com/owner",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
service, err := NewUpdateService(tc.config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
if (err != nil) != tc.expectError {
|
||||
t.Errorf("Expected error: %v, got: %v", tc.expectError, err)
|
||||
}
|
||||
if service.isGitHub != tc.isGitHub {
|
||||
t.Errorf("expected isGitHub=%v, got %v", tc.isGitHub, service.isGitHub)
|
||||
if err == nil && service.isGitHub != tc.isGitHub {
|
||||
t.Errorf("Expected isGitHub: %v, got: %v", tc.isGitHub, service.isGitHub)
|
||||
}
|
||||
if err == nil && tc.wantChannel != "" && service.config.Channel != tc.wantChannel {
|
||||
t.Errorf("Expected GitHub channel %q, got %q", tc.wantChannel, service.config.Channel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_NewUpdateService_Bad(t *testing.T) {
|
||||
_, err := NewUpdateService(UpdateServiceConfig{
|
||||
RepoURL: "https://github.com/owner", // missing repo segment
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid GitHub URL, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_NewUpdateService_Ugly(t *testing.T) {
|
||||
// Empty RepoURL — treated as non-GitHub, no parse error
|
||||
service, err := NewUpdateService(UpdateServiceConfig{RepoURL: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for empty URL: %v", err)
|
||||
}
|
||||
if service.isGitHub {
|
||||
t.Error("expected isGitHub=false for empty URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_Start_Good(t *testing.T) {
|
||||
func TestUpdateService_Start(t *testing.T) {
|
||||
// Setup a mock server for HTTP tests
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"version": "v1.1.0", "url": "http://example.com/release.zip"}`))
|
||||
}))
|
||||
|
|
@ -74,16 +86,17 @@ func TestService_Start_Good(t *testing.T) {
|
|||
checkAndDoGitHub int
|
||||
checkOnlyHTTPCalls int
|
||||
checkAndDoHTTPCalls int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "GitHub NoCheck skips all checks",
|
||||
name: "GitHub: NoCheck",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: "https://github.com/owner/repo",
|
||||
CheckOnStartup: NoCheck,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "GitHub CheckOnStartup calls CheckOnly",
|
||||
name: "GitHub: CheckOnStartup",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: "https://github.com/owner/repo",
|
||||
CheckOnStartup: CheckOnStartup,
|
||||
|
|
@ -91,7 +104,7 @@ func TestService_Start_Good(t *testing.T) {
|
|||
checkOnlyGitHub: 1,
|
||||
},
|
||||
{
|
||||
name: "GitHub CheckAndUpdateOnStartup calls CheckForUpdates",
|
||||
name: "GitHub: CheckAndUpdateOnStartup",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: "https://github.com/owner/repo",
|
||||
CheckOnStartup: CheckAndUpdateOnStartup,
|
||||
|
|
@ -99,14 +112,14 @@ func TestService_Start_Good(t *testing.T) {
|
|||
checkAndDoGitHub: 1,
|
||||
},
|
||||
{
|
||||
name: "HTTP NoCheck skips all checks",
|
||||
name: "HTTP: NoCheck",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: server.URL,
|
||||
CheckOnStartup: NoCheck,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "HTTP CheckOnStartup calls CheckOnlyHTTP",
|
||||
name: "HTTP: CheckOnStartup",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: server.URL,
|
||||
CheckOnStartup: CheckOnStartup,
|
||||
|
|
@ -114,7 +127,7 @@ func TestService_Start_Good(t *testing.T) {
|
|||
checkOnlyHTTPCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "HTTP CheckAndUpdateOnStartup calls CheckForUpdatesHTTP",
|
||||
name: "HTTP: CheckAndUpdateOnStartup",
|
||||
config: UpdateServiceConfig{
|
||||
RepoURL: server.URL,
|
||||
CheckOnStartup: CheckAndUpdateOnStartup,
|
||||
|
|
@ -127,6 +140,7 @@ func TestService_Start_Good(t *testing.T) {
|
|||
t.Run(tc.name, func(t *testing.T) {
|
||||
var checkOnlyGitHub, checkAndDoGitHub, checkOnlyHTTP, checkAndDoHTTP int
|
||||
|
||||
// Mock GitHub functions
|
||||
originalCheckOnly := CheckOnly
|
||||
CheckOnly = func(owner, repo, channel string, forceSemVerPrefix bool, releaseURLFormat string) error {
|
||||
checkOnlyGitHub++
|
||||
|
|
@ -141,6 +155,7 @@ func TestService_Start_Good(t *testing.T) {
|
|||
}
|
||||
defer func() { CheckForUpdates = originalCheckForUpdates }()
|
||||
|
||||
// Mock HTTP functions
|
||||
originalCheckOnlyHTTP := CheckOnlyHTTP
|
||||
CheckOnlyHTTP = func(baseURL string) error {
|
||||
checkOnlyHTTP++
|
||||
|
|
@ -156,52 +171,23 @@ func TestService_Start_Good(t *testing.T) {
|
|||
defer func() { CheckForUpdatesHTTP = originalCheckForUpdatesHTTP }()
|
||||
|
||||
service, _ := NewUpdateService(tc.config)
|
||||
if err := service.Start(); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
err := service.Start()
|
||||
|
||||
if (err != nil) != tc.expectError {
|
||||
t.Errorf("Expected error: %v, got: %v", tc.expectError, err)
|
||||
}
|
||||
if checkOnlyGitHub != tc.checkOnlyGitHub {
|
||||
t.Errorf("GitHub CheckOnly calls: want %d, got %d", tc.checkOnlyGitHub, checkOnlyGitHub)
|
||||
t.Errorf("Expected GitHub CheckOnly calls: %d, got: %d", tc.checkOnlyGitHub, checkOnlyGitHub)
|
||||
}
|
||||
if checkAndDoGitHub != tc.checkAndDoGitHub {
|
||||
t.Errorf("GitHub CheckForUpdates calls: want %d, got %d", tc.checkAndDoGitHub, checkAndDoGitHub)
|
||||
t.Errorf("Expected GitHub CheckForUpdates calls: %d, got: %d", tc.checkAndDoGitHub, checkAndDoGitHub)
|
||||
}
|
||||
if checkOnlyHTTP != tc.checkOnlyHTTPCalls {
|
||||
t.Errorf("HTTP CheckOnly calls: want %d, got %d", tc.checkOnlyHTTPCalls, checkOnlyHTTP)
|
||||
t.Errorf("Expected HTTP CheckOnly calls: %d, got: %d", tc.checkOnlyHTTPCalls, checkOnlyHTTP)
|
||||
}
|
||||
if checkAndDoHTTP != tc.checkAndDoHTTPCalls {
|
||||
t.Errorf("HTTP CheckForUpdates calls: want %d, got %d", tc.checkAndDoHTTPCalls, checkAndDoHTTP)
|
||||
t.Errorf("Expected HTTP CheckForUpdates calls: %d, got: %d", tc.checkAndDoHTTPCalls, checkAndDoHTTP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_Start_Bad(t *testing.T) {
|
||||
// GitHub unknown mode returns an error
|
||||
service := &UpdateService{
|
||||
config: UpdateServiceConfig{CheckOnStartup: StartupCheckMode(99)},
|
||||
isGitHub: true,
|
||||
owner: "owner",
|
||||
repo: "repo",
|
||||
}
|
||||
if err := service.Start(); err == nil {
|
||||
t.Error("expected error for unknown GitHub startup check mode")
|
||||
}
|
||||
|
||||
// HTTP unknown mode returns an error
|
||||
service = &UpdateService{
|
||||
config: UpdateServiceConfig{RepoURL: "https://example.com", CheckOnStartup: StartupCheckMode(99)},
|
||||
isGitHub: false,
|
||||
}
|
||||
if err := service.Start(); err == nil {
|
||||
t.Error("expected error for unknown HTTP startup check mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_Start_Ugly(t *testing.T) {
|
||||
// Start on a zero-value service (no config, no RepoURL) should not panic
|
||||
service := &UpdateService{}
|
||||
// NoCheck mode — returns nil without any action
|
||||
if err := service.Start(); err != nil {
|
||||
t.Errorf("unexpected error from zero-value service with NoCheck: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
84
updater.go
84
updater.go
|
|
@ -2,10 +2,11 @@ package updater
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
"github.com/minio/selfupdate"
|
||||
"golang.org/x/mod/semver"
|
||||
|
|
@ -27,9 +28,8 @@ var NewGithubClient = func() GithubClient {
|
|||
return &githubClient{}
|
||||
}
|
||||
|
||||
// DoUpdate performs the actual binary self-update from the given URL.
|
||||
//
|
||||
// updater.DoUpdate = func(url string) error { return nil } // stub in tests
|
||||
// DoUpdate is a variable that holds the function to perform the actual update.
|
||||
// This can be replaced in tests to prevent actual updates.
|
||||
var DoUpdate = func(url string) error {
|
||||
client := NewHTTPClient()
|
||||
req, err := newAgentRequest(context.Background(), "GET", url)
|
||||
|
|
@ -46,7 +46,7 @@ var DoUpdate = func(url string) error {
|
|||
}(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return coreerr.E("DoUpdate", "failed to download update: "+resp.Status, nil)
|
||||
return coreerr.E("DoUpdate", fmt.Sprintf("failed to download update: %s", resp.Status), nil)
|
||||
}
|
||||
|
||||
err = selfupdate.Apply(resp.Body, selfupdate.Options{})
|
||||
|
|
@ -59,10 +59,9 @@ var DoUpdate = func(url string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// CheckForNewerVersion checks GitHub for a newer release and returns whether one is available.
|
||||
//
|
||||
// release, available, err := updater.CheckForNewerVersion("myorg", "myrepo", "stable", true)
|
||||
// if available { fmt.Println("update:", release.TagName) }
|
||||
// CheckForNewerVersion checks if a newer version of the application is available on GitHub.
|
||||
// It fetches the latest release for the given owner, repository, and channel, and compares its tag
|
||||
// with the current application version.
|
||||
var CheckForNewerVersion = func(owner, repo, channel string, forceSemVerPrefix bool) (*Release, bool, error) {
|
||||
client := NewGithubClient()
|
||||
ctx := context.Background()
|
||||
|
|
@ -87,9 +86,8 @@ var CheckForNewerVersion = func(owner, repo, channel string, forceSemVerPrefix b
|
|||
return release, true, nil // A newer version is available
|
||||
}
|
||||
|
||||
// CheckForUpdates checks for and applies a GitHub update if a newer version is available.
|
||||
//
|
||||
// err := updater.CheckForUpdates("myorg", "myrepo", "stable", true, "")
|
||||
// CheckForUpdates checks for new updates on GitHub and applies them if a newer version is found.
|
||||
// It uses the provided owner, repository, and channel to find the latest release.
|
||||
var CheckForUpdates = func(owner, repo, channel string, forceSemVerPrefix bool, releaseURLFormat string) error {
|
||||
release, updateAvailable, err := CheckForNewerVersion(owner, repo, channel, forceSemVerPrefix)
|
||||
if err != nil {
|
||||
|
|
@ -98,16 +96,16 @@ var CheckForUpdates = func(owner, repo, channel string, forceSemVerPrefix bool,
|
|||
|
||||
if !updateAvailable {
|
||||
if release != nil {
|
||||
cli.Print("Current version %s is up-to-date with latest release %s.\n",
|
||||
fmt.Printf("Current version %s is up-to-date with latest release %s.\n",
|
||||
formatVersionForDisplay(Version, forceSemVerPrefix),
|
||||
formatVersionForDisplay(release.TagName, forceSemVerPrefix))
|
||||
} else {
|
||||
cli.Print("No releases found.\n")
|
||||
fmt.Println("No releases found.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
cli.Print("Newer version %s found (current: %s). Applying update...\n",
|
||||
fmt.Printf("Newer version %s found (current: %s). Applying update...\n",
|
||||
formatVersionForDisplay(release.TagName, forceSemVerPrefix),
|
||||
formatVersionForDisplay(Version, forceSemVerPrefix))
|
||||
|
||||
|
|
@ -119,9 +117,8 @@ var CheckForUpdates = func(owner, repo, channel string, forceSemVerPrefix bool,
|
|||
return DoUpdate(downloadURL)
|
||||
}
|
||||
|
||||
// CheckOnly checks for a newer GitHub release and prints the result without applying any update.
|
||||
//
|
||||
// err := updater.CheckOnly("myorg", "myrepo", "stable", true, "")
|
||||
// CheckOnly checks for new updates on GitHub without applying them.
|
||||
// It prints a message indicating if a new release is available.
|
||||
var CheckOnly = func(owner, repo, channel string, forceSemVerPrefix bool, releaseURLFormat string) error {
|
||||
release, updateAvailable, err := CheckForNewerVersion(owner, repo, channel, forceSemVerPrefix)
|
||||
if err != nil {
|
||||
|
|
@ -130,40 +127,37 @@ var CheckOnly = func(owner, repo, channel string, forceSemVerPrefix bool, releas
|
|||
|
||||
if !updateAvailable {
|
||||
if release != nil {
|
||||
cli.Print("Current version %s is up-to-date with latest release %s.\n",
|
||||
fmt.Printf("Current version %s is up-to-date with latest release %s.\n",
|
||||
formatVersionForDisplay(Version, forceSemVerPrefix),
|
||||
formatVersionForDisplay(release.TagName, forceSemVerPrefix))
|
||||
} else {
|
||||
cli.Print("No new release found.\n")
|
||||
fmt.Println("No new release found.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
cli.Print("New release found: %s (current version: %s)\n",
|
||||
fmt.Printf("New release found: %s (current version: %s)\n",
|
||||
formatVersionForDisplay(release.TagName, forceSemVerPrefix),
|
||||
formatVersionForDisplay(Version, forceSemVerPrefix))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckForUpdatesByTag updates from GitHub using the channel inferred from the current version tag.
|
||||
//
|
||||
// err := updater.CheckForUpdatesByTag("myorg", "myrepo")
|
||||
// CheckForUpdatesByTag checks for and applies updates from GitHub based on the channel
|
||||
// determined by the current application's version tag (e.g., 'stable' or 'prerelease').
|
||||
var CheckForUpdatesByTag = func(owner, repo string) error {
|
||||
channel := determineChannel(Version, false) // isPreRelease is false for current version
|
||||
channel := determineChannel(Version, semver.Prerelease(formatVersionForComparison(Version)) != "")
|
||||
return CheckForUpdates(owner, repo, channel, true, "")
|
||||
}
|
||||
|
||||
// CheckOnlyByTag checks for a newer release using the channel inferred from the current version tag.
|
||||
//
|
||||
// err := updater.CheckOnlyByTag("myorg", "myrepo")
|
||||
// CheckOnlyByTag checks for updates from GitHub based on the channel determined by the
|
||||
// current version tag, without applying them.
|
||||
var CheckOnlyByTag = func(owner, repo string) error {
|
||||
channel := determineChannel(Version, false) // isPreRelease is false for current version
|
||||
channel := determineChannel(Version, semver.Prerelease(formatVersionForComparison(Version)) != "")
|
||||
return CheckOnly(owner, repo, channel, true, "")
|
||||
}
|
||||
|
||||
// CheckForUpdatesByPullRequest finds and applies the GitHub release associated with a specific PR.
|
||||
//
|
||||
// err := updater.CheckForUpdatesByPullRequest("myorg", "myrepo", 42, "")
|
||||
// CheckForUpdatesByPullRequest finds a release associated with a specific pull request number
|
||||
// on GitHub and applies the update.
|
||||
var CheckForUpdatesByPullRequest = func(owner, repo string, prNumber int, releaseURLFormat string) error {
|
||||
client := NewGithubClient()
|
||||
ctx := context.Background()
|
||||
|
|
@ -174,11 +168,11 @@ var CheckForUpdatesByPullRequest = func(owner, repo string, prNumber int, releas
|
|||
}
|
||||
|
||||
if release == nil {
|
||||
cli.Print("No release found for PR #%d.\n", prNumber)
|
||||
fmt.Printf("No release found for PR #%d.\n", prNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
cli.Print("Release %s found for PR #%d. Applying update...\n", release.TagName, prNumber)
|
||||
fmt.Printf("Release %s found for PR #%d. Applying update...\n", release.TagName, prNumber)
|
||||
|
||||
downloadURL, err := GetDownloadURL(release, releaseURLFormat)
|
||||
if err != nil {
|
||||
|
|
@ -188,9 +182,8 @@ var CheckForUpdatesByPullRequest = func(owner, repo string, prNumber int, releas
|
|||
return DoUpdate(downloadURL)
|
||||
}
|
||||
|
||||
// CheckForUpdatesHTTP checks for and applies updates from a generic HTTP endpoint serving latest.json.
|
||||
//
|
||||
// err := updater.CheckForUpdatesHTTP("https://my-server.com/updates")
|
||||
// CheckForUpdatesHTTP checks for and applies updates from a generic HTTP endpoint.
|
||||
// The endpoint is expected to provide update information in a structured format.
|
||||
var CheckForUpdatesHTTP = func(baseURL string) error {
|
||||
info, err := GetLatestUpdateFromURL(baseURL)
|
||||
if err != nil {
|
||||
|
|
@ -201,17 +194,16 @@ var CheckForUpdatesHTTP = func(baseURL string) error {
|
|||
vLatest := formatVersionForComparison(info.Version)
|
||||
|
||||
if semver.Compare(vCurrent, vLatest) >= 0 {
|
||||
cli.Print("Current version %s is up-to-date with latest release %s.\n", Version, info.Version)
|
||||
fmt.Printf("Current version %s is up-to-date with latest release %s.\n", Version, info.Version)
|
||||
return nil
|
||||
}
|
||||
|
||||
cli.Print("Newer version %s found (current: %s). Applying update...\n", info.Version, Version)
|
||||
fmt.Printf("Newer version %s found (current: %s). Applying update...\n", info.Version, Version)
|
||||
return DoUpdate(info.URL)
|
||||
}
|
||||
|
||||
// CheckOnlyHTTP checks for updates from a generic HTTP endpoint without applying them.
|
||||
//
|
||||
// err := updater.CheckOnlyHTTP("https://my-server.com/updates")
|
||||
// It prints a message if a new version is available.
|
||||
var CheckOnlyHTTP = func(baseURL string) error {
|
||||
info, err := GetLatestUpdateFromURL(baseURL)
|
||||
if err != nil {
|
||||
|
|
@ -222,17 +214,17 @@ var CheckOnlyHTTP = func(baseURL string) error {
|
|||
vLatest := formatVersionForComparison(info.Version)
|
||||
|
||||
if semver.Compare(vCurrent, vLatest) >= 0 {
|
||||
cli.Print("Current version %s is up-to-date with latest release %s.\n", Version, info.Version)
|
||||
fmt.Printf("Current version %s is up-to-date with latest release %s.\n", Version, info.Version)
|
||||
return nil
|
||||
}
|
||||
|
||||
cli.Print("New release found: %s (current version: %s)\n", info.Version, Version)
|
||||
fmt.Printf("New release found: %s (current version: %s)\n", info.Version, Version)
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatVersionForComparison ensures the version string has a 'v' prefix for semver comparison.
|
||||
func formatVersionForComparison(version string) string {
|
||||
if version != "" && version[0] != 'v' {
|
||||
if version != "" && !strings.HasPrefix(version, "v") {
|
||||
return "v" + version
|
||||
}
|
||||
return version
|
||||
|
|
@ -240,12 +232,12 @@ func formatVersionForComparison(version string) string {
|
|||
|
||||
// formatVersionForDisplay ensures the version string has the correct 'v' prefix based on the forceSemVerPrefix flag.
|
||||
func formatVersionForDisplay(version string, forceSemVerPrefix bool) string {
|
||||
hasV := len(version) > 0 && version[0] == 'v'
|
||||
hasV := strings.HasPrefix(version, "v")
|
||||
if forceSemVerPrefix && !hasV {
|
||||
return "v" + version
|
||||
}
|
||||
if !forceSemVerPrefix && hasV {
|
||||
return version[1:]
|
||||
return strings.TrimPrefix(version, "v")
|
||||
}
|
||||
return version
|
||||
}
|
||||
|
|
|
|||
174
updater_test.go
174
updater_test.go
|
|
@ -7,7 +7,6 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
)
|
||||
|
|
@ -40,169 +39,6 @@ func (m *mockGithubClient) GetPublicRepos(ctx context.Context, userOrOrg string)
|
|||
return nil, coreerr.E("mockGithubClient.GetPublicRepos", "not implemented", nil)
|
||||
}
|
||||
|
||||
// TestUpdater_CheckForNewerVersion_Good verifies that a newer version is detected correctly.
|
||||
func TestUpdater_CheckForNewerVersion_Good(t *testing.T) {
|
||||
originalNewGithubClient := NewGithubClient
|
||||
defer func() { NewGithubClient = originalNewGithubClient }()
|
||||
|
||||
NewGithubClient = func() GithubClient {
|
||||
return &mockGithubClient{
|
||||
getLatestRelease: func(ctx context.Context, owner, repo, channel string) (*Release, error) {
|
||||
return &Release{TagName: "v1.1.0"}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Version = "1.0.0"
|
||||
release, available, err := CheckForNewerVersion("owner", "repo", "stable", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !available {
|
||||
t.Error("expected update to be available")
|
||||
}
|
||||
if release == nil || release.TagName != "v1.1.0" {
|
||||
t.Errorf("expected release v1.1.0, got %v", release)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdater_CheckForNewerVersion_Bad verifies that a GitHub client error is propagated.
|
||||
func TestUpdater_CheckForNewerVersion_Bad(t *testing.T) {
|
||||
originalNewGithubClient := NewGithubClient
|
||||
defer func() { NewGithubClient = originalNewGithubClient }()
|
||||
|
||||
NewGithubClient = func() GithubClient {
|
||||
return &mockGithubClient{
|
||||
getLatestRelease: func(ctx context.Context, owner, repo, channel string) (*Release, error) {
|
||||
return nil, coreerr.E("test", "simulated network error", nil)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
_, _, err := CheckForNewerVersion("owner", "repo", "stable", true)
|
||||
if err == nil {
|
||||
t.Error("expected error from failing client, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdater_CheckForNewerVersion_Ugly verifies no-release and already-up-to-date paths.
|
||||
func TestUpdater_CheckForNewerVersion_Ugly(t *testing.T) {
|
||||
originalNewGithubClient := NewGithubClient
|
||||
defer func() { NewGithubClient = originalNewGithubClient }()
|
||||
|
||||
// No release found
|
||||
NewGithubClient = func() GithubClient {
|
||||
return &mockGithubClient{
|
||||
getLatestRelease: func(ctx context.Context, owner, repo, channel string) (*Release, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
release, available, err := CheckForNewerVersion("owner", "repo", "stable", true)
|
||||
if err != nil || release != nil || available {
|
||||
t.Errorf("expected nil/nil/false for no release; got release=%v available=%v err=%v", release, available, err)
|
||||
}
|
||||
|
||||
// Already on latest version
|
||||
NewGithubClient = func() GithubClient {
|
||||
return &mockGithubClient{
|
||||
getLatestRelease: func(ctx context.Context, owner, repo, channel string) (*Release, error) {
|
||||
return &Release{TagName: "v1.0.0"}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
Version = "1.0.0"
|
||||
_, available, err = CheckForNewerVersion("owner", "repo", "stable", true)
|
||||
if err != nil || available {
|
||||
t.Errorf("expected available=false when already on latest; err=%v available=%v", err, available)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdater_DoUpdate_Good verifies successful binary replacement.
|
||||
func TestUpdater_DoUpdate_Good(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return minimal valid binary content (selfupdate will fail but we test the HTTP path)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("binary-content"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
originalDoUpdate := DoUpdate
|
||||
defer func() { DoUpdate = originalDoUpdate }()
|
||||
|
||||
called := false
|
||||
DoUpdate = func(url string) error {
|
||||
called = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := DoUpdate(server.URL); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("expected DoUpdate to be called")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdater_DoUpdate_Bad verifies that HTTP errors are propagated.
|
||||
func TestUpdater_DoUpdate_Bad(t *testing.T) {
|
||||
// Restore original DoUpdate for this test
|
||||
originalDoUpdate := DoUpdate
|
||||
defer func() { DoUpdate = originalDoUpdate }()
|
||||
|
||||
// Reset to real implementation
|
||||
DoUpdate = func(downloadURL string) error {
|
||||
client := NewHTTPClient()
|
||||
req, err := newAgentRequest(context.Background(), "GET", downloadURL)
|
||||
if err != nil {
|
||||
return coreerr.E("DoUpdate", "failed to create update request", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return coreerr.E("DoUpdate", "failed to download update", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return coreerr.E("DoUpdate", "failed to download update: "+resp.Status, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
if err := DoUpdate(server.URL); err == nil {
|
||||
t.Error("expected error for 404 response, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdater_DoUpdate_Ugly verifies that an invalid URL returns an error.
|
||||
func TestUpdater_DoUpdate_Ugly(t *testing.T) {
|
||||
originalDoUpdate := DoUpdate
|
||||
defer func() { DoUpdate = originalDoUpdate }()
|
||||
|
||||
// Reset to real implementation
|
||||
DoUpdate = func(downloadURL string) error {
|
||||
client := NewHTTPClient()
|
||||
req, err := newAgentRequest(context.Background(), "GET", downloadURL)
|
||||
if err != nil {
|
||||
return coreerr.E("DoUpdate", "failed to create update request", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return coreerr.E("DoUpdate", "failed to download update", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := DoUpdate("http://127.0.0.1:0/invalid"); err == nil {
|
||||
t.Error("expected error for unreachable URL, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleCheckForNewerVersion() {
|
||||
originalNewGithubClient := NewGithubClient
|
||||
defer func() { NewGithubClient = originalNewGithubClient }()
|
||||
|
|
@ -230,6 +66,7 @@ func ExampleCheckForNewerVersion() {
|
|||
}
|
||||
|
||||
func ExampleCheckForUpdates() {
|
||||
// Mock the functions to prevent actual updates and network calls
|
||||
originalDoUpdate := DoUpdate
|
||||
originalNewGithubClient := NewGithubClient
|
||||
defer func() {
|
||||
|
|
@ -284,6 +121,7 @@ func ExampleCheckOnly() {
|
|||
}
|
||||
|
||||
func ExampleCheckForUpdatesByTag() {
|
||||
// Mock the functions to prevent actual updates and network calls
|
||||
originalDoUpdate := DoUpdate
|
||||
originalNewGithubClient := NewGithubClient
|
||||
defer func() {
|
||||
|
|
@ -310,7 +148,7 @@ func ExampleCheckForUpdatesByTag() {
|
|||
return nil
|
||||
}
|
||||
|
||||
Version = "1.0.0"
|
||||
Version = "1.0.0" // A version that resolves to the "stable" channel
|
||||
err := CheckForUpdatesByTag("owner", "repo")
|
||||
if err != nil {
|
||||
log.Fatalf("CheckForUpdatesByTag failed: %v", err)
|
||||
|
|
@ -335,7 +173,7 @@ func ExampleCheckOnlyByTag() {
|
|||
}
|
||||
}
|
||||
|
||||
Version = "1.0.0"
|
||||
Version = "1.0.0" // A version that resolves to the "stable" channel
|
||||
err := CheckOnlyByTag("owner", "repo")
|
||||
if err != nil {
|
||||
log.Fatalf("CheckOnlyByTag failed: %v", err)
|
||||
|
|
@ -344,6 +182,7 @@ func ExampleCheckOnlyByTag() {
|
|||
}
|
||||
|
||||
func ExampleCheckForUpdatesByPullRequest() {
|
||||
// Mock the functions to prevent actual updates and network calls
|
||||
originalDoUpdate := DoUpdate
|
||||
originalNewGithubClient := NewGithubClient
|
||||
defer func() {
|
||||
|
|
@ -380,6 +219,7 @@ func ExampleCheckForUpdatesByPullRequest() {
|
|||
}
|
||||
|
||||
func ExampleCheckForUpdatesHTTP() {
|
||||
// Create a mock HTTP server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/latest.json" {
|
||||
_, _ = fmt.Fprintln(w, `{"version": "1.1.0", "url": "http://example.com/update"}`)
|
||||
|
|
@ -387,6 +227,7 @@ func ExampleCheckForUpdatesHTTP() {
|
|||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Mock the doUpdateFunc to prevent actual updates
|
||||
originalDoUpdate := DoUpdate
|
||||
defer func() { DoUpdate = originalDoUpdate }()
|
||||
DoUpdate = func(url string) error {
|
||||
|
|
@ -405,6 +246,7 @@ func ExampleCheckForUpdatesHTTP() {
|
|||
}
|
||||
|
||||
func ExampleCheckOnlyHTTP() {
|
||||
// Create a mock HTTP server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/latest.json" {
|
||||
_, _ = fmt.Fprintln(w, `{"version": "1.1.0", "url": "http://example.com/update"}`)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue