diff --git a/github.go b/github.go index 31354f1..03b4cd7 100644 --- a/github.go +++ b/github.go @@ -50,7 +50,7 @@ type githubClient struct{} var NewAuthenticatedClient = func(ctx context.Context) *http.Client { token := os.Getenv("GITHUB_TOKEN") if token == "" { - return NewHTTPClient() + return http.DefaultClient } ts := oauth2.StaticTokenSource( diff --git a/github_internal_test.go b/github_internal_test.go index 2024f37..09f301e 100644 --- a/github_internal_test.go +++ b/github_internal_test.go @@ -82,6 +82,54 @@ func TestDetermineChannel_Good(t *testing.T) { } } +func TestCheckForUpdatesByTag_UsesCurrentVersionChannel(t *testing.T) { + originalVersion := Version + originalCheckForUpdates := CheckForUpdates + defer func() { + Version = originalVersion + CheckForUpdates = originalCheckForUpdates + }() + + var gotChannel string + CheckForUpdates = func(owner, repo, channel string, forceSemVerPrefix bool, releaseURLFormat string) error { + gotChannel = channel + return nil + } + + Version = "v2.0.0-rc.1" + if err := CheckForUpdatesByTag("owner", "repo"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotChannel != "beta" { + t.Fatalf("expected beta channel, got %q", gotChannel) + } +} + +func TestCheckOnlyByTag_UsesCurrentVersionChannel(t *testing.T) { + originalVersion := Version + originalCheckOnly := CheckOnly + defer func() { + Version = originalVersion + CheckOnly = originalCheckOnly + }() + + var gotChannel string + CheckOnly = func(owner, repo, channel string, forceSemVerPrefix bool, releaseURLFormat string) error { + gotChannel = channel + return nil + } + + Version = "v2.0.0-alpha.1" + if err := CheckOnlyByTag("owner", "repo"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotChannel != "alpha" { + t.Fatalf("expected alpha channel, got %q", gotChannel) + } +} + func TestGetDownloadURL_Good(t *testing.T) { osName := runtime.GOOS archName := runtime.GOARCH diff --git a/service.go b/service.go index a58f448..737e22b 100644 --- a/service.go +++ b/service.go @@ -63,6 +63,7 @@ func NewUpdateService(config UpdateServiceConfig) (*UpdateService, error) { var err error if isGitHub { + config.Channel = normaliseGitHubChannel(config.Channel) owner, repo, err = ParseRepoURL(config.RepoURL) if err != nil { return nil, coreerr.E("NewUpdateService", "failed to parse GitHub repo URL", err) @@ -127,3 +128,11 @@ func ParseRepoURL(repoURL string) (owner string, repo string, err error) { } return parts[0], parts[1], nil } + +func normaliseGitHubChannel(channel string) string { + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel == "" { + return "stable" + } + return channel +} diff --git a/service_test.go b/service_test.go index ab8691a..c65458c 100644 --- a/service_test.go +++ b/service_test.go @@ -12,13 +12,15 @@ func TestNewUpdateService(t *testing.T) { config UpdateServiceConfig expectError bool isGitHub bool + wantChannel string }{ { name: "Valid GitHub URL", config: UpdateServiceConfig{ RepoURL: "https://github.com/owner/repo", }, - isGitHub: true, + isGitHub: true, + wantChannel: "stable", }, { name: "Valid non-GitHub URL", @@ -27,6 +29,15 @@ func TestNewUpdateService(t *testing.T) { }, isGitHub: false, }, + { + name: "GitHub channel is normalised", + config: UpdateServiceConfig{ + RepoURL: "https://github.com/owner/repo", + Channel: " Beta ", + }, + isGitHub: true, + wantChannel: "beta", + }, { name: "Invalid GitHub URL", config: UpdateServiceConfig{ @@ -45,6 +56,9 @@ func TestNewUpdateService(t *testing.T) { if err == nil && service.isGitHub != tc.isGitHub { t.Errorf("Expected isGitHub: %v, got: %v", tc.isGitHub, service.isGitHub) } + if err == nil && tc.wantChannel != "" && service.config.Channel != tc.wantChannel { + t.Errorf("Expected GitHub channel %q, got %q", tc.wantChannel, service.config.Channel) + } }) } } diff --git a/updater.go b/updater.go index 59a82d7..99858f3 100644 --- a/updater.go +++ b/updater.go @@ -145,14 +145,14 @@ var CheckOnly = func(owner, repo, channel string, forceSemVerPrefix bool, releas // CheckForUpdatesByTag checks for and applies updates from GitHub based on the channel // determined by the current application's version tag (e.g., 'stable' or 'prerelease'). var CheckForUpdatesByTag = func(owner, repo string) error { - channel := determineChannel(Version, false) // isPreRelease is false for current version + channel := determineChannel(Version, semver.Prerelease(formatVersionForComparison(Version)) != "") return CheckForUpdates(owner, repo, channel, true, "") } // CheckOnlyByTag checks for updates from GitHub based on the channel determined by the // current version tag, without applying them. var CheckOnlyByTag = func(owner, repo string) error { - channel := determineChannel(Version, false) // isPreRelease is false for current version + channel := determineChannel(Version, semver.Prerelease(formatVersionForComparison(Version)) != "") return CheckOnly(owner, repo, channel, true, "") }