refactor: Use dependency injection for mocking

This commit refactors the codebase to use dependency injection for mocking external dependencies, removing the need for in-code mocking with the `BORG_PLEXSUS` environment variable.

- Interfaces have been created for external dependencies in the `pkg/vcs`, `pkg/github`, and `pkg/pwa` packages.
- The `cmd` package has been refactored to use these interfaces, with dependencies exposed as public variables for easy mocking in tests.
- The tests in `TDD/collect_commands_test.go` have been updated to inject mock implementations of these interfaces.
- The `BORG_PLEXSUS` environment variable has been removed from the codebase.
This commit is contained in:
google-labs-jules[bot] 2025-11-02 22:14:52 +00:00
parent 19431ed23d
commit bb6d30122d
11 changed files with 195 additions and 89 deletions

View file

@ -4,16 +4,46 @@ import (
"bytes"
"context"
"fmt"
"github.com/Snider/Borg/cmd"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/Snider/Borg/cmd"
"github.com/Snider/Borg/pkg/datanode"
"github.com/schollz/progressbar/v3"
)
func TestCollectCommands(t *testing.T) {
t.Setenv("BORG_PLEXSUS", "0")
type mockGitCloner struct{}
func (m *mockGitCloner) CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) {
dn := datanode.New()
dn.AddData("README.md", []byte("Mock README"))
return dn, nil
}
type mockGithubClient struct{}
func (m *mockGithubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
return []string{"https://github.com/test/repo1.git"}, nil
}
type mockPWAClient struct{}
func (m *mockPWAClient) FindManifest(pageURL string) (string, error) {
return "http://test.com/manifest.json", nil
}
func (m *mockPWAClient) DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
dn := datanode.New()
dn.AddData("manifest.json", []byte(`{"name": "Test PWA", "start_url": "index.html"}`))
return dn, nil
}
func TestCollectCommands(t *testing.T) {
// Setup a test server for the website test
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/page2.html" {
@ -29,21 +59,43 @@ func TestCollectCommands(t *testing.T) {
args []string
expectedStdout string
expectedStderr string
setup func(t *testing.T)
}{
{
name: "collect github repos",
args: []string{"collect", "github", "repos", "test"},
expectedStdout: "https://github.com/test/repo1.git",
setup: func(t *testing.T) {
original := cmd.GithubClient
cmd.GithubClient = &mockGithubClient{}
t.Cleanup(func() {
cmd.GithubClient = original
})
},
},
{
name: "collect github repo",
args: []string{"collect", "github", "repo", "https://github.com/test/repo1.git"},
expectedStdout: "Repository saved to repo.datanode",
setup: func(t *testing.T) {
original := cmd.GitCloner
cmd.GitCloner = &mockGitCloner{}
t.Cleanup(func() {
cmd.GitCloner = original
})
},
},
{
name: "collect pwa",
args: []string{"collect", "pwa", "--uri", "http://test.com"},
expectedStdout: "PWA saved to pwa.datanode",
setup: func(t *testing.T) {
original := cmd.PWAClient
cmd.PWAClient = &mockPWAClient{}
t.Cleanup(func() {
cmd.PWAClient = original
})
},
},
{
name: "collect website",
@ -54,6 +106,22 @@ func TestCollectCommands(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.setup != nil {
tc.setup(t)
}
t.Cleanup(func() {
files, err := filepath.Glob("*.datanode")
if err != nil {
t.Fatalf("failed to glob for datanode files: %v", err)
}
for _, f := range files {
if err := os.Remove(f); err != nil {
t.Logf("failed to remove datanode file %s: %v", f, err)
}
}
})
rootCmd := cmd.NewRootCmd()
outBuf := new(bytes.Buffer)
errBuf := new(bytes.Buffer)

View file

@ -7,9 +7,7 @@ import (
"os"
"strings"
"github.com/Snider/Borg/pkg/github"
"github.com/Snider/Borg/pkg/ui"
"github.com/Snider/Borg/pkg/vcs"
"github.com/spf13/cobra"
)
@ -27,7 +25,7 @@ var allCmd = &cobra.Command{
fmt.Fprintln(os.Stderr, "Error: logger not properly initialised")
return
}
repos, err := github.GetPublicRepos(context.Background(), args[0])
repos, err := GithubClient.GetPublicRepos(context.Background(), args[0])
if err != nil {
log.Error("failed to get public repos", "err", err)
return
@ -39,7 +37,7 @@ var allCmd = &cobra.Command{
log.Info("cloning repository", "url", repoURL)
bar := ui.NewProgressBar(-1, "Cloning repository")
dn, err := vcs.CloneGitRepository(repoURL, bar)
dn, err := GitCloner.CloneGitRepository(repoURL, bar)
bar.Finish()
if err != nil {
log.Error("failed to clone repository", "url", repoURL, "err", err)

View file

@ -13,6 +13,11 @@ import (
"github.com/spf13/cobra"
)
var (
// GitCloner is the git cloner used by the command. It can be replaced for testing.
GitCloner = vcs.NewGitCloner()
)
// collectGithubRepoCmd represents the collect github repo command
var collectGithubRepoCmd = &cobra.Command{
Use: "repo [repository-url]",
@ -35,7 +40,7 @@ var collectGithubRepoCmd = &cobra.Command{
progressWriter = ui.NewProgressWriter(bar)
}
dn, err := vcs.CloneGitRepository(repoURL, progressWriter)
dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error cloning repository:", err)
return

View file

@ -7,12 +7,17 @@ import (
"github.com/spf13/cobra"
)
var (
// GithubClient is the github client used by the command. It can be replaced for testing.
GithubClient = github.NewGithubClient()
)
var collectGithubReposCmd = &cobra.Command{
Use: "repos [user-or-org]",
Short: "Collects all public repositories for a user or organization",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
repos, err := github.GetPublicRepos(cmd.Context(), args[0])
repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0])
if err != nil {
return err
}

View file

@ -12,6 +12,11 @@ import (
"github.com/spf13/cobra"
)
var (
// PWAClient is the pwa client used by the command. It can be replaced for testing.
PWAClient = pwa.NewPWAClient()
)
// collectPWACmd represents the collect pwa command
var collectPWACmd = &cobra.Command{
Use: "pwa",
@ -34,13 +39,13 @@ Example:
bar := ui.NewProgressBar(-1, "Finding PWA manifest")
defer bar.Finish()
manifestURL, err := pwa.FindManifest(pwaURL)
manifestURL, err := PWAClient.FindManifest(pwaURL)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error finding manifest:", err)
return
}
bar.Describe("Downloading and packaging PWA")
dn, err := pwa.DownloadAndPackagePWA(pwaURL, manifestURL, bar)
dn, err := PWAClient.DownloadAndPackagePWA(pwaURL, manifestURL, bar)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error downloading and packaging PWA:", err)
return

View file

@ -1,12 +1,9 @@
package github
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/Snider/Borg/pkg/mocks"
"io"
"net/http"
"os"
"strings"
@ -18,27 +15,23 @@ type Repo struct {
CloneURL string `json:"clone_url"`
}
func GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
return GetPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg)
// GithubClient is an interface for interacting with the Github API.
type GithubClient interface {
GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error)
}
func newAuthenticatedClient(ctx context.Context) *http.Client {
if os.Getenv("BORG_PLEXSUS") == "0" {
// Define mock responses for testing
responses := map[string]*http.Response{
"https://api.github.com/users/test/repos": {
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/test/repo1.git"}]`)),
Header: make(http.Header),
},
"https://api.github.com/orgs/test/repos": {
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/test/repo2.git"}]`)),
Header: make(http.Header),
},
}
return mocks.NewMockClient(responses)
}
// NewGithubClient creates a new GithubClient.
func NewGithubClient() GithubClient {
return &githubClient{}
}
type githubClient struct{}
func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
return g.getPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg)
}
func (g *githubClient) newAuthenticatedClient(ctx context.Context) *http.Client {
token := os.Getenv("GITHUB_TOKEN")
if token == "" {
return http.DefaultClient
@ -49,8 +42,8 @@ func newAuthenticatedClient(ctx context.Context) *http.Client {
return oauth2.NewClient(ctx, ts)
}
func GetPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) {
client := newAuthenticatedClient(ctx)
func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) {
client := g.newAuthenticatedClient(ctx)
var allCloneURLs []string
url := fmt.Sprintf("%s/users/%s/repos", apiURL, userOrOrg)
@ -103,7 +96,7 @@ func GetPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]
if linkHeader == "" {
break
}
nextURL := findNextURL(linkHeader)
nextURL := g.findNextURL(linkHeader)
if nextURL == "" {
break
}
@ -113,7 +106,7 @@ func GetPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]
return allCloneURLs, nil
}
func findNextURL(linkHeader string) string {
func (g *githubClient) findNextURL(linkHeader string) string {
links := strings.Split(linkHeader, ",")
for _, link := range links {
parts := strings.Split(link, ";")

View file

@ -13,6 +13,13 @@ type MockRoundTripper struct {
responses map[string]*http.Response
}
// SetResponses sets the mock responses in a thread-safe way.
func (m *MockRoundTripper) SetResponses(responses map[string]*http.Response) {
m.mu.Lock()
defer m.mu.Unlock()
m.responses = responses
}
// RoundTrip implements the http.RoundTripper interface.
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
url := req.URL.String()
@ -59,9 +66,15 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
// NewMockClient creates a new http.Client with a MockRoundTripper.
func NewMockClient(responses map[string]*http.Response) *http.Client {
responsesCopy := make(map[string]*http.Response)
if responses != nil {
for k, v := range responses {
responsesCopy[k] = v
}
}
return &http.Client{
Transport: &MockRoundTripper{
responses: responses,
responses: responsesCopy,
},
}
}

View file

@ -1,14 +1,11 @@
package pwa
import (
"bytes"
"encoding/json"
"fmt"
"github.com/Snider/Borg/pkg/mocks"
"io"
"net/http"
"net/url"
"os"
"path"
"github.com/Snider/Borg/pkg/datanode"
@ -32,34 +29,26 @@ type Icon struct {
Type string `json:"type"`
}
var getHTTPClient = func() *http.Client {
if os.Getenv("BORG_PLEXSUS") == "0" {
responses := map[string]*http.Response{
"http://test.com": {
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`<html><head><link rel="manifest" href="manifest.json"></head></html>`)),
Header: make(http.Header),
},
"http://test.com/manifest.json": {
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{"name": "Test PWA", "start_url": "index.html"}`)),
Header: make(http.Header),
},
"http://test.com/index.html": {
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`<html><body>Hello</body></html>`)),
Header: make(http.Header),
},
}
return mocks.NewMockClient(responses)
// PWAClient is an interface for finding and downloading PWAs.
type PWAClient interface {
FindManifest(pageURL string) (string, error)
DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error)
}
// NewPWAClient creates a new PWAClient.
func NewPWAClient() PWAClient {
return &pwaClient{
client: http.DefaultClient,
}
return http.DefaultClient
}
type pwaClient struct {
client *http.Client
}
// FindManifest finds the manifest URL from a given HTML page.
func FindManifest(pageURL string) (string, error) {
client := getHTTPClient()
resp, err := client.Get(pageURL)
func (p *pwaClient) FindManifest(pageURL string) (string, error) {
resp, err := p.client.Get(pageURL)
if err != nil {
return "", err
}
@ -100,7 +89,7 @@ func FindManifest(pageURL string) (string, error) {
return "", fmt.Errorf("manifest not found")
}
resolvedURL, err := resolveURL(pageURL, manifestPath)
resolvedURL, err := p.resolveURL(pageURL, manifestPath)
if err != nil {
return "", fmt.Errorf("could not resolve manifest URL: %w", err)
}
@ -109,17 +98,16 @@ func FindManifest(pageURL string) (string, error) {
}
// DownloadAndPackagePWA downloads all assets of a PWA and packages them into a DataNode.
func DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
func (p *pwaClient) DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
if bar == nil {
return nil, fmt.Errorf("progress bar cannot be nil")
}
manifestAbsURL, err := resolveURL(baseURL, manifestURL)
manifestAbsURL, err := p.resolveURL(baseURL, manifestURL)
if err != nil {
return nil, fmt.Errorf("could not resolve manifest URL: %w", err)
}
client := getHTTPClient()
resp, err := client.Get(manifestAbsURL.String())
resp, err := p.client.Get(manifestAbsURL.String())
if err != nil {
return nil, fmt.Errorf("could not download manifest: %w", err)
}
@ -139,30 +127,30 @@ func DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.
dn.AddData("manifest.json", manifestBody)
if manifest.StartURL != "" {
startURLAbs, err := resolveURL(manifestAbsURL.String(), manifest.StartURL)
startURLAbs, err := p.resolveURL(manifestAbsURL.String(), manifest.StartURL)
if err != nil {
return nil, fmt.Errorf("could not resolve start_url: %w", err)
}
err = downloadAndAddFile(dn, startURLAbs, manifest.StartURL, bar)
err = p.downloadAndAddFile(dn, startURLAbs, manifest.StartURL, bar)
if err != nil {
return nil, fmt.Errorf("failed to download start_url asset: %w", err)
}
}
for _, icon := range manifest.Icons {
iconURLAbs, err := resolveURL(manifestAbsURL.String(), icon.Src)
iconURLAbs, err := p.resolveURL(manifestAbsURL.String(), icon.Src)
if err != nil {
fmt.Printf("Warning: could not resolve icon URL %s: %v\n", icon.Src, err)
continue
}
err = downloadAndAddFile(dn, iconURLAbs, icon.Src, bar)
err = p.downloadAndAddFile(dn, iconURLAbs, icon.Src, bar)
if err != nil {
fmt.Printf("Warning: failed to download icon %s: %v\n", icon.Src, err)
}
}
baseURLAbs, _ := url.Parse(baseURL)
err = downloadAndAddFile(dn, baseURLAbs, "index.html", bar)
err = p.downloadAndAddFile(dn, baseURLAbs, "index.html", bar)
if err != nil {
return nil, fmt.Errorf("failed to download base HTML: %w", err)
}
@ -170,7 +158,7 @@ func DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.
return dn, nil
}
func resolveURL(base, ref string) (*url.URL, error) {
func (p *pwaClient) resolveURL(base, ref string) (*url.URL, error) {
baseURL, err := url.Parse(base)
if err != nil {
return nil, err
@ -182,9 +170,8 @@ func resolveURL(base, ref string) (*url.URL, error) {
return baseURL.ResolveReference(refURL), nil
}
func downloadAndAddFile(dn *datanode.DataNode, fileURL *url.URL, internalPath string, bar *progressbar.ProgressBar) error {
client := getHTTPClient()
resp, err := client.Get(fileURL.String())
func (p *pwaClient) downloadAndAddFile(dn *datanode.DataNode, fileURL *url.URL, internalPath string, bar *progressbar.ProgressBar) error {
resp, err := p.client.Get(fileURL.String())
if err != nil {
return err
}

View file

@ -1,13 +1,34 @@
package pwa
import (
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/schollz/progressbar/v3"
)
func newTestPWAClient(serverURL string) PWAClient {
return &pwaClient{
client: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
},
}
}
func TestFindManifest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
@ -26,8 +47,9 @@ func TestFindManifest(t *testing.T) {
}))
defer server.Close()
client := newTestPWAClient(server.URL)
expectedURL := server.URL + "/manifest.json"
actualURL, err := FindManifest(server.URL)
actualURL, err := client.FindManifest(server.URL)
if err != nil {
t.Fatalf("FindManifest failed: %v", err)
}
@ -80,8 +102,9 @@ func TestDownloadAndPackagePWA(t *testing.T) {
}))
defer server.Close()
client := newTestPWAClient(server.URL)
bar := progressbar.New(1)
dn, err := DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", bar)
dn, err := client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", bar)
if err != nil {
t.Fatalf("DownloadAndPackagePWA failed: %v", err)
}
@ -99,6 +122,7 @@ func TestDownloadAndPackagePWA(t *testing.T) {
}
func TestResolveURL(t *testing.T) {
client := NewPWAClient()
tests := []struct {
base string
ref string
@ -113,7 +137,7 @@ func TestResolveURL(t *testing.T) {
}
for _, tt := range tests {
got, err := resolveURL(tt.base, tt.ref)
got, err := client.(*pwaClient).resolveURL(tt.base, tt.ref)
if err != nil {
t.Errorf("resolveURL(%q, %q) returned error: %v", tt.base, tt.ref, err)
continue

View file

@ -10,13 +10,20 @@ import (
"github.com/go-git/go-git/v5"
)
// GitCloner is an interface for cloning Git repositories.
type GitCloner interface {
CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error)
}
// NewGitCloner creates a new GitCloner.
func NewGitCloner() GitCloner {
return &gitCloner{}
}
type gitCloner struct{}
// CloneGitRepository clones a Git repository from a URL and packages it into a DataNode.
func CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) {
if os.Getenv("BORG_PLEXSUS") == "0" {
dn := datanode.New()
dn.AddData("README.md", []byte("Mock README"))
return dn, nil
}
func (g *gitCloner) CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) {
tempPath, err := os.MkdirTemp("", "borg-clone-*")
if err != nil {
return nil, err

View file

@ -69,7 +69,8 @@ func TestCloneGitRepository(t *testing.T) {
}
// Clone the repository using the function we're testing
dn, err := CloneGitRepository("file://"+bareRepoPath, os.Stdout)
cloner := NewGitCloner()
dn, err := cloner.CloneGitRepository("file://"+bareRepoPath, os.Stdout)
if err != nil {
t.Fatalf("CloneGitRepository failed: %v", err)
}