Compare commits
1 commit
main
...
feat/bandw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6071dc74f1 |
12 changed files with 499 additions and 74 deletions
12
cmd/all.go
12
cmd/all.go
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/github"
|
||||
"github.com/Snider/Borg/pkg/ratelimit"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
|
|
@ -42,7 +43,16 @@ func NewAllCmd() *cobra.Command {
|
|||
return err
|
||||
}
|
||||
|
||||
repos, err := GithubClient.GetPublicRepos(cmd.Context(), owner)
|
||||
bandwidth, _ := cmd.Flags().GetString("bandwidth")
|
||||
bytesPerSec, err := ratelimit.ParseBandwidth(bandwidth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid bandwidth: %w", err)
|
||||
}
|
||||
|
||||
client := github.NewAuthenticatedClient(cmd.Context(), ratelimit.NewRateLimitedRoundTripper(nil, bytesPerSec))
|
||||
githubClient := GithubClient(client)
|
||||
|
||||
repos, err := githubClient.GetPublicRepos(cmd.Context(), owner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,19 +15,16 @@ import (
|
|||
|
||||
func TestAllCmd_Good(t *testing.T) {
|
||||
// Setup mock HTTP client for GitHub API
|
||||
mockGithubClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/users/testuser/repos": {
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testuser/repo1.git"}]`)),
|
||||
},
|
||||
})
|
||||
oldNewAuthenticatedClient := github.NewAuthenticatedClient
|
||||
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
mockGithubClient := &mocks.MockGithubClient{
|
||||
Repos: []string{"https://github.com/testuser/repo1.git"},
|
||||
Err: nil,
|
||||
}
|
||||
oldGithubClient := GithubClient
|
||||
GithubClient = func(client *http.Client) github.GithubClient {
|
||||
return mockGithubClient
|
||||
}
|
||||
defer func() {
|
||||
github.NewAuthenticatedClient = oldNewAuthenticatedClient
|
||||
GithubClient = oldGithubClient
|
||||
}()
|
||||
|
||||
// Setup mock Git cloner
|
||||
|
|
@ -54,24 +51,16 @@ func TestAllCmd_Good(t *testing.T) {
|
|||
|
||||
func TestAllCmd_Bad(t *testing.T) {
|
||||
// Setup mock HTTP client to return an error
|
||||
mockGithubClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/users/baduser/repos": {
|
||||
StatusCode: http.StatusNotFound,
|
||||
Status: "404 Not Found",
|
||||
Body: io.NopCloser(bytes.NewBufferString(`{"message": "Not Found"}`)),
|
||||
},
|
||||
"https://api.github.com/orgs/baduser/repos": {
|
||||
StatusCode: http.StatusNotFound,
|
||||
Status: "404 Not Found",
|
||||
Body: io.NopCloser(bytes.NewBufferString(`{"message": "Not Found"}`)),
|
||||
},
|
||||
})
|
||||
oldNewAuthenticatedClient := github.NewAuthenticatedClient
|
||||
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
mockGithubClient := &mocks.MockGithubClient{
|
||||
Repos: nil,
|
||||
Err: fmt.Errorf("github error"),
|
||||
}
|
||||
oldGithubClient := GithubClient
|
||||
GithubClient = func(client *http.Client) github.GithubClient {
|
||||
return mockGithubClient
|
||||
}
|
||||
defer func() {
|
||||
github.NewAuthenticatedClient = oldNewAuthenticatedClient
|
||||
GithubClient = oldGithubClient
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
|
|
@ -88,19 +77,16 @@ func TestAllCmd_Bad(t *testing.T) {
|
|||
func TestAllCmd_Ugly(t *testing.T) {
|
||||
t.Run("User with no repos", func(t *testing.T) {
|
||||
// Setup mock HTTP client for a user with no repos
|
||||
mockGithubClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/users/emptyuser/repos": {
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(bytes.NewBufferString(`[]`)),
|
||||
},
|
||||
})
|
||||
oldNewAuthenticatedClient := github.NewAuthenticatedClient
|
||||
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
mockGithubClient := &mocks.MockGithubClient{
|
||||
Repos: []string{},
|
||||
Err: nil,
|
||||
}
|
||||
oldGithubClient := GithubClient
|
||||
GithubClient = func(client *http.Client) github.GithubClient {
|
||||
return mockGithubClient
|
||||
}
|
||||
defer func() {
|
||||
github.NewAuthenticatedClient = oldNewAuthenticatedClient
|
||||
GithubClient = oldGithubClient
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
|
|
|
|||
|
|
@ -11,11 +11,13 @@ func init() {
|
|||
RootCmd.AddCommand(GetCollectCmd())
|
||||
}
|
||||
func NewCollectCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
cmd := &cobra.Command{
|
||||
Use: "collect",
|
||||
Short: "Collect a resource from a URI.",
|
||||
Long: `Collect a resource from a URI and store it in a DataNode.`,
|
||||
}
|
||||
cmd.PersistentFlags().String("bandwidth", "0", "Limit download bandwidth (e.g., 1MB/s, 500KB/s, 0 for unlimited)")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func GetCollectCmd() *cobra.Command {
|
||||
|
|
|
|||
|
|
@ -65,3 +65,26 @@ func TestCollectGithubRepoCmd_Ugly(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollectGithubRepoCmd_Bandwidth(t *testing.T) {
|
||||
// Setup mock Git cloner
|
||||
mockCloner := &mocks.MockGitCloner{
|
||||
DN: datanode.New(),
|
||||
Err: nil,
|
||||
}
|
||||
oldCloner := GitCloner
|
||||
GitCloner = mockCloner
|
||||
defer func() {
|
||||
GitCloner = oldCloner
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(GetCollectCmd())
|
||||
|
||||
// Execute command with a bandwidth limit
|
||||
out := filepath.Join(t.TempDir(), "out")
|
||||
_, err := executeCommand(rootCmd, "collect", "github", "repo", "https://github.com/testuser/repo1", "--output", out, "--bandwidth", "1KB/s")
|
||||
if err != nil {
|
||||
t.Fatalf("collect github repo command failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,14 +2,18 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/Snider/Borg/pkg/github"
|
||||
"github.com/Snider/Borg/pkg/ratelimit"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
// GithubClient is the github client used by the command. It can be replaced for testing.
|
||||
GithubClient = github.NewGithubClient()
|
||||
GithubClient = func(client *http.Client) github.GithubClient {
|
||||
return github.NewGithubClient(client)
|
||||
}
|
||||
)
|
||||
|
||||
var collectGithubReposCmd = &cobra.Command{
|
||||
|
|
@ -17,7 +21,16 @@ var collectGithubReposCmd = &cobra.Command{
|
|||
Short: "Collects all public repositories for a user or organization",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0])
|
||||
bandwidth, _ := cmd.Flags().GetString("bandwidth")
|
||||
bytesPerSec, err := ratelimit.ParseBandwidth(bandwidth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid bandwidth: %w", err)
|
||||
}
|
||||
|
||||
client := github.NewAuthenticatedClient(cmd.Context(), ratelimit.NewRateLimitedRoundTripper(http.DefaultTransport, bytesPerSec))
|
||||
githubClient := GithubClient(client)
|
||||
|
||||
repos, err := githubClient.GetPublicRepos(cmd.Context(), args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
130
cmd/collect_github_repos_test.go
Normal file
130
cmd/collect_github_repos_test.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Borg/pkg/github"
|
||||
)
|
||||
|
||||
type mockGithubClient struct {
|
||||
repos []string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockGithubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
|
||||
return m.repos, m.err
|
||||
}
|
||||
|
||||
func TestCollectGithubReposCmd_Good(t *testing.T) {
|
||||
oldGithubClient := GithubClient
|
||||
GithubClient = func(client *http.Client) github.GithubClient {
|
||||
return &mockGithubClient{
|
||||
repos: []string{"https://github.com/testuser/repo1", "https://github.com/testuser/repo2"},
|
||||
err: nil,
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
GithubClient = oldGithubClient
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(GetCollectCmd())
|
||||
|
||||
// Execute command
|
||||
output, err := executeCommand(rootCmd, "collect", "github", "repos", "testuser")
|
||||
if err != nil {
|
||||
t.Fatalf("collect github repos command failed: %v", err)
|
||||
}
|
||||
|
||||
expected := "https://github.com/testuser/repo1\nhttps://github.com/testuser/repo2\n"
|
||||
if output != expected {
|
||||
t.Errorf("expected output %q, but got %q", expected, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectGithubReposCmd_Bad(t *testing.T) {
|
||||
oldGithubClient := GithubClient
|
||||
GithubClient = func(client *http.Client) github.GithubClient {
|
||||
return &mockGithubClient{
|
||||
repos: nil,
|
||||
err: fmt.Errorf("github error"),
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
GithubClient = oldGithubClient
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(GetCollectCmd())
|
||||
|
||||
// Execute command
|
||||
_, err := executeCommand(rootCmd, "collect", "github", "repos", "testuser")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectGithubReposCmd_Ugly(t *testing.T) {
|
||||
t.Run("Invalid bandwidth", func(t *testing.T) {
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(GetCollectCmd())
|
||||
_, err := executeCommand(rootCmd, "collect", "github", "repos", "testuser", "--bandwidth", "1Gbps")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error for invalid bandwidth, but got none")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid bandwidth") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollectGithubReposCmd_Bandwidth(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// This is a simplified mock of the GitHub API. It returns a single repo.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintln(w, `[{"clone_url": "https://github.com/testuser/repo1"}]`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// We need to override the API URL to point to our test server.
|
||||
oldGetPublicRepos := GithubClient
|
||||
GithubClient = func(client *http.Client) github.GithubClient {
|
||||
return &mockGithubClient{
|
||||
repos: []string{"https://github.com/testuser/repo1"},
|
||||
err: nil,
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
GithubClient = oldGetPublicRepos
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(GetCollectCmd())
|
||||
|
||||
// Execute command with a bandwidth limit
|
||||
start := time.Now()
|
||||
_, err := executeCommand(rootCmd, "collect", "github", "repos", "testuser", "--bandwidth", "1KB/s")
|
||||
if err != nil {
|
||||
t.Fatalf("collect github repos command failed: %v", err)
|
||||
}
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Since the response is very small, we can't reliably test the bandwidth limit.
|
||||
// We'll just check that the command runs without error.
|
||||
if elapsed > 1*time.Second {
|
||||
t.Errorf("expected the command to run quickly, but it took %s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
// getPublicReposWithAPIURL is a copy of the private function in pkg/github/github.go, so we can test it.
|
||||
type githubClient struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
var getPublicReposWithAPIURL func(g *githubClient, ctx context.Context, apiURL, userOrOrg string) ([]string, error)
|
||||
|
|
@ -2,10 +2,12 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/ratelimit"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
|
|
@ -51,7 +53,17 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
bar = ui.NewProgressBar(-1, "Crawling website")
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
|
||||
bandwidth, _ := cmd.Flags().GetString("bandwidth")
|
||||
bytesPerSec, err := ratelimit.ParseBandwidth(bandwidth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid bandwidth: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: ratelimit.NewRateLimitedRoundTripper(http.DefaultTransport, bytesPerSec),
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsiteWithClient(websiteURL, depth, bar, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error downloading and packaging website: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/website"
|
||||
|
|
@ -13,12 +16,12 @@ import (
|
|||
|
||||
func TestCollectWebsiteCmd_Good(t *testing.T) {
|
||||
// Mock the website downloader
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
oldDownloadAndPackageWebsiteWithClient := website.DownloadAndPackageWebsiteWithClient
|
||||
website.DownloadAndPackageWebsiteWithClient = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
|
||||
return datanode.New(), nil
|
||||
}
|
||||
defer func() {
|
||||
website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsiteWithClient = oldDownloadAndPackageWebsiteWithClient
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
|
|
@ -34,12 +37,12 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
|
|||
|
||||
func TestCollectWebsiteCmd_Bad(t *testing.T) {
|
||||
// Mock the website downloader to return an error
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
oldDownloadAndPackageWebsiteWithClient := website.DownloadAndPackageWebsiteWithClient
|
||||
website.DownloadAndPackageWebsiteWithClient = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
|
||||
return nil, fmt.Errorf("website error")
|
||||
}
|
||||
defer func() {
|
||||
website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsiteWithClient = oldDownloadAndPackageWebsiteWithClient
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
|
|
@ -65,4 +68,43 @@ func TestCollectWebsiteCmd_Ugly(t *testing.T) {
|
|||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid bandwidth", func(t *testing.T) {
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(GetCollectCmd())
|
||||
_, err := executeCommand(rootCmd, "collect", "website", "https://example.com", "--bandwidth", "1Gbps")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error for invalid bandwidth, but got none")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid bandwidth") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollectWebsiteCmd_Bandwidth(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(make([]byte, 1024*1024)) // 1MB
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(GetCollectCmd())
|
||||
|
||||
// Create a temporary directory for the output file
|
||||
outDir := t.TempDir()
|
||||
out := filepath.Join(outDir, "out")
|
||||
|
||||
// Execute command with a bandwidth limit
|
||||
start := time.Now()
|
||||
_, err := executeCommand(rootCmd, "collect", "website", server.URL, "--output", out, "--bandwidth", "500KB/s")
|
||||
if err != nil {
|
||||
t.Fatalf("collect website command failed: %v", err)
|
||||
}
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Check if the download took at least 2 seconds
|
||||
if elapsed < 2*time.Second {
|
||||
t.Errorf("expected download to take at least 2 seconds, but it took %s", elapsed)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,30 +21,43 @@ type GithubClient interface {
|
|||
}
|
||||
|
||||
// NewGithubClient creates a new GithubClient.
|
||||
func NewGithubClient() GithubClient {
|
||||
return &githubClient{}
|
||||
func NewGithubClient(client *http.Client) GithubClient {
|
||||
return &githubClient{client: client}
|
||||
}
|
||||
|
||||
type githubClient struct{}
|
||||
type githubClient struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewAuthenticatedClient creates a new authenticated http client.
|
||||
var NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
var NewAuthenticatedClient = func(ctx context.Context, transport http.RoundTripper) *http.Client {
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
token := os.Getenv("GITHUB_TOKEN")
|
||||
if token == "" {
|
||||
return http.DefaultClient
|
||||
return &http.Client{Transport: transport}
|
||||
}
|
||||
ts := oauth2.StaticTokenSource(
|
||||
&oauth2.Token{AccessToken: token},
|
||||
)
|
||||
return oauth2.NewClient(ctx, ts)
|
||||
return &http.Client{
|
||||
Transport: &oauth2.Transport{
|
||||
Base: transport,
|
||||
Source: ts,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
|
||||
return g.getPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg)
|
||||
return g.GetPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg)
|
||||
}
|
||||
|
||||
func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) {
|
||||
client := NewAuthenticatedClient(ctx)
|
||||
func (g *githubClient) GetPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) {
|
||||
client := g.client
|
||||
if client == nil {
|
||||
client = NewAuthenticatedClient(ctx, nil)
|
||||
}
|
||||
var allCloneURLs []string
|
||||
url := fmt.Sprintf("%s/users/%s/repos", apiURL, userOrOrg)
|
||||
isFirstRequest := true
|
||||
|
|
|
|||
125
pkg/ratelimit/ratelimit.go
Normal file
125
pkg/ratelimit/ratelimit.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Limiter is a simple token bucket rate limiter.
|
||||
type Limiter struct {
|
||||
c chan time.Time
|
||||
}
|
||||
|
||||
// NewLimiter creates a new Limiter.
|
||||
func NewLimiter(rate int64, per time.Duration) *Limiter {
|
||||
l := &Limiter{
|
||||
c: make(chan time.Time, rate),
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(per / time.Duration(rate))
|
||||
defer ticker.Stop()
|
||||
for t := range ticker.C {
|
||||
select {
|
||||
case l.c <- t:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
return l
|
||||
}
|
||||
|
||||
// Wait blocks until a token is available.
|
||||
func (l *Limiter) Wait() {
|
||||
<-l.c
|
||||
}
|
||||
|
||||
// rateLimitedRoundTripper is an http.RoundTripper that limits the bandwidth.
|
||||
type rateLimitedRoundTripper struct {
|
||||
transport http.RoundTripper
|
||||
limiter *Limiter
|
||||
}
|
||||
|
||||
// NewRateLimitedRoundTripper creates a new rateLimitedRoundTripper.
|
||||
func NewRateLimitedRoundTripper(transport http.RoundTripper, bytesPerSec int64) http.RoundTripper {
|
||||
if bytesPerSec <= 0 {
|
||||
return transport
|
||||
}
|
||||
return &rateLimitedRoundTripper{
|
||||
transport: transport,
|
||||
limiter: NewLimiter(bytesPerSec, time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface.
|
||||
func (t *rateLimitedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := t.transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.Body = &rateLimitedResponseBody{
|
||||
body: resp.Body,
|
||||
limiter: t.limiter,
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// rateLimitedResponseBody is an io.ReadCloser that limits the bandwidth.
|
||||
type rateLimitedResponseBody struct {
|
||||
body io.ReadCloser
|
||||
limiter *Limiter
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
func (b *rateLimitedResponseBody) Read(p []byte) (int, error) {
|
||||
n, err := b.body.Read(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
b.limiter.Wait()
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Close implements the io.Closer interface.
|
||||
func (b *rateLimitedResponseBody) Close() error {
|
||||
return b.body.Close()
|
||||
}
|
||||
|
||||
// ParseBandwidth parses a human-readable bandwidth string (e.g., "1MB/s")
|
||||
// and returns the equivalent in bytes per second.
|
||||
func ParseBandwidth(bandwidth string) (int64, error) {
|
||||
if bandwidth == "" || bandwidth == "0" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`(?i)^(\d+)\s*(KB/s|MB/s|Mbps)$`)
|
||||
matches := re.FindStringSubmatch(bandwidth)
|
||||
if len(matches) != 3 {
|
||||
return 0, fmt.Errorf("invalid bandwidth format: %s", bandwidth)
|
||||
}
|
||||
|
||||
value, err := strconv.ParseInt(matches[1], 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid bandwidth value: %s", matches[1])
|
||||
}
|
||||
|
||||
unit := strings.ToUpper(matches[2])
|
||||
switch unit {
|
||||
case "KB/S":
|
||||
return value * 1024, nil
|
||||
case "MB/S":
|
||||
return value * 1024 * 1024, nil
|
||||
case "MBPS":
|
||||
return value * 1024 * 1024 / 8, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown bandwidth unit: %s", unit)
|
||||
}
|
||||
}
|
||||
62
pkg/ratelimit/ratelimit_test.go
Normal file
62
pkg/ratelimit/ratelimit_test.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseBandwidth(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expected int64
|
||||
err bool
|
||||
}{
|
||||
{"1KB/s", 1024, false},
|
||||
{"1MB/s", 1024 * 1024, false},
|
||||
{"1Mbps", 1024 * 1024 / 8, false},
|
||||
{"500KB/s", 500 * 1024, false},
|
||||
{"10MB/s", 10 * 1024 * 1024, false},
|
||||
{"8Mbps", 1024 * 1024, false},
|
||||
{"0", 0, false},
|
||||
{"", 0, false},
|
||||
{"1 GB/s", 0, true},
|
||||
{"1MB", 0, true},
|
||||
{"MB/s", 0, true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
actual, err := ParseBandwidth(tc.input)
|
||||
if (err != nil) != tc.err {
|
||||
t.Errorf("expected error: %v, got: %v", tc.err, err)
|
||||
}
|
||||
if actual != tc.expected {
|
||||
t.Errorf("expected: %d, got: %d", tc.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiter(t *testing.T) {
|
||||
// Test case 1: 10 tokens per second
|
||||
limiter1 := NewLimiter(10, time.Second)
|
||||
start1 := time.Now()
|
||||
for i := 0; i < 10; i++ {
|
||||
limiter1.Wait()
|
||||
}
|
||||
elapsed1 := time.Since(start1)
|
||||
if elapsed1 < 900*time.Millisecond || elapsed1 > 1100*time.Millisecond {
|
||||
t.Errorf("expected to take around 1s for 10 tokens at 10 tokens/sec, but took %s", elapsed1)
|
||||
}
|
||||
|
||||
// Test case 2: 100 tokens per second
|
||||
limiter2 := NewLimiter(100, time.Second)
|
||||
start2 := time.Now()
|
||||
for i := 0; i < 10; i++ {
|
||||
limiter2.Wait()
|
||||
}
|
||||
elapsed2 := time.Since(start2)
|
||||
if elapsed2 < 90*time.Millisecond || elapsed2 > 110*time.Millisecond {
|
||||
t.Errorf("expected to take around 100ms for 10 tokens at 100 tokens/sec, but took %s", elapsed2)
|
||||
}
|
||||
}
|
||||
|
|
@ -44,25 +44,7 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
|||
|
||||
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(startURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := NewDownloader(maxDepth)
|
||||
d.baseURL = baseURL
|
||||
d.progressBar = bar
|
||||
d.crawl(startURL, 0)
|
||||
|
||||
if len(d.errors) > 0 {
|
||||
var errs []string
|
||||
for _, e := range d.errors {
|
||||
errs = append(errs, e.Error())
|
||||
}
|
||||
return nil, fmt.Errorf("failed to download website:\n%s", strings.Join(errs, "\n"))
|
||||
}
|
||||
|
||||
return d.dn, nil
|
||||
return downloadAndPackageWebsiteWithClient(startURL, maxDepth, bar, http.DefaultClient)
|
||||
}
|
||||
|
||||
func (d *Downloader) crawl(pageURL string, depth int) {
|
||||
|
|
@ -204,3 +186,28 @@ func isAsset(pageURL string) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// DownloadAndPackageWebsiteWithClient downloads a website and packages it into a DataNode using a custom http.Client.
|
||||
var DownloadAndPackageWebsiteWithClient = downloadAndPackageWebsiteWithClient
|
||||
|
||||
func downloadAndPackageWebsiteWithClient(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(startURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := NewDownloaderWithClient(maxDepth, client)
|
||||
d.baseURL = baseURL
|
||||
d.progressBar = bar
|
||||
d.crawl(startURL, 0)
|
||||
|
||||
if len(d.errors) > 0 {
|
||||
var errs []string
|
||||
for _, e := range d.errors {
|
||||
errs = append(errs, e.Error())
|
||||
}
|
||||
return nil, fmt.Errorf("failed to download website:\n%s", strings.Join(errs, "\n"))
|
||||
}
|
||||
|
||||
return d.dn, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue