diff --git a/executor.go b/executor.go index d88f9b2..985171d 100644 --- a/executor.go +++ b/executor.go @@ -11,7 +11,8 @@ import ( "text/template" "time" - "forge.lthn.ai/core/go-log" + coreio "forge.lthn.ai/core/go-io" + coreerr "forge.lthn.ai/core/go-log" ) // Executor runs Ansible playbooks. @@ -80,12 +81,12 @@ func (e *Executor) SetVar(key string, value any) { func (e *Executor) Run(ctx context.Context, playbookPath string) error { plays, err := e.parser.ParsePlaybook(playbookPath) if err != nil { - return fmt.Errorf("parse playbook: %w", err) + return coreerr.E("Executor.Run", "parse playbook", err) } for i := range plays { if err := e.runPlay(ctx, &plays[i]); err != nil { - return fmt.Errorf("play %d (%s): %w", i, plays[i].Name, err) + return coreerr.E("Executor.Run", fmt.Sprintf("play %d (%s)", i, plays[i].Name), err) } } @@ -121,7 +122,7 @@ func (e *Executor) runPlay(ctx context.Context, play *Play) error { if err := e.gatherFacts(ctx, host, play); err != nil { // Non-fatal if e.Verbose > 0 { - log.Warn("gather facts failed", "host", host, "err", err) + coreerr.Warn("gather facts failed", "host", host, "err", err) } } } @@ -179,7 +180,7 @@ func (e *Executor) runRole(ctx context.Context, hosts []string, roleRef *RoleRef // Parse role tasks tasks, err := e.parser.ParseRole(roleRef.Role, roleRef.TasksFrom) if err != nil { - return log.E("executor.runRole", fmt.Sprintf("parse role %s", roleRef.Role), err) + return coreerr.E("executor.runRole", fmt.Sprintf("parse role %s", roleRef.Role), err) } // Merge role vars @@ -266,7 +267,7 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, p // Get SSH client client, err := e.getClient(host, play) if err != nil { - return fmt.Errorf("get client for %s: %w", host, err) + return coreerr.E("Executor.runTaskOnHost", fmt.Sprintf("get client for %s", host), err) } // Handle loops @@ -296,7 +297,7 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, p } if result.Failed && !task.IgnoreErrors { - return fmt.Errorf("task failed: %s", result.Msg) + return coreerr.E("Executor.runTaskOnHost", "task failed: "+result.Msg, nil) } return nil @@ -427,7 +428,7 @@ func (e *Executor) runIncludeTasks(ctx context.Context, hosts []string, task *Ta tasks, err := e.parser.ParseTasks(path) if err != nil { - return fmt.Errorf("include_tasks %s: %w", path, err) + return coreerr.E("Executor.runIncludeTasks", "include_tasks "+path, err) } for _, t := range tasks { @@ -881,8 +882,8 @@ func (e *Executor) handleLookup(expr string) string { case "env": return os.Getenv(arg) case "file": - if data, err := os.ReadFile(arg); err == nil { - return string(data) + if data, err := coreio.Local.Read(arg); err == nil { + return data } } @@ -970,13 +971,13 @@ func (e *Executor) Close() { // TemplateFile processes a template file. func (e *Executor) TemplateFile(src, host string, task *Task) (string, error) { - content, err := os.ReadFile(src) + content, err := coreio.Local.Read(src) if err != nil { return "", err } // Convert Jinja2 to Go template syntax (basic conversion) - tmplContent := string(content) + tmplContent := content tmplContent = strings.ReplaceAll(tmplContent, "{{", "{{ .") tmplContent = strings.ReplaceAll(tmplContent, "{%", "{{") tmplContent = strings.ReplaceAll(tmplContent, "%}", "}}") @@ -984,7 +985,7 @@ func (e *Executor) TemplateFile(src, host string, task *Task) (string, error) { tmpl, err := template.New("template").Parse(tmplContent) if err != nil { // Fall back to simple replacement - return e.templateString(string(content), host, task), nil + return e.templateString(content, host, task), nil } // Build context map @@ -1011,7 +1012,7 @@ func (e *Executor) TemplateFile(src, host string, task *Task) (string, error) { var buf strings.Builder if err := tmpl.Execute(&buf, context); err != nil { - return e.templateString(string(content), host, task), nil + return e.templateString(content, host, task), nil } return buf.String(), nil diff --git a/go.mod b/go.mod index 7f5b43e..94fae76 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.26.0 require ( forge.lthn.ai/core/cli v0.3.1 + forge.lthn.ai/core/go-io v0.1.2 forge.lthn.ai/core/go-log v0.0.4 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.49.0 @@ -15,7 +16,6 @@ require ( forge.lthn.ai/core/go-crypt v0.1.7 // indirect forge.lthn.ai/core/go-i18n v0.1.4 // indirect forge.lthn.ai/core/go-inference v0.1.4 // indirect - forge.lthn.ai/core/go-io v0.1.2 // indirect forge.lthn.ai/core/go-process v0.2.2 // indirect github.com/ProtonMail/go-crypto v1.4.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect diff --git a/modules.go b/modules.go index 9a3c722..da261a1 100644 --- a/modules.go +++ b/modules.go @@ -3,12 +3,14 @@ package ansible import ( "context" "encoding/base64" - "errors" "fmt" "os" "path/filepath" "strconv" "strings" + + coreio "forge.lthn.ai/core/go-io" + coreerr "forge.lthn.ai/core/go-log" ) // executeModule dispatches to the appropriate module handler. @@ -136,7 +138,7 @@ func (e *Executor) executeModule(ctx context.Context, host string, client *SSHCl if strings.Contains(task.Module, " ") || task.Module == "" { return e.moduleShell(ctx, client, args) } - return nil, fmt.Errorf("unsupported module: %s", module) + return nil, coreerr.E("Executor.executeModule", "unsupported module: "+module, nil) } } @@ -179,7 +181,7 @@ func (e *Executor) moduleShell(ctx context.Context, client *SSHClient, args map[ cmd = getStringArg(args, "cmd", "") } if cmd == "" { - return nil, errors.New("shell: no command specified") + return nil, coreerr.E("Executor.moduleShell", "no command specified", nil) } // Handle chdir @@ -207,7 +209,7 @@ func (e *Executor) moduleCommand(ctx context.Context, client *SSHClient, args ma cmd = getStringArg(args, "cmd", "") } if cmd == "" { - return nil, errors.New("command: no command specified") + return nil, coreerr.E("Executor.moduleCommand", "no command specified", nil) } // Handle chdir @@ -232,7 +234,7 @@ func (e *Executor) moduleCommand(ctx context.Context, client *SSHClient, args ma func (e *Executor) moduleRaw(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { cmd := getStringArg(args, "_raw_params", "") if cmd == "" { - return nil, errors.New("raw: no command specified") + return nil, coreerr.E("Executor.moduleRaw", "no command specified", nil) } stdout, stderr, rc, err := client.Run(ctx, cmd) @@ -251,16 +253,16 @@ func (e *Executor) moduleRaw(ctx context.Context, client *SSHClient, args map[st func (e *Executor) moduleScript(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { script := getStringArg(args, "_raw_params", "") if script == "" { - return nil, errors.New("script: no script specified") + return nil, coreerr.E("Executor.moduleScript", "no script specified", nil) } // Read local script - content, err := os.ReadFile(script) + data, err := coreio.Local.Read(script) if err != nil { - return nil, fmt.Errorf("read script: %w", err) + return nil, coreerr.E("Executor.moduleScript", "read script", err) } - stdout, stderr, rc, err := client.RunScript(ctx, string(content)) + stdout, stderr, rc, err := client.RunScript(ctx, data) if err != nil { return &TaskResult{Failed: true, Msg: err.Error()}, nil } @@ -279,21 +281,21 @@ func (e *Executor) moduleScript(ctx context.Context, client *SSHClient, args map func (e *Executor) moduleCopy(ctx context.Context, client *SSHClient, args map[string]any, host string, task *Task) (*TaskResult, error) { dest := getStringArg(args, "dest", "") if dest == "" { - return nil, errors.New("copy: dest required") + return nil, coreerr.E("Executor.moduleCopy", "dest required", nil) } - var content []byte + var content string var err error if src := getStringArg(args, "src", ""); src != "" { - content, err = os.ReadFile(src) + content, err = coreio.Local.Read(src) if err != nil { - return nil, fmt.Errorf("read src: %w", err) + return nil, coreerr.E("Executor.moduleCopy", "read src", err) } } else if c := getStringArg(args, "content", ""); c != "" { - content = []byte(c) + content = c } else { - return nil, errors.New("copy: src or content required") + return nil, coreerr.E("Executor.moduleCopy", "src or content required", nil) } mode := os.FileMode(0644) @@ -303,7 +305,7 @@ func (e *Executor) moduleCopy(ctx context.Context, client *SSHClient, args map[s } } - err = client.Upload(ctx, strings.NewReader(string(content)), dest, mode) + err = client.Upload(ctx, strings.NewReader(content), dest, mode) if err != nil { return nil, err } @@ -323,13 +325,13 @@ func (e *Executor) moduleTemplate(ctx context.Context, client *SSHClient, args m src := getStringArg(args, "src", "") dest := getStringArg(args, "dest", "") if src == "" || dest == "" { - return nil, errors.New("template: src and dest required") + return nil, coreerr.E("Executor.moduleTemplate", "src and dest required", nil) } // Process template content, err := e.TemplateFile(src, host, task) if err != nil { - return nil, fmt.Errorf("template: %w", err) + return nil, coreerr.E("Executor.moduleTemplate", "template", err) } mode := os.FileMode(0644) @@ -353,7 +355,7 @@ func (e *Executor) moduleFile(ctx context.Context, client *SSHClient, args map[s path = getStringArg(args, "dest", "") } if path == "" { - return nil, errors.New("file: path required") + return nil, coreerr.E("Executor.moduleFile", "path required", nil) } state := getStringArg(args, "state", "file") @@ -384,7 +386,7 @@ func (e *Executor) moduleFile(ctx context.Context, client *SSHClient, args map[s case "link": src := getStringArg(args, "src", "") if src == "" { - return nil, errors.New("file: src required for link state") + return nil, coreerr.E("Executor.moduleFile", "src required for link state", nil) } cmd := fmt.Sprintf("ln -sf %q %q", src, path) _, stderr, rc, err := client.Run(ctx, cmd) @@ -421,7 +423,7 @@ func (e *Executor) moduleLineinfile(ctx context.Context, client *SSHClient, args path = getStringArg(args, "dest", "") } if path == "" { - return nil, errors.New("lineinfile: path required") + return nil, coreerr.E("Executor.moduleLineinfile", "path required", nil) } line := getStringArg(args, "line", "") @@ -461,7 +463,7 @@ func (e *Executor) moduleLineinfile(ctx context.Context, client *SSHClient, args func (e *Executor) moduleStat(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { path := getStringArg(args, "path", "") if path == "" { - return nil, errors.New("stat: path required") + return nil, coreerr.E("Executor.moduleStat", "path required", nil) } stat, err := client.Stat(ctx, path) @@ -481,7 +483,7 @@ func (e *Executor) moduleSlurp(ctx context.Context, client *SSHClient, args map[ path = getStringArg(args, "src", "") } if path == "" { - return nil, errors.New("slurp: path required") + return nil, coreerr.E("Executor.moduleSlurp", "path required", nil) } content, err := client.Download(ctx, path) @@ -501,7 +503,7 @@ func (e *Executor) moduleFetch(ctx context.Context, client *SSHClient, args map[ src := getStringArg(args, "src", "") dest := getStringArg(args, "dest", "") if src == "" || dest == "" { - return nil, errors.New("fetch: src and dest required") + return nil, coreerr.E("Executor.moduleFetch", "src and dest required", nil) } content, err := client.Download(ctx, src) @@ -510,11 +512,11 @@ func (e *Executor) moduleFetch(ctx context.Context, client *SSHClient, args map[ } // Create dest directory - if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil { + if err := coreio.Local.EnsureDir(filepath.Dir(dest)); err != nil { return nil, err } - if err := os.WriteFile(dest, content, 0644); err != nil { + if err := coreio.Local.Write(dest, string(content)); err != nil { return nil, err } @@ -525,7 +527,7 @@ func (e *Executor) moduleGetURL(ctx context.Context, client *SSHClient, args map url := getStringArg(args, "url", "") dest := getStringArg(args, "dest", "") if url == "" || dest == "" { - return nil, errors.New("get_url: url and dest required") + return nil, coreerr.E("Executor.moduleGetURL", "url and dest required", nil) } // Use curl or wget @@ -592,7 +594,7 @@ func (e *Executor) moduleAptKey(ctx context.Context, client *SSHClient, args map } if url == "" { - return nil, errors.New("apt_key: url required") + return nil, coreerr.E("Executor.moduleAptKey", "url required", nil) } var cmd string @@ -616,7 +618,7 @@ func (e *Executor) moduleAptRepository(ctx context.Context, client *SSHClient, a state := getStringArg(args, "state", "present") if repo == "" { - return nil, errors.New("apt_repository: repo required") + return nil, coreerr.E("Executor.moduleAptRepository", "repo required", nil) } if filename == "" { @@ -691,7 +693,7 @@ func (e *Executor) moduleService(ctx context.Context, client *SSHClient, args ma enabled := args["enabled"] if name == "" { - return nil, errors.New("service: name required") + return nil, coreerr.E("Executor.moduleService", "name required", nil) } var cmds []string @@ -743,7 +745,7 @@ func (e *Executor) moduleUser(ctx context.Context, client *SSHClient, args map[s state := getStringArg(args, "state", "present") if name == "" { - return nil, errors.New("user: name required") + return nil, coreerr.E("Executor.moduleUser", "name required", nil) } if state == "absent" { @@ -800,7 +802,7 @@ func (e *Executor) moduleGroup(ctx context.Context, client *SSHClient, args map[ state := getStringArg(args, "state", "present") if name == "" { - return nil, errors.New("group: name required") + return nil, coreerr.E("Executor.moduleGroup", "name required", nil) } if state == "absent" { @@ -835,7 +837,7 @@ func (e *Executor) moduleURI(ctx context.Context, client *SSHClient, args map[st method := getStringArg(args, "method", "GET") if url == "" { - return nil, errors.New("uri: url required") + return nil, coreerr.E("Executor.moduleURI", "url required", nil) } var curlOpts []string @@ -913,7 +915,7 @@ func (e *Executor) moduleFail(args map[string]any) (*TaskResult, error) { func (e *Executor) moduleAssert(args map[string]any, host string) (*TaskResult, error) { that, ok := args["that"] if !ok { - return nil, errors.New("assert: 'that' required") + return nil, coreerr.E("Executor.moduleAssert", "'that' required", nil) } conditions := normalizeConditions(that) @@ -1014,7 +1016,7 @@ func (e *Executor) moduleGit(ctx context.Context, client *SSHClient, args map[st version := getStringArg(args, "version", "HEAD") if repo == "" || dest == "" { - return nil, errors.New("git: repo and dest required") + return nil, coreerr.E("Executor.moduleGit", "repo and dest required", nil) } // Check if dest exists @@ -1043,7 +1045,7 @@ func (e *Executor) moduleUnarchive(ctx context.Context, client *SSHClient, args remote := getBoolArg(args, "remote_src", false) if src == "" || dest == "" { - return nil, errors.New("unarchive: src and dest required") + return nil, coreerr.E("Executor.moduleUnarchive", "src and dest required", nil) } // Create dest directory (best-effort) @@ -1052,12 +1054,12 @@ func (e *Executor) moduleUnarchive(ctx context.Context, client *SSHClient, args var cmd string if !remote { // Upload local file first - content, err := os.ReadFile(src) + data, err := coreio.Local.Read(src) if err != nil { - return nil, fmt.Errorf("read src: %w", err) + return nil, coreerr.E("Executor.moduleUnarchive", "read src", err) } tmpPath := "/tmp/ansible_unarchive_" + filepath.Base(src) - err = client.Upload(ctx, strings.NewReader(string(content)), tmpPath, 0644) + err = client.Upload(ctx, strings.NewReader(data), tmpPath, 0644) if err != nil { return nil, err } @@ -1118,7 +1120,7 @@ func getBoolArg(args map[string]any, key string, def bool) bool { func (e *Executor) moduleHostname(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) { name := getStringArg(args, "name", "") if name == "" { - return nil, errors.New("hostname: name required") + return nil, coreerr.E("Executor.moduleHostname", "name required", nil) } // Set hostname @@ -1140,7 +1142,7 @@ func (e *Executor) moduleSysctl(ctx context.Context, client *SSHClient, args map state := getStringArg(args, "state", "present") if name == "" { - return nil, errors.New("sysctl: name required") + return nil, coreerr.E("Executor.moduleSysctl", "name required", nil) } if state == "absent" { @@ -1210,7 +1212,7 @@ func (e *Executor) moduleBlockinfile(ctx context.Context, client *SSHClient, arg path = getStringArg(args, "dest", "") } if path == "" { - return nil, errors.New("blockinfile: path required") + return nil, coreerr.E("Executor.moduleBlockinfile", "path required", nil) } block := getStringArg(args, "block", "") @@ -1358,13 +1360,13 @@ func (e *Executor) moduleAuthorizedKey(ctx context.Context, client *SSHClient, a state := getStringArg(args, "state", "present") if user == "" || key == "" { - return nil, errors.New("authorized_key: user and key required") + return nil, coreerr.E("Executor.moduleAuthorizedKey", "user and key required", nil) } // Get user's home directory stdout, _, _, err := client.Run(ctx, fmt.Sprintf("getent passwd %s | cut -d: -f6", user)) if err != nil { - return nil, fmt.Errorf("get home dir: %w", err) + return nil, coreerr.E("Executor.moduleAuthorizedKey", "get home dir", err) } home := strings.TrimSpace(stdout) if home == "" { @@ -1408,7 +1410,7 @@ func (e *Executor) moduleDockerCompose(ctx context.Context, client *SSHClient, a state := getStringArg(args, "state", "present") if projectSrc == "" { - return nil, errors.New("docker_compose: project_src required") + return nil, coreerr.E("Executor.moduleDockerCompose", "project_src required", nil) } var cmd string diff --git a/parser.go b/parser.go index baecf3f..9f9d21a 100644 --- a/parser.go +++ b/parser.go @@ -4,12 +4,12 @@ import ( "fmt" "iter" "maps" - "os" "path/filepath" "slices" "strings" - "forge.lthn.ai/core/go-log" + coreio "forge.lthn.ai/core/go-io" + coreerr "forge.lthn.ai/core/go-log" "gopkg.in/yaml.v3" ) @@ -29,20 +29,20 @@ func NewParser(basePath string) *Parser { // ParsePlaybook parses an Ansible playbook file. func (p *Parser) ParsePlaybook(path string) ([]Play, error) { - data, err := os.ReadFile(path) + data, err := coreio.Local.Read(path) if err != nil { - return nil, fmt.Errorf("read playbook: %w", err) + return nil, coreerr.E("Parser.ParsePlaybook", "read playbook", err) } var plays []Play - if err := yaml.Unmarshal(data, &plays); err != nil { - return nil, fmt.Errorf("parse playbook: %w", err) + if err := yaml.Unmarshal([]byte(data), &plays); err != nil { + return nil, coreerr.E("Parser.ParsePlaybook", "parse playbook", err) } // Process each play for i := range plays { if err := p.processPlay(&plays[i]); err != nil { - return nil, fmt.Errorf("process play %d: %w", i, err) + return nil, coreerr.E("Parser.ParsePlaybook", fmt.Sprintf("process play %d", i), err) } } @@ -66,14 +66,14 @@ func (p *Parser) ParsePlaybookIter(path string) (iter.Seq[Play], error) { // ParseInventory parses an Ansible inventory file. func (p *Parser) ParseInventory(path string) (*Inventory, error) { - data, err := os.ReadFile(path) + data, err := coreio.Local.Read(path) if err != nil { - return nil, fmt.Errorf("read inventory: %w", err) + return nil, coreerr.E("Parser.ParseInventory", "read inventory", err) } var inv Inventory - if err := yaml.Unmarshal(data, &inv); err != nil { - return nil, fmt.Errorf("parse inventory: %w", err) + if err := yaml.Unmarshal([]byte(data), &inv); err != nil { + return nil, coreerr.E("Parser.ParseInventory", "parse inventory", err) } return &inv, nil @@ -81,19 +81,19 @@ func (p *Parser) ParseInventory(path string) (*Inventory, error) { // ParseTasks parses a tasks file (used by include_tasks). func (p *Parser) ParseTasks(path string) ([]Task, error) { - data, err := os.ReadFile(path) + data, err := coreio.Local.Read(path) if err != nil { - return nil, fmt.Errorf("read tasks: %w", err) + return nil, coreerr.E("Parser.ParseTasks", "read tasks", err) } var tasks []Task - if err := yaml.Unmarshal(data, &tasks); err != nil { - return nil, fmt.Errorf("parse tasks: %w", err) + if err := yaml.Unmarshal([]byte(data), &tasks); err != nil { + return nil, coreerr.E("Parser.ParseTasks", "parse tasks", err) } for i := range tasks { if err := p.extractModule(&tasks[i]); err != nil { - return nil, fmt.Errorf("task %d: %w", i, err) + return nil, coreerr.E("Parser.ParseTasks", fmt.Sprintf("task %d", i), err) } } @@ -139,21 +139,21 @@ func (p *Parser) ParseRole(name string, tasksFrom string) ([]Task, error) { for _, sp := range searchPaths { // Clean the path to resolve .. segments sp = filepath.Clean(sp) - if _, err := os.Stat(sp); err == nil { + if coreio.Local.Exists(sp) { tasksPath = sp break } } if tasksPath == "" { - return nil, log.E("parser.ParseRole", fmt.Sprintf("role %s not found in search paths: %v", name, searchPaths), nil) + return nil, coreerr.E("Parser.ParseRole", fmt.Sprintf("role %s not found in search paths: %v", name, searchPaths), nil) } // Load role defaults defaultsPath := filepath.Join(filepath.Dir(filepath.Dir(tasksPath)), "defaults", "main.yml") - if data, err := os.ReadFile(defaultsPath); err == nil { + if data, err := coreio.Local.Read(defaultsPath); err == nil { var defaults map[string]any - if yaml.Unmarshal(data, &defaults) == nil { + if yaml.Unmarshal([]byte(data), &defaults) == nil { for k, v := range defaults { if _, exists := p.vars[k]; !exists { p.vars[k] = v @@ -164,9 +164,9 @@ func (p *Parser) ParseRole(name string, tasksFrom string) ([]Task, error) { // Load role vars varsPath := filepath.Join(filepath.Dir(filepath.Dir(tasksPath)), "vars", "main.yml") - if data, err := os.ReadFile(varsPath); err == nil { + if data, err := coreio.Local.Read(varsPath); err == nil { var roleVars map[string]any - if yaml.Unmarshal(data, &roleVars) == nil { + if yaml.Unmarshal([]byte(data), &roleVars) == nil { for k, v := range roleVars { p.vars[k] = v } @@ -185,25 +185,25 @@ func (p *Parser) processPlay(play *Play) error { for i := range play.PreTasks { if err := p.extractModule(&play.PreTasks[i]); err != nil { - return fmt.Errorf("pre_task %d: %w", i, err) + return coreerr.E("Parser.processPlay", fmt.Sprintf("pre_task %d", i), err) } } for i := range play.Tasks { if err := p.extractModule(&play.Tasks[i]); err != nil { - return fmt.Errorf("task %d: %w", i, err) + return coreerr.E("Parser.processPlay", fmt.Sprintf("task %d", i), err) } } for i := range play.PostTasks { if err := p.extractModule(&play.PostTasks[i]); err != nil { - return fmt.Errorf("post_task %d: %w", i, err) + return coreerr.E("Parser.processPlay", fmt.Sprintf("post_task %d", i), err) } } for i := range play.Handlers { if err := p.extractModule(&play.Handlers[i]); err != nil { - return fmt.Errorf("handler %d: %w", i, err) + return coreerr.E("Parser.processPlay", fmt.Sprintf("handler %d", i), err) } } diff --git a/ssh.go b/ssh.go index f12d38e..e296b93 100644 --- a/ssh.go +++ b/ssh.go @@ -12,7 +12,8 @@ import ( "sync" "time" - "forge.lthn.ai/core/go-log" + coreio "forge.lthn.ai/core/go-io" + coreerr "forge.lthn.ai/core/go-log" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/knownhosts" ) @@ -91,8 +92,8 @@ func (c *SSHClient) Connect(ctx context.Context) error { keyPath = filepath.Join(home, keyPath[1:]) } - if key, err := os.ReadFile(keyPath); err == nil { - if signer, err := ssh.ParsePrivateKey(key); err == nil { + if key, err := coreio.Local.Read(keyPath); err == nil { + if signer, err := ssh.ParsePrivateKey([]byte(key)); err == nil { authMethods = append(authMethods, ssh.PublicKeys(signer)) } } @@ -106,8 +107,8 @@ func (c *SSHClient) Connect(ctx context.Context) error { filepath.Join(home, ".ssh", "id_rsa"), } for _, keyPath := range defaultKeys { - if key, err := os.ReadFile(keyPath); err == nil { - if signer, err := ssh.ParsePrivateKey(key); err == nil { + if key, err := coreio.Local.Read(keyPath); err == nil { + if signer, err := ssh.ParsePrivateKey([]byte(key)); err == nil { authMethods = append(authMethods, ssh.PublicKeys(signer)) break } @@ -128,7 +129,7 @@ func (c *SSHClient) Connect(ctx context.Context) error { } if len(authMethods) == 0 { - return log.E("ssh.Connect", "no authentication method available", nil) + return coreerr.E("ssh.Connect", "no authentication method available", nil) } // Host key verification @@ -136,23 +137,23 @@ func (c *SSHClient) Connect(ctx context.Context) error { home, err := os.UserHomeDir() if err != nil { - return log.E("ssh.Connect", "failed to get user home dir", err) + return coreerr.E("ssh.Connect", "failed to get user home dir", err) } knownHostsPath := filepath.Join(home, ".ssh", "known_hosts") // Ensure known_hosts file exists - if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) { - if err := os.MkdirAll(filepath.Dir(knownHostsPath), 0700); err != nil { - return log.E("ssh.Connect", "failed to create .ssh dir", err) + if !coreio.Local.Exists(knownHostsPath) { + if err := coreio.Local.EnsureDir(filepath.Dir(knownHostsPath)); err != nil { + return coreerr.E("ssh.Connect", "failed to create .ssh dir", err) } - if err := os.WriteFile(knownHostsPath, nil, 0600); err != nil { - return log.E("ssh.Connect", "failed to create known_hosts file", err) + if err := coreio.Local.Write(knownHostsPath, ""); err != nil { + return coreerr.E("ssh.Connect", "failed to create known_hosts file", err) } } cb, err := knownhosts.New(knownHostsPath) if err != nil { - return log.E("ssh.Connect", "failed to load known_hosts", err) + return coreerr.E("ssh.Connect", "failed to load known_hosts", err) } hostKeyCallback = cb @@ -169,13 +170,13 @@ func (c *SSHClient) Connect(ctx context.Context) error { var d net.Dialer conn, err := d.DialContext(ctx, "tcp", addr) if err != nil { - return log.E("ssh.Connect", fmt.Sprintf("dial %s", addr), err) + return coreerr.E("ssh.Connect", fmt.Sprintf("dial %s", addr), err) } sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) if err != nil { // conn is closed by NewClientConn on error - return log.E("ssh.Connect", fmt.Sprintf("ssh connect %s", addr), err) + return coreerr.E("ssh.Connect", fmt.Sprintf("ssh connect %s", addr), err) } c.client = ssh.NewClient(sshConn, chans, reqs) @@ -203,7 +204,7 @@ func (c *SSHClient) Run(ctx context.Context, cmd string) (stdout, stderr string, session, err := c.client.NewSession() if err != nil { - return "", "", -1, log.E("ssh.Run", "new session", err) + return "", "", -1, coreerr.E("ssh.Run", "new session", err) } defer func() { _ = session.Close() }() @@ -225,7 +226,7 @@ func (c *SSHClient) Run(ctx context.Context, cmd string) (stdout, stderr string, cmd = fmt.Sprintf("sudo -S -u %s bash -c '%s'", becomeUser, escapedCmd) stdin, err := session.StdinPipe() if err != nil { - return "", "", -1, log.E("ssh.Run", "stdin pipe", err) + return "", "", -1, coreerr.E("ssh.Run", "stdin pipe", err) } go func() { defer func() { _ = stdin.Close() }() @@ -236,7 +237,7 @@ func (c *SSHClient) Run(ctx context.Context, cmd string) (stdout, stderr string, cmd = fmt.Sprintf("sudo -S -u %s bash -c '%s'", becomeUser, escapedCmd) stdin, err := session.StdinPipe() if err != nil { - return "", "", -1, log.E("ssh.Run", "stdin pipe", err) + return "", "", -1, coreerr.E("ssh.Run", "stdin pipe", err) } go func() { defer func() { _ = stdin.Close() }() @@ -287,7 +288,7 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string, // Read content content, err := io.ReadAll(local) if err != nil { - return log.E("ssh.Upload", "read content", err) + return coreerr.E("ssh.Upload", "read content", err) } // Create parent directory @@ -297,7 +298,7 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string, dirCmd = fmt.Sprintf("sudo mkdir -p %q", dir) } if _, _, _, err := c.Run(ctx, dirCmd); err != nil { - return log.E("ssh.Upload", "create parent dir", err) + return coreerr.E("ssh.Upload", "create parent dir", err) } // Use cat to write the file (simpler than SCP) @@ -309,13 +310,13 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string, session2, err := c.client.NewSession() if err != nil { - return log.E("ssh.Upload", "new session for write", err) + return coreerr.E("ssh.Upload", "new session for write", err) } defer func() { _ = session2.Close() }() stdin, err := session2.StdinPipe() if err != nil { - return log.E("ssh.Upload", "stdin pipe", err) + return coreerr.E("ssh.Upload", "stdin pipe", err) } var stderrBuf bytes.Buffer @@ -343,7 +344,7 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string, } if err := session2.Start(writeCmd); err != nil { - return log.E("ssh.Upload", "start write", err) + return coreerr.E("ssh.Upload", "start write", err) } go func() { @@ -356,7 +357,7 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string, } else { // Normal write if err := session2.Start(writeCmd); err != nil { - return log.E("ssh.Upload", "start write", err) + return coreerr.E("ssh.Upload", "start write", err) } go func() { @@ -366,7 +367,7 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string, } if err := session2.Wait(); err != nil { - return log.E("ssh.Upload", fmt.Sprintf("write failed (stderr: %s)", stderrBuf.String()), err) + return coreerr.E("ssh.Upload", fmt.Sprintf("write failed (stderr: %s)", stderrBuf.String()), err) } return nil @@ -385,7 +386,7 @@ func (c *SSHClient) Download(ctx context.Context, remote string) ([]byte, error) return nil, err } if exitCode != 0 { - return nil, log.E("ssh.Download", fmt.Sprintf("cat failed: %s", stderr), nil) + return nil, coreerr.E("ssh.Download", fmt.Sprintf("cat failed: %s", stderr), nil) } return []byte(stdout), nil