Compare commits
1 commit
main
...
feat/retry
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a242080299 |
10 changed files with 215 additions and 10 deletions
25
cmd/root.go
25
cmd/root.go
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/Snider/Borg/pkg/retry"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
|
@ -16,6 +17,30 @@ packaging their contents into a single file, and managing the data within.`,
|
|||
}
|
||||
|
||||
rootCmd.PersistentFlags().BoolP("verbose", "v", false, "Enable verbose logging")
|
||||
rootCmd.PersistentFlags().Int("retries", 3, "Max retry attempts")
|
||||
rootCmd.PersistentFlags().Duration("retry-backoff", 1s, "Initial backoff duration")
|
||||
rootCmd.PersistentFlags().Duration("retry-max", 30s, "Maximum backoff duration")
|
||||
rootCmd.PersistentFlags().Float64("retry-jitter", 0.1, "Randomization factor")
|
||||
rootCmd.PersistentFlags().Bool("no-retry", false, "Disable retries")
|
||||
|
||||
rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) {
|
||||
// Configure retry settings
|
||||
retries, _ := cmd.Flags().GetInt("retries")
|
||||
retryBackoff, _ := cmd.Flags().GetDuration("retry-backoff")
|
||||
retryMax, _ := cmd.Flags().GetDuration("retry-max")
|
||||
retryJitter, _ := cmd.Flags().GetFloat64("retry-jitter")
|
||||
noRetry, _ := cmd.Flags().GetBool("no-retry")
|
||||
|
||||
if noRetry {
|
||||
retry.DefaultTransport.Retries = 0
|
||||
} else {
|
||||
retry.DefaultTransport.Retries = retries
|
||||
retry.DefaultTransport.InitialBackoff = retryBackoff
|
||||
retry.DefaultTransport.MaxBackoff = retryMax
|
||||
retry.DefaultTransport.Jitter = retryJitter
|
||||
}
|
||||
}
|
||||
|
||||
return rootCmd
|
||||
}
|
||||
|
||||
|
|
|
|||
1
main.go
1
main.go
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
"github.com/Snider/Borg/cmd"
|
||||
"github.com/Snider/Borg/pkg/logger"
|
||||
"github.com/Snider/Borg/pkg/retry"
|
||||
)
|
||||
|
||||
var osExit = os.Exit
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/retry"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
|
|
@ -30,12 +31,19 @@ type githubClient struct{}
|
|||
// NewAuthenticatedClient creates a new authenticated http client.
|
||||
var NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
token := os.Getenv("GITHUB_TOKEN")
|
||||
|
||||
// Create a base client with retry transport.
|
||||
baseClient := &http.Client{Transport: retry.NewTransport()}
|
||||
|
||||
if token == "" {
|
||||
return http.DefaultClient
|
||||
return baseClient
|
||||
}
|
||||
|
||||
ts := oauth2.StaticTokenSource(
|
||||
&oauth2.Token{AccessToken: token},
|
||||
)
|
||||
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, baseClient)
|
||||
return oauth2.NewClient(ctx, ts)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/mocks"
|
||||
"github.com/Snider/Borg/pkg/retry"
|
||||
)
|
||||
|
||||
func TestGetPublicRepos_Good(t *testing.T) {
|
||||
|
|
@ -164,8 +165,8 @@ func TestNewAuthenticatedClient_Bad(t *testing.T) {
|
|||
// Unset the variable to ensure it's not present
|
||||
t.Setenv("GITHUB_TOKEN", "")
|
||||
client := NewAuthenticatedClient(context.Background())
|
||||
if client != http.DefaultClient {
|
||||
t.Error("expected http.DefaultClient when no token is set, but got something else")
|
||||
if _, ok := client.Transport.(*retry.Transport); !ok {
|
||||
t.Errorf("expected transport to be *retry.Transport, but got %T", client.Transport)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/retry"
|
||||
"github.com/google/go-github/v39/github"
|
||||
)
|
||||
|
||||
|
|
@ -22,12 +23,12 @@ var (
|
|||
return http.NewRequest(method, url, body)
|
||||
}
|
||||
// DefaultClient is the default http client
|
||||
DefaultClient = &http.Client{}
|
||||
DefaultClient = retry.NewClient(retry.NewTransport())
|
||||
)
|
||||
|
||||
// GetLatestRelease gets the latest release for a repository.
|
||||
func GetLatestRelease(owner, repo string) (*github.RepositoryRelease, error) {
|
||||
client := NewClient(nil)
|
||||
client := NewClient(NewAuthenticatedClient(context.Background()))
|
||||
release, _, err := client.Repositories.GetLatestRelease(context.Background(), owner, repo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/retry"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"golang.org/x/net/html"
|
||||
)
|
||||
|
|
@ -32,7 +33,7 @@ type PWAClient interface {
|
|||
|
||||
// NewPWAClient creates a new PWAClient.
|
||||
func NewPWAClient() PWAClient {
|
||||
return &pwaClient{client: http.DefaultClient}
|
||||
return &pwaClient{client: retry.NewClient(retry.NewTransport())}
|
||||
}
|
||||
|
||||
type pwaClient struct {
|
||||
|
|
|
|||
99
pkg/retry/client.go
Normal file
99
pkg/retry/client.go
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
package retry
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Backoff is a time.Duration that represents the backoff strategy.
|
||||
type Backoff time.Duration
|
||||
|
||||
// Exponential returns a new Backoff duration that is the current duration
|
||||
// multiplied by 2.
|
||||
func (b Backoff) Exponential() Backoff {
|
||||
return Backoff(time.Duration(b) * 2)
|
||||
}
|
||||
|
||||
// Jitter returns a new Backoff duration with a random jitter added.
|
||||
func (b Backoff) Jitter(factor float64) Backoff {
|
||||
if factor <= 0 {
|
||||
return b
|
||||
}
|
||||
jitter := time.Duration(rand.Float64() * factor * float64(b))
|
||||
return Backoff(time.Duration(b) + jitter)
|
||||
}
|
||||
|
||||
// Transport is an http.RoundTripper that automatically retries requests.
|
||||
type Transport struct {
|
||||
// Transport is the underlying http.RoundTripper to use for requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// Retries is the maximum number of retries to attempt.
|
||||
Retries int
|
||||
|
||||
// InitialBackoff is the initial backoff duration.
|
||||
InitialBackoff time.Duration
|
||||
|
||||
// MaxBackoff is the maximum backoff duration.
|
||||
MaxBackoff time.Duration
|
||||
|
||||
// Jitter is the jitter factor to apply to backoff durations.
|
||||
Jitter float64
|
||||
}
|
||||
|
||||
// NewTransport creates a new Transport with default values.
|
||||
func NewTransport() *Transport {
|
||||
return &Transport{
|
||||
Transport: http.DefaultTransport,
|
||||
Retries: 3,
|
||||
InitialBackoff: 1 * time.Second,
|
||||
MaxBackoff: 30 * time.Second,
|
||||
Jitter: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface.
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
var resp *http.Response
|
||||
var err error
|
||||
var backoff = Backoff(t.InitialBackoff)
|
||||
|
||||
for i := 0; i < t.Retries; i++ {
|
||||
resp, err = t.transport().RoundTrip(req)
|
||||
if err == nil && resp.StatusCode < 500 {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
if i < t.Retries-1 {
|
||||
time.Sleep(time.Duration(backoff))
|
||||
backoff = backoff.Exponential().Jitter(t.Jitter)
|
||||
if backoff > Backoff(t.MaxBackoff) {
|
||||
backoff = Backoff(t.MaxBackoff)
|
||||
}
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (t *Transport) transport() http.RoundTripper {
|
||||
if t.Transport != nil {
|
||||
return t.Transport
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
// NewClient returns a new http.Client that uses the Transport.
|
||||
func NewClient(transport *Transport) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultTransport is the default transport with retry logic.
|
||||
DefaultTransport = NewTransport()
|
||||
// DefaultClient is the default client that uses the default transport.
|
||||
DefaultClient = NewClient(DefaultTransport)
|
||||
)
|
||||
42
pkg/retry/client_test.go
Normal file
42
pkg/retry/client_test.go
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
package retry
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTransport_RoundTrip(t *testing.T) {
|
||||
var requestCount int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount++
|
||||
if requestCount <= 2 {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
transport := NewTransport()
|
||||
transport.Retries = 3
|
||||
transport.InitialBackoff = 1 * time.Millisecond
|
||||
transport.MaxBackoff = 10 * time.Millisecond
|
||||
|
||||
client := NewClient(transport)
|
||||
req, _ := http.NewRequest("GET", server.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
if requestCount != 3 {
|
||||
t.Errorf("expected 3 requests, got %d", requestCount)
|
||||
}
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
|
||||
|
|
@ -37,12 +38,37 @@ func (g *gitCloner) CloneGitRepository(repoURL string, progress io.Writer) (*dat
|
|||
cloneOptions.Progress = progress
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
retries := 3
|
||||
backoff := 1 * time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
|
||||
for i := 0; i < retries; i++ {
|
||||
_, err = git.PlainClone(tempPath, false, cloneOptions)
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
lastErr = nil
|
||||
break
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// Handle non-retryable error
|
||||
if err.Error() == "remote repository is empty" {
|
||||
return datanode.New(), nil
|
||||
}
|
||||
return nil, err
|
||||
|
||||
// Don't wait on the last attempt
|
||||
if i < retries-1 {
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
dn := datanode.New()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/retry"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
|
|
@ -28,7 +29,7 @@ type Downloader struct {
|
|||
|
||||
// NewDownloader creates a new Downloader.
|
||||
func NewDownloader(maxDepth int) *Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
|
||||
return NewDownloaderWithClient(maxDepth, retry.NewClient(retry.NewTransport()))
|
||||
}
|
||||
|
||||
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue