Compare commits

..

No commits in common. "dev" and "v0.0.1" have entirely different histories.
dev ... v0.0.1

14 changed files with 50 additions and 496 deletions

View file

@ -1,75 +0,0 @@
# CLAUDE.md — go-update
This file provides guidance to Claude Code when working with the `go-update` package.
## Package Overview
`go-update` (`forge.lthn.ai/core/go-update`) is a **self-updater library** for Go applications. It supports updates from GitHub releases and generic HTTP endpoints, with configurable startup behaviour and version channel filtering.
## Build & Test Commands
```bash
# Run all tests
go test ./...
# Run tests with coverage
go test -cover ./...
# Run a single test
go test -run TestName ./...
# Generate version.go from package.json
go generate ./...
# Vet and lint
go vet ./...
```
## Architecture
### Update Sources
| Source | Description |
|--------|-------------|
| GitHub Releases | Fetches releases via GitHub API, filters by channel (stable/beta/alpha) |
| Generic HTTP | Fetches `latest.json` from a base URL with version + download URL |
### Key Types
- **`UpdateService`** — Configured service that checks for updates on startup
- **`GithubClient`** — Interface for GitHub API interactions (mockable for tests)
- **`Release`** / **`ReleaseAsset`** — GitHub release model
- **`GenericUpdateInfo`** — HTTP update endpoint model
### Testable Function Variables
Core update logic is exposed as `var` function values so tests can replace them:
- `NewGithubClient` — Factory for GitHub client (replace with mock)
- `DoUpdate` — Performs the actual binary update
- `CheckForNewerVersion`, `CheckForUpdates`, `CheckOnly` — GitHub update flow
- `CheckForUpdatesHTTP`, `CheckOnlyHTTP` — HTTP update flow
- `NewAuthenticatedClient` — HTTP client factory (supports `GITHUB_TOKEN`)
### Error Handling
All errors **must** use `coreerr.E()` from `forge.lthn.ai/core/go-log`:
```go
import coreerr "forge.lthn.ai/core/go-log"
return coreerr.E("FunctionName", "what failed", underlyingErr)
```
Never use `fmt.Errorf` or `errors.New`.
### File I/O
Use `forge.lthn.ai/core/go-io` for file operations, not `os.ReadFile`/`os.WriteFile`.
## Coding Standards
- **UK English** in comments and strings
- **Strict types**: All parameters and return types
- **Test naming**: `_Good`, `_Bad`, `_Ugly` suffix pattern
- **License**: EUPL-1.2

View file

@ -3,33 +3,32 @@ package main
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"os"
coreio "forge.lthn.ai/core/go-io"
) )
func main() { func main() {
// Read package.json // Read package.json
data, err := coreio.Read(coreio.Local, "package.json") data, err := os.ReadFile("package.json")
if err != nil { if err != nil {
fmt.Println("Error reading package.json, skipping version file generation.") fmt.Println("Error reading package.json, skipping version file generation.")
return os.Exit(0)
} }
// Parse package.json // Parse package.json
var pkg struct { var pkg struct {
Version string `json:"version"` Version string `json:"version"`
} }
if err := json.Unmarshal([]byte(data), &pkg); err != nil { if err := json.Unmarshal(data, &pkg); err != nil {
fmt.Println("Error parsing package.json, skipping version file generation.") fmt.Println("Error parsing package.json, skipping version file generation.")
return os.Exit(0)
} }
// Create the version file // Create the version file
content := fmt.Sprintf("package updater\n\n// Generated by go:generate. DO NOT EDIT.\n\nconst PkgVersion = %q\n", pkg.Version) content := fmt.Sprintf("package updater\n\n// Generated by go:generate. DO NOT EDIT.\n\nconst PkgVersion = %q\n", pkg.Version)
err = coreio.Write(coreio.Local, "version.go", content) err = os.WriteFile("version.go", []byte(content), 0644)
if err != nil { if err != nil {
fmt.Printf("Error writing version file: %v\n", err) fmt.Printf("Error writing version file: %v\n", err)
return os.Exit(1)
} }
fmt.Println("Generated version.go with version:", pkg.Version) fmt.Println("Generated version.go with version:", pkg.Version)

17
cmd.go
View file

@ -45,9 +45,9 @@ Examples:
RunE: runUpdate, RunE: runUpdate,
} }
updateCmd.PersistentFlags().StringVar(&updateChannel, "channel", "stable", "Release channel: stable, beta, alpha, prerelease, or dev") updateCmd.PersistentFlags().StringVar(&updateChannel, "channel", "stable", "Release channel: stable, beta, alpha, or dev")
updateCmd.PersistentFlags().BoolVar(&updateForce, "force", false, "Force update even if already on latest version") updateCmd.PersistentFlags().BoolVar(&updateForce, "force", false, "Force update even if already on latest version")
updateCmd.Flags().BoolVar(&updateCheck, "check", false, "Only check for updates, do not apply") updateCmd.Flags().BoolVar(&updateCheck, "check", false, "Only check for updates, don't apply")
updateCmd.Flags().IntVar(&updateWatchPID, "watch-pid", 0, "Internal: watch for parent PID to die then restart") updateCmd.Flags().IntVar(&updateWatchPID, "watch-pid", 0, "Internal: watch for parent PID to die then restart")
_ = updateCmd.Flags().MarkHidden("watch-pid") _ = updateCmd.Flags().MarkHidden("watch-pid")
@ -55,9 +55,7 @@ Examples:
Use: "check", Use: "check",
Short: "Check for available updates", Short: "Check for available updates",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
previousCheck := updateCheck
updateCheck = true updateCheck = true
defer func() { updateCheck = previousCheck }()
return runUpdate(cmd, args) return runUpdate(cmd, args)
}, },
}) })
@ -72,25 +70,24 @@ func runUpdate(cmd *cobra.Command, args []string) error {
} }
currentVersion := cli.AppVersion currentVersion := cli.AppVersion
normalizedChannel := normaliseGitHubChannel(updateChannel)
cli.Print("%s %s\n", cli.DimStyle.Render("Current version:"), cli.ValueStyle.Render(currentVersion)) cli.Print("%s %s\n", cli.DimStyle.Render("Current version:"), cli.ValueStyle.Render(currentVersion))
cli.Print("%s %s/%s\n", cli.DimStyle.Render("Platform:"), runtime.GOOS, runtime.GOARCH) cli.Print("%s %s/%s\n", cli.DimStyle.Render("Platform:"), runtime.GOOS, runtime.GOARCH)
cli.Print("%s %s\n\n", cli.DimStyle.Render("Channel:"), normalizedChannel) cli.Print("%s %s\n\n", cli.DimStyle.Render("Channel:"), updateChannel)
// Handle dev channel specially - it's a prerelease tag, not a semver channel // Handle dev channel specially - it's a prerelease tag, not a semver channel
if normalizedChannel == "dev" { if updateChannel == "dev" {
return handleDevUpdate(currentVersion) return handleDevUpdate(currentVersion)
} }
// Check for newer version // Check for newer version
release, updateAvailable, err := CheckForNewerVersion(repoOwner, repoName, normalizedChannel, true) release, updateAvailable, err := CheckForNewerVersion(repoOwner, repoName, updateChannel, true)
if err != nil { if err != nil {
return cli.Wrap(err, "failed to check for updates") return cli.Wrap(err, "failed to check for updates")
} }
if release == nil { if release == nil {
cli.Print("%s No releases found in %s channel\n", cli.WarningStyle.Render("!"), normalizedChannel) cli.Print("%s No releases found in %s channel\n", cli.WarningStyle.Render("!"), updateChannel)
return nil return nil
} }
@ -143,7 +140,7 @@ func handleDevUpdate(currentVersion string) error {
client := NewGithubClient() client := NewGithubClient()
// Fetch the dev release directly by tag // Fetch the dev release directly by tag
release, err := client.GetLatestRelease(context.Background(), repoOwner, repoName, "beta") release, err := client.GetLatestRelease(context.TODO(), repoOwner, repoName, "beta")
if err != nil { if err != nil {
// Try fetching the "dev" tag directly // Try fetching the "dev" tag directly
return handleDevTagUpdate(currentVersion) return handleDevTagUpdate(currentVersion)

View file

@ -26,7 +26,9 @@ The `CheckOnStartup` field can take one of the following values:
If you are using the example CLI provided in `cmd/updater`, the following flags are available: If you are using the example CLI provided in `cmd/updater`, the following flags are available:
* `--check`: Check for new updates without applying them. * `--check-update`: Check for new updates without applying them.
* `--channel`: Set the update channel (e.g., stable, beta, alpha). Defaults to `stable`. * `--do-update`: Perform an update if available.
* `--force`: Force update even when already on latest. * `--channel`: Set the update channel (e.g., stable, beta, alpha). If not set, it's determined from the current version tag.
* `--watch-pid`: Internal flag used during restart after update. * `--force-semver-prefix`: Force 'v' prefix on semver tags (default `true`).
* `--release-url-format`: A URL format for release assets.
* `--pull-request`: Update to a specific pull request (integer ID).

View file

@ -3,10 +3,8 @@ package updater
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"context"
"net/http" "net/http"
"net/url" "net/url"
"strings"
coreerr "forge.lthn.ai/core/go-log" coreerr "forge.lthn.ai/core/go-log"
) )
@ -33,16 +31,10 @@ func GetLatestUpdateFromURL(baseURL string) (*GenericUpdateInfo, error) {
if err != nil { if err != nil {
return nil, coreerr.E("GetLatestUpdateFromURL", "invalid base URL", err) return nil, coreerr.E("GetLatestUpdateFromURL", "invalid base URL", err)
} }
// Append latest.json to the path // Append latest.json to the path
u.Path = strings.TrimSuffix(u.Path, "/") + "/latest.json" u.Path += "/latest.json"
req, err := newAgentRequest(context.Background(), "GET", u.String()) resp, err := http.Get(u.String())
if err != nil {
return nil, coreerr.E("GetLatestUpdateFromURL", "failed to create update check request", err)
}
resp, err := NewHTTPClient().Do(req)
if err != nil { if err != nil {
return nil, coreerr.E("GetLatestUpdateFromURL", "failed to fetch latest.json", err) return nil, coreerr.E("GetLatestUpdateFromURL", "failed to fetch latest.json", err)
} }

View file

@ -52,13 +52,10 @@ var NewAuthenticatedClient = func(ctx context.Context) *http.Client {
if token == "" { if token == "" {
return http.DefaultClient return http.DefaultClient
} }
ts := oauth2.StaticTokenSource( ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: token}, &oauth2.Token{AccessToken: token},
) )
client := oauth2.NewClient(ctx, ts) return oauth2.NewClient(ctx, ts)
client.Timeout = defaultHTTPTimeout
return client
} }
func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) { func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
@ -74,10 +71,11 @@ func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, use
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return nil, err return nil, err
} }
req, err := newAgentRequest(ctx, "GET", url) req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("User-Agent", "Borg-Data-Collector")
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
@ -87,10 +85,11 @@ func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, use
_ = resp.Body.Close() _ = resp.Body.Close()
// Try organization endpoint // Try organization endpoint
url = fmt.Sprintf("%s/orgs/%s/repos", apiURL, userOrOrg) url = fmt.Sprintf("%s/orgs/%s/repos", apiURL, userOrOrg)
req, err = newAgentRequest(ctx, "GET", url) req, err = http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("User-Agent", "Borg-Data-Collector")
resp, err = client.Do(req) resp, err = client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
@ -144,10 +143,11 @@ func (g *githubClient) GetLatestRelease(ctx context.Context, owner, repo, channe
client := NewAuthenticatedClient(ctx) client := NewAuthenticatedClient(ctx)
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases", owner, repo) url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases", owner, repo)
req, err := newAgentRequest(ctx, "GET", url) req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("User-Agent", "Borg-Data-Collector")
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
@ -198,10 +198,11 @@ func (g *githubClient) GetReleaseByPullRequest(ctx context.Context, owner, repo
client := NewAuthenticatedClient(ctx) client := NewAuthenticatedClient(ctx)
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases", owner, repo) url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases", owner, repo)
req, err := newAgentRequest(ctx, "GET", url) req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("User-Agent", "Borg-Data-Collector")
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {

View file

@ -1,284 +0,0 @@
package updater
import (
"runtime"
"testing"
)
func TestFilterReleases_Good(t *testing.T) {
releases := []Release{
{TagName: "v1.0.0-alpha.1", PreRelease: true},
{TagName: "v1.0.0-beta.1", PreRelease: true},
{TagName: "v1.0.0", PreRelease: false},
}
tests := []struct {
channel string
wantTag string
}{
{"stable", "v1.0.0"},
{"alpha", "v1.0.0-alpha.1"},
{"beta", "v1.0.0-beta.1"},
}
for _, tt := range tests {
t.Run(tt.channel, func(t *testing.T) {
got := filterReleases(releases, tt.channel)
if got == nil {
t.Fatalf("expected release for channel %q, got nil", tt.channel)
}
if got.TagName != tt.wantTag {
t.Errorf("expected tag %q, got %q", tt.wantTag, got.TagName)
}
})
}
}
func TestFilterReleases_Bad(t *testing.T) {
releases := []Release{
{TagName: "v1.0.0", PreRelease: false},
}
got := filterReleases(releases, "alpha")
if got != nil {
t.Errorf("expected nil for non-matching channel, got %v", got)
}
}
func TestFilterReleases_PreReleaseWithoutLabel(t *testing.T) {
releases := []Release{
{TagName: "v2.0.0-rc.1", PreRelease: true},
}
got := filterReleases(releases, "beta")
if got == nil {
t.Fatal("expected pre-release without alpha/beta label to match beta channel")
}
if got.TagName != "v2.0.0-rc.1" {
t.Errorf("expected tag %q, got %q", "v2.0.0-rc.1", got.TagName)
}
}
func TestDetermineChannel_Good(t *testing.T) {
tests := []struct {
tag string
isPreRelease bool
want string
}{
{"v1.0.0", false, "stable"},
{"v1.0.0-alpha.1", false, "alpha"},
{"v1.0.0-ALPHA.1", false, "alpha"},
{"v1.0.0-beta.1", false, "beta"},
{"v1.0.0-BETA.1", false, "beta"},
{"v1.0.0-rc.1", true, "beta"},
{"v1.0.0-rc.1", false, "stable"},
}
for _, tt := range tests {
t.Run(tt.tag, func(t *testing.T) {
got := determineChannel(tt.tag, tt.isPreRelease)
if got != tt.want {
t.Errorf("determineChannel(%q, %v) = %q, want %q", tt.tag, tt.isPreRelease, got, tt.want)
}
})
}
}
func TestCheckForUpdatesByTag_UsesCurrentVersionChannel(t *testing.T) {
originalVersion := Version
originalCheckForUpdates := CheckForUpdates
defer func() {
Version = originalVersion
CheckForUpdates = originalCheckForUpdates
}()
var gotChannel string
CheckForUpdates = func(owner, repo, channel string, forceSemVerPrefix bool, releaseURLFormat string) error {
gotChannel = channel
return nil
}
Version = "v2.0.0-rc.1"
if err := CheckForUpdatesByTag("owner", "repo"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotChannel != "beta" {
t.Fatalf("expected beta channel, got %q", gotChannel)
}
}
func TestCheckOnlyByTag_UsesCurrentVersionChannel(t *testing.T) {
originalVersion := Version
originalCheckOnly := CheckOnly
defer func() {
Version = originalVersion
CheckOnly = originalCheckOnly
}()
var gotChannel string
CheckOnly = func(owner, repo, channel string, forceSemVerPrefix bool, releaseURLFormat string) error {
gotChannel = channel
return nil
}
Version = "v2.0.0-alpha.1"
if err := CheckOnlyByTag("owner", "repo"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotChannel != "alpha" {
t.Fatalf("expected alpha channel, got %q", gotChannel)
}
}
func TestGetDownloadURL_Good(t *testing.T) {
osName := runtime.GOOS
archName := runtime.GOARCH
release := &Release{
TagName: "v1.2.3",
Assets: []ReleaseAsset{
{Name: "app-" + osName + "-" + archName, DownloadURL: "https://example.com/full-match"},
{Name: "app-" + osName, DownloadURL: "https://example.com/os-only"},
},
}
url, err := GetDownloadURL(release, "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if url != "https://example.com/full-match" {
t.Errorf("expected full match URL, got %q", url)
}
}
func TestGetDownloadURL_OSOnlyFallback(t *testing.T) {
osName := runtime.GOOS
release := &Release{
TagName: "v1.2.3",
Assets: []ReleaseAsset{
{Name: "app-other-other", DownloadURL: "https://example.com/other"},
{Name: "app-" + osName + "-other", DownloadURL: "https://example.com/os-only"},
},
}
url, err := GetDownloadURL(release, "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if url != "https://example.com/os-only" {
t.Errorf("expected OS-only fallback URL, got %q", url)
}
}
func TestGetDownloadURL_WithFormat(t *testing.T) {
release := &Release{TagName: "v1.2.3"}
url, err := GetDownloadURL(release, "https://example.com/{tag}/{os}/{arch}")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "https://example.com/v1.2.3/" + runtime.GOOS + "/" + runtime.GOARCH
if url != expected {
t.Errorf("expected %q, got %q", expected, url)
}
}
func TestGetDownloadURL_Bad(t *testing.T) {
// nil release
_, err := GetDownloadURL(nil, "")
if err == nil {
t.Error("expected error for nil release")
}
// No matching assets
release := &Release{
TagName: "v1.2.3",
Assets: []ReleaseAsset{
{Name: "app-unknownos-unknownarch", DownloadURL: "https://example.com/other"},
},
}
_, err = GetDownloadURL(release, "")
if err == nil {
t.Error("expected error when no suitable asset found")
}
}
func TestFormatVersionForComparison(t *testing.T) {
tests := []struct {
input string
want string
}{
{"1.0.0", "v1.0.0"},
{"v1.0.0", "v1.0.0"},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := formatVersionForComparison(tt.input)
if got != tt.want {
t.Errorf("formatVersionForComparison(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestFormatVersionForDisplay(t *testing.T) {
tests := []struct {
version string
force bool
want string
}{
{"1.0.0", true, "v1.0.0"},
{"v1.0.0", true, "v1.0.0"},
{"v1.0.0", false, "1.0.0"},
{"1.0.0", false, "1.0.0"},
}
for _, tt := range tests {
t.Run(tt.version+"_force_"+boolStr(tt.force), func(t *testing.T) {
got := formatVersionForDisplay(tt.version, tt.force)
if got != tt.want {
t.Errorf("formatVersionForDisplay(%q, %v) = %q, want %q", tt.version, tt.force, got, tt.want)
}
})
}
}
func boolStr(b bool) string {
if b {
return "true"
}
return "false"
}
func TestStartGitHubCheck_UnknownMode(t *testing.T) {
s := &UpdateService{
config: UpdateServiceConfig{
CheckOnStartup: StartupCheckMode(99),
},
isGitHub: true,
owner: "owner",
repo: "repo",
}
err := s.Start()
if err == nil {
t.Error("expected error for unknown startup check mode")
}
}
func TestStartHTTPCheck_UnknownMode(t *testing.T) {
s := &UpdateService{
config: UpdateServiceConfig{
RepoURL: "https://example.com/updates",
CheckOnStartup: StartupCheckMode(99),
},
isGitHub: false,
}
err := s.Start()
if err == nil {
t.Error("expected error for unknown startup check mode")
}
}

13
go.mod
View file

@ -1,11 +1,10 @@
module dappco.re/go/core/update module forge.lthn.ai/core/go-update
go 1.26.0 go 1.26.0
require ( require (
dappco.re/go/core/cli v0.3.6 forge.lthn.ai/core/cli v0.3.6
dappco.re/go/core/io v0.1.5 forge.lthn.ai/core/go-log v0.0.4
dappco.re/go/core/log v0.0.4
github.com/Snider/Borg v0.2.0 github.com/Snider/Borg v0.2.0
github.com/minio/selfupdate v0.6.0 github.com/minio/selfupdate v0.6.0
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2
@ -16,9 +15,9 @@ require (
require ( require (
aead.dev/minisign v0.3.0 // indirect aead.dev/minisign v0.3.0 // indirect
dario.cat/mergo v1.0.0 // indirect dario.cat/mergo v1.0.0 // indirect
dappco.re/go/core v0.3.1 // indirect forge.lthn.ai/core/go v0.3.1 // indirect
dappco.re/go/core/i18n v0.1.6 // indirect forge.lthn.ai/core/go-i18n v0.1.6 // indirect
dappco.re/go/core/inference v0.1.5 // indirect forge.lthn.ai/core/go-inference v0.1.5 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.4.0 // indirect github.com/ProtonMail/go-crypto v1.4.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect

2
go.sum
View file

@ -11,8 +11,6 @@ forge.lthn.ai/core/go-i18n v0.1.6 h1:Z9h6sEZsgJmWlkkq3ZPZyfgWipeeqN5lDCpzltpamHU
forge.lthn.ai/core/go-i18n v0.1.6/go.mod h1:C6CbwdN7sejTx/lbutBPrxm77b8paMHBO6uHVLHOdqQ= forge.lthn.ai/core/go-i18n v0.1.6/go.mod h1:C6CbwdN7sejTx/lbutBPrxm77b8paMHBO6uHVLHOdqQ=
forge.lthn.ai/core/go-inference v0.1.5 h1:Az/Euv1DusJQJz/Eca0Ey7sVXQkFLPHW0TBrs9g+Qwg= forge.lthn.ai/core/go-inference v0.1.5 h1:Az/Euv1DusJQJz/Eca0Ey7sVXQkFLPHW0TBrs9g+Qwg=
forge.lthn.ai/core/go-inference v0.1.5/go.mod h1:jfWz+IJX55wAH98+ic6FEqqGB6/P31CHlg7VY7pxREw= forge.lthn.ai/core/go-inference v0.1.5/go.mod h1:jfWz+IJX55wAH98+ic6FEqqGB6/P31CHlg7VY7pxREw=
forge.lthn.ai/core/go-io v0.1.5 h1:+XJ1YhaGGFLGtcNbPtVlndTjk+pO0Ydi2hRDj5/cHOM=
forge.lthn.ai/core/go-io v0.1.5/go.mod h1:FRtXSsi8W+U9vewCU+LBAqqbIj3wjXA4dBdSv3SAtWI=
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0= forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw= forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=

View file

@ -1,32 +0,0 @@
package updater
import (
"context"
"fmt"
"net/http"
"time"
)
const defaultHTTPTimeout = 30 * time.Second
var NewHTTPClient = func() *http.Client {
return &http.Client{Timeout: defaultHTTPTimeout}
}
func newAgentRequest(ctx context.Context, method, url string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", updaterUserAgent())
return req, nil
}
func updaterUserAgent() string {
version := formatVersionForDisplay(Version, true)
if version == "" {
version = "unknown"
}
return fmt.Sprintf("agent-go-update/%s", version)
}

View file

@ -30,8 +30,7 @@ type UpdateServiceConfig struct {
// repository URL (e.g., "https://github.com/owner/repo") or a base URL // repository URL (e.g., "https://github.com/owner/repo") or a base URL
// for a generic HTTP update server. // for a generic HTTP update server.
RepoURL string RepoURL string
// Channel specifies the release channel to track (e.g., "stable", "beta", or "prerelease"). // Channel specifies the release channel to track (e.g., "stable", "prerelease").
// "prerelease" is normalised to "beta" to match the GitHub release filter.
// This is only used for GitHub-based updates. // This is only used for GitHub-based updates.
Channel string Channel string
// CheckOnStartup determines the update behavior when the service starts. // CheckOnStartup determines the update behavior when the service starts.
@ -64,7 +63,6 @@ func NewUpdateService(config UpdateServiceConfig) (*UpdateService, error) {
var err error var err error
if isGitHub { if isGitHub {
config.Channel = normaliseGitHubChannel(config.Channel)
owner, repo, err = ParseRepoURL(config.RepoURL) owner, repo, err = ParseRepoURL(config.RepoURL)
if err != nil { if err != nil {
return nil, coreerr.E("NewUpdateService", "failed to parse GitHub repo URL", err) return nil, coreerr.E("NewUpdateService", "failed to parse GitHub repo URL", err)
@ -129,14 +127,3 @@ func ParseRepoURL(repoURL string) (owner string, repo string, err error) {
} }
return parts[0], parts[1], nil return parts[0], parts[1], nil
} }
func normaliseGitHubChannel(channel string) string {
channel = strings.ToLower(strings.TrimSpace(channel))
if channel == "" {
return "stable"
}
if channel == "prerelease" {
return "beta"
}
return channel
}

View file

@ -12,15 +12,13 @@ func TestNewUpdateService(t *testing.T) {
config UpdateServiceConfig config UpdateServiceConfig
expectError bool expectError bool
isGitHub bool isGitHub bool
wantChannel string
}{ }{
{ {
name: "Valid GitHub URL", name: "Valid GitHub URL",
config: UpdateServiceConfig{ config: UpdateServiceConfig{
RepoURL: "https://github.com/owner/repo", RepoURL: "https://github.com/owner/repo",
}, },
isGitHub: true, isGitHub: true,
wantChannel: "stable",
}, },
{ {
name: "Valid non-GitHub URL", name: "Valid non-GitHub URL",
@ -29,24 +27,6 @@ func TestNewUpdateService(t *testing.T) {
}, },
isGitHub: false, isGitHub: false,
}, },
{
name: "GitHub channel is normalised",
config: UpdateServiceConfig{
RepoURL: "https://github.com/owner/repo",
Channel: " Beta ",
},
isGitHub: true,
wantChannel: "beta",
},
{
name: "GitHub prerelease channel maps to beta",
config: UpdateServiceConfig{
RepoURL: "https://github.com/owner/repo",
Channel: " prerelease ",
},
isGitHub: true,
wantChannel: "beta",
},
{ {
name: "Invalid GitHub URL", name: "Invalid GitHub URL",
config: UpdateServiceConfig{ config: UpdateServiceConfig{
@ -65,9 +45,6 @@ func TestNewUpdateService(t *testing.T) {
if err == nil && service.isGitHub != tc.isGitHub { if err == nil && service.isGitHub != tc.isGitHub {
t.Errorf("Expected isGitHub: %v, got: %v", tc.isGitHub, service.isGitHub) t.Errorf("Expected isGitHub: %v, got: %v", tc.isGitHub, service.isGitHub)
} }
if err == nil && tc.wantChannel != "" && service.config.Channel != tc.wantChannel {
t.Errorf("Expected GitHub channel %q, got %q", tc.wantChannel, service.config.Channel)
}
}) })
} }
} }

View file

@ -31,24 +31,17 @@ var NewGithubClient = func() GithubClient {
// DoUpdate is a variable that holds the function to perform the actual update. // DoUpdate is a variable that holds the function to perform the actual update.
// This can be replaced in tests to prevent actual updates. // This can be replaced in tests to prevent actual updates.
var DoUpdate = func(url string) error { var DoUpdate = func(url string) error {
client := NewHTTPClient() resp, err := http.Get(url)
req, err := newAgentRequest(context.Background(), "GET", url)
if err != nil { if err != nil {
return coreerr.E("DoUpdate", "failed to create update request", err) return err
}
resp, err := client.Do(req)
if err != nil {
return coreerr.E("DoUpdate", "failed to download update", err)
} }
defer func(Body io.ReadCloser) { defer func(Body io.ReadCloser) {
_ = Body.Close() err := Body.Close()
if err != nil {
fmt.Printf("failed to close response body: %v\n", err)
}
}(resp.Body) }(resp.Body)
if resp.StatusCode != http.StatusOK {
return coreerr.E("DoUpdate", fmt.Sprintf("failed to download update: %s", resp.Status), nil)
}
err = selfupdate.Apply(resp.Body, selfupdate.Options{}) err = selfupdate.Apply(resp.Body, selfupdate.Options{})
if err != nil { if err != nil {
if rerr := selfupdate.RollbackError(err); rerr != nil { if rerr := selfupdate.RollbackError(err); rerr != nil {
@ -56,6 +49,8 @@ var DoUpdate = func(url string) error {
} }
return coreerr.E("DoUpdate", "update failed", err) return coreerr.E("DoUpdate", "update failed", err)
} }
fmt.Println("Update applied successfully.")
return nil return nil
} }
@ -145,14 +140,14 @@ var CheckOnly = func(owner, repo, channel string, forceSemVerPrefix bool, releas
// CheckForUpdatesByTag checks for and applies updates from GitHub based on the channel // CheckForUpdatesByTag checks for and applies updates from GitHub based on the channel
// determined by the current application's version tag (e.g., 'stable' or 'prerelease'). // determined by the current application's version tag (e.g., 'stable' or 'prerelease').
var CheckForUpdatesByTag = func(owner, repo string) error { var CheckForUpdatesByTag = func(owner, repo string) error {
channel := determineChannel(Version, semver.Prerelease(formatVersionForComparison(Version)) != "") channel := determineChannel(Version, false) // isPreRelease is false for current version
return CheckForUpdates(owner, repo, channel, true, "") return CheckForUpdates(owner, repo, channel, true, "")
} }
// CheckOnlyByTag checks for updates from GitHub based on the channel determined by the // CheckOnlyByTag checks for updates from GitHub based on the channel determined by the
// current version tag, without applying them. // current version tag, without applying them.
var CheckOnlyByTag = func(owner, repo string) error { var CheckOnlyByTag = func(owner, repo string) error {
channel := determineChannel(Version, semver.Prerelease(formatVersionForComparison(Version)) != "") channel := determineChannel(Version, false) // isPreRelease is false for current version
return CheckOnly(owner, repo, channel, true, "") return CheckOnly(owner, repo, channel, true, "")
} }

View file

@ -7,8 +7,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"runtime" "runtime"
coreerr "forge.lthn.ai/core/go-log"
) )
// mockGithubClient is a mock implementation of the GithubClient interface for testing. // mockGithubClient is a mock implementation of the GithubClient interface for testing.
@ -36,7 +34,7 @@ func (m *mockGithubClient) GetPublicRepos(ctx context.Context, userOrOrg string)
if m.getPublicRepos != nil { if m.getPublicRepos != nil {
return m.getPublicRepos(ctx, userOrOrg) return m.getPublicRepos(ctx, userOrOrg)
} }
return nil, coreerr.E("mockGithubClient.GetPublicRepos", "not implemented", nil) return nil, fmt.Errorf("GetPublicRepos not implemented")
} }
func ExampleCheckForNewerVersion() { func ExampleCheckForNewerVersion() {