refactor: replace os.* and fmt.Errorf with go-io/go-log conventions

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-03-16 19:50:03 +00:00
parent ce48e74340
commit 2d51fb4d0e
5 changed files with 116 additions and 112 deletions

View file

@ -11,7 +11,8 @@ import (
"text/template"
"time"
"forge.lthn.ai/core/go-log"
coreio "forge.lthn.ai/core/go-io"
coreerr "forge.lthn.ai/core/go-log"
)
// Executor runs Ansible playbooks.
@ -80,12 +81,12 @@ func (e *Executor) SetVar(key string, value any) {
func (e *Executor) Run(ctx context.Context, playbookPath string) error {
plays, err := e.parser.ParsePlaybook(playbookPath)
if err != nil {
return fmt.Errorf("parse playbook: %w", err)
return coreerr.E("Executor.Run", "parse playbook", err)
}
for i := range plays {
if err := e.runPlay(ctx, &plays[i]); err != nil {
return fmt.Errorf("play %d (%s): %w", i, plays[i].Name, err)
return coreerr.E("Executor.Run", fmt.Sprintf("play %d (%s)", i, plays[i].Name), err)
}
}
@ -121,7 +122,7 @@ func (e *Executor) runPlay(ctx context.Context, play *Play) error {
if err := e.gatherFacts(ctx, host, play); err != nil {
// Non-fatal
if e.Verbose > 0 {
log.Warn("gather facts failed", "host", host, "err", err)
coreerr.Warn("gather facts failed", "host", host, "err", err)
}
}
}
@ -179,7 +180,7 @@ func (e *Executor) runRole(ctx context.Context, hosts []string, roleRef *RoleRef
// Parse role tasks
tasks, err := e.parser.ParseRole(roleRef.Role, roleRef.TasksFrom)
if err != nil {
return log.E("executor.runRole", fmt.Sprintf("parse role %s", roleRef.Role), err)
return coreerr.E("executor.runRole", fmt.Sprintf("parse role %s", roleRef.Role), err)
}
// Merge role vars
@ -266,7 +267,7 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, p
// Get SSH client
client, err := e.getClient(host, play)
if err != nil {
return fmt.Errorf("get client for %s: %w", host, err)
return coreerr.E("Executor.runTaskOnHost", fmt.Sprintf("get client for %s", host), err)
}
// Handle loops
@ -296,7 +297,7 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, p
}
if result.Failed && !task.IgnoreErrors {
return fmt.Errorf("task failed: %s", result.Msg)
return coreerr.E("Executor.runTaskOnHost", "task failed: "+result.Msg, nil)
}
return nil
@ -427,7 +428,7 @@ func (e *Executor) runIncludeTasks(ctx context.Context, hosts []string, task *Ta
tasks, err := e.parser.ParseTasks(path)
if err != nil {
return fmt.Errorf("include_tasks %s: %w", path, err)
return coreerr.E("Executor.runIncludeTasks", "include_tasks "+path, err)
}
for _, t := range tasks {
@ -881,8 +882,8 @@ func (e *Executor) handleLookup(expr string) string {
case "env":
return os.Getenv(arg)
case "file":
if data, err := os.ReadFile(arg); err == nil {
return string(data)
if data, err := coreio.Local.Read(arg); err == nil {
return data
}
}
@ -970,13 +971,13 @@ func (e *Executor) Close() {
// TemplateFile processes a template file.
func (e *Executor) TemplateFile(src, host string, task *Task) (string, error) {
content, err := os.ReadFile(src)
content, err := coreio.Local.Read(src)
if err != nil {
return "", err
}
// Convert Jinja2 to Go template syntax (basic conversion)
tmplContent := string(content)
tmplContent := content
tmplContent = strings.ReplaceAll(tmplContent, "{{", "{{ .")
tmplContent = strings.ReplaceAll(tmplContent, "{%", "{{")
tmplContent = strings.ReplaceAll(tmplContent, "%}", "}}")
@ -984,7 +985,7 @@ func (e *Executor) TemplateFile(src, host string, task *Task) (string, error) {
tmpl, err := template.New("template").Parse(tmplContent)
if err != nil {
// Fall back to simple replacement
return e.templateString(string(content), host, task), nil
return e.templateString(content, host, task), nil
}
// Build context map
@ -1011,7 +1012,7 @@ func (e *Executor) TemplateFile(src, host string, task *Task) (string, error) {
var buf strings.Builder
if err := tmpl.Execute(&buf, context); err != nil {
return e.templateString(string(content), host, task), nil
return e.templateString(content, host, task), nil
}
return buf.String(), nil

2
go.mod
View file

@ -4,6 +4,7 @@ go 1.26.0
require (
forge.lthn.ai/core/cli v0.3.1
forge.lthn.ai/core/go-io v0.1.2
forge.lthn.ai/core/go-log v0.0.4
github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.49.0
@ -15,7 +16,6 @@ require (
forge.lthn.ai/core/go-crypt v0.1.7 // indirect
forge.lthn.ai/core/go-i18n v0.1.4 // indirect
forge.lthn.ai/core/go-inference v0.1.4 // indirect
forge.lthn.ai/core/go-io v0.1.2 // indirect
forge.lthn.ai/core/go-process v0.2.2 // indirect
github.com/ProtonMail/go-crypto v1.4.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect

View file

@ -3,12 +3,14 @@ package ansible
import (
"context"
"encoding/base64"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
coreio "forge.lthn.ai/core/go-io"
coreerr "forge.lthn.ai/core/go-log"
)
// executeModule dispatches to the appropriate module handler.
@ -136,7 +138,7 @@ func (e *Executor) executeModule(ctx context.Context, host string, client *SSHCl
if strings.Contains(task.Module, " ") || task.Module == "" {
return e.moduleShell(ctx, client, args)
}
return nil, fmt.Errorf("unsupported module: %s", module)
return nil, coreerr.E("Executor.executeModule", "unsupported module: "+module, nil)
}
}
@ -179,7 +181,7 @@ func (e *Executor) moduleShell(ctx context.Context, client *SSHClient, args map[
cmd = getStringArg(args, "cmd", "")
}
if cmd == "" {
return nil, errors.New("shell: no command specified")
return nil, coreerr.E("Executor.moduleShell", "no command specified", nil)
}
// Handle chdir
@ -207,7 +209,7 @@ func (e *Executor) moduleCommand(ctx context.Context, client *SSHClient, args ma
cmd = getStringArg(args, "cmd", "")
}
if cmd == "" {
return nil, errors.New("command: no command specified")
return nil, coreerr.E("Executor.moduleCommand", "no command specified", nil)
}
// Handle chdir
@ -232,7 +234,7 @@ func (e *Executor) moduleCommand(ctx context.Context, client *SSHClient, args ma
func (e *Executor) moduleRaw(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) {
cmd := getStringArg(args, "_raw_params", "")
if cmd == "" {
return nil, errors.New("raw: no command specified")
return nil, coreerr.E("Executor.moduleRaw", "no command specified", nil)
}
stdout, stderr, rc, err := client.Run(ctx, cmd)
@ -251,16 +253,16 @@ func (e *Executor) moduleRaw(ctx context.Context, client *SSHClient, args map[st
func (e *Executor) moduleScript(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) {
script := getStringArg(args, "_raw_params", "")
if script == "" {
return nil, errors.New("script: no script specified")
return nil, coreerr.E("Executor.moduleScript", "no script specified", nil)
}
// Read local script
content, err := os.ReadFile(script)
data, err := coreio.Local.Read(script)
if err != nil {
return nil, fmt.Errorf("read script: %w", err)
return nil, coreerr.E("Executor.moduleScript", "read script", err)
}
stdout, stderr, rc, err := client.RunScript(ctx, string(content))
stdout, stderr, rc, err := client.RunScript(ctx, data)
if err != nil {
return &TaskResult{Failed: true, Msg: err.Error()}, nil
}
@ -279,21 +281,21 @@ func (e *Executor) moduleScript(ctx context.Context, client *SSHClient, args map
func (e *Executor) moduleCopy(ctx context.Context, client *SSHClient, args map[string]any, host string, task *Task) (*TaskResult, error) {
dest := getStringArg(args, "dest", "")
if dest == "" {
return nil, errors.New("copy: dest required")
return nil, coreerr.E("Executor.moduleCopy", "dest required", nil)
}
var content []byte
var content string
var err error
if src := getStringArg(args, "src", ""); src != "" {
content, err = os.ReadFile(src)
content, err = coreio.Local.Read(src)
if err != nil {
return nil, fmt.Errorf("read src: %w", err)
return nil, coreerr.E("Executor.moduleCopy", "read src", err)
}
} else if c := getStringArg(args, "content", ""); c != "" {
content = []byte(c)
content = c
} else {
return nil, errors.New("copy: src or content required")
return nil, coreerr.E("Executor.moduleCopy", "src or content required", nil)
}
mode := os.FileMode(0644)
@ -303,7 +305,7 @@ func (e *Executor) moduleCopy(ctx context.Context, client *SSHClient, args map[s
}
}
err = client.Upload(ctx, strings.NewReader(string(content)), dest, mode)
err = client.Upload(ctx, strings.NewReader(content), dest, mode)
if err != nil {
return nil, err
}
@ -323,13 +325,13 @@ func (e *Executor) moduleTemplate(ctx context.Context, client *SSHClient, args m
src := getStringArg(args, "src", "")
dest := getStringArg(args, "dest", "")
if src == "" || dest == "" {
return nil, errors.New("template: src and dest required")
return nil, coreerr.E("Executor.moduleTemplate", "src and dest required", nil)
}
// Process template
content, err := e.TemplateFile(src, host, task)
if err != nil {
return nil, fmt.Errorf("template: %w", err)
return nil, coreerr.E("Executor.moduleTemplate", "template", err)
}
mode := os.FileMode(0644)
@ -353,7 +355,7 @@ func (e *Executor) moduleFile(ctx context.Context, client *SSHClient, args map[s
path = getStringArg(args, "dest", "")
}
if path == "" {
return nil, errors.New("file: path required")
return nil, coreerr.E("Executor.moduleFile", "path required", nil)
}
state := getStringArg(args, "state", "file")
@ -384,7 +386,7 @@ func (e *Executor) moduleFile(ctx context.Context, client *SSHClient, args map[s
case "link":
src := getStringArg(args, "src", "")
if src == "" {
return nil, errors.New("file: src required for link state")
return nil, coreerr.E("Executor.moduleFile", "src required for link state", nil)
}
cmd := fmt.Sprintf("ln -sf %q %q", src, path)
_, stderr, rc, err := client.Run(ctx, cmd)
@ -421,7 +423,7 @@ func (e *Executor) moduleLineinfile(ctx context.Context, client *SSHClient, args
path = getStringArg(args, "dest", "")
}
if path == "" {
return nil, errors.New("lineinfile: path required")
return nil, coreerr.E("Executor.moduleLineinfile", "path required", nil)
}
line := getStringArg(args, "line", "")
@ -461,7 +463,7 @@ func (e *Executor) moduleLineinfile(ctx context.Context, client *SSHClient, args
func (e *Executor) moduleStat(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) {
path := getStringArg(args, "path", "")
if path == "" {
return nil, errors.New("stat: path required")
return nil, coreerr.E("Executor.moduleStat", "path required", nil)
}
stat, err := client.Stat(ctx, path)
@ -481,7 +483,7 @@ func (e *Executor) moduleSlurp(ctx context.Context, client *SSHClient, args map[
path = getStringArg(args, "src", "")
}
if path == "" {
return nil, errors.New("slurp: path required")
return nil, coreerr.E("Executor.moduleSlurp", "path required", nil)
}
content, err := client.Download(ctx, path)
@ -501,7 +503,7 @@ func (e *Executor) moduleFetch(ctx context.Context, client *SSHClient, args map[
src := getStringArg(args, "src", "")
dest := getStringArg(args, "dest", "")
if src == "" || dest == "" {
return nil, errors.New("fetch: src and dest required")
return nil, coreerr.E("Executor.moduleFetch", "src and dest required", nil)
}
content, err := client.Download(ctx, src)
@ -510,11 +512,11 @@ func (e *Executor) moduleFetch(ctx context.Context, client *SSHClient, args map[
}
// Create dest directory
if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil {
if err := coreio.Local.EnsureDir(filepath.Dir(dest)); err != nil {
return nil, err
}
if err := os.WriteFile(dest, content, 0644); err != nil {
if err := coreio.Local.Write(dest, string(content)); err != nil {
return nil, err
}
@ -525,7 +527,7 @@ func (e *Executor) moduleGetURL(ctx context.Context, client *SSHClient, args map
url := getStringArg(args, "url", "")
dest := getStringArg(args, "dest", "")
if url == "" || dest == "" {
return nil, errors.New("get_url: url and dest required")
return nil, coreerr.E("Executor.moduleGetURL", "url and dest required", nil)
}
// Use curl or wget
@ -592,7 +594,7 @@ func (e *Executor) moduleAptKey(ctx context.Context, client *SSHClient, args map
}
if url == "" {
return nil, errors.New("apt_key: url required")
return nil, coreerr.E("Executor.moduleAptKey", "url required", nil)
}
var cmd string
@ -616,7 +618,7 @@ func (e *Executor) moduleAptRepository(ctx context.Context, client *SSHClient, a
state := getStringArg(args, "state", "present")
if repo == "" {
return nil, errors.New("apt_repository: repo required")
return nil, coreerr.E("Executor.moduleAptRepository", "repo required", nil)
}
if filename == "" {
@ -691,7 +693,7 @@ func (e *Executor) moduleService(ctx context.Context, client *SSHClient, args ma
enabled := args["enabled"]
if name == "" {
return nil, errors.New("service: name required")
return nil, coreerr.E("Executor.moduleService", "name required", nil)
}
var cmds []string
@ -743,7 +745,7 @@ func (e *Executor) moduleUser(ctx context.Context, client *SSHClient, args map[s
state := getStringArg(args, "state", "present")
if name == "" {
return nil, errors.New("user: name required")
return nil, coreerr.E("Executor.moduleUser", "name required", nil)
}
if state == "absent" {
@ -800,7 +802,7 @@ func (e *Executor) moduleGroup(ctx context.Context, client *SSHClient, args map[
state := getStringArg(args, "state", "present")
if name == "" {
return nil, errors.New("group: name required")
return nil, coreerr.E("Executor.moduleGroup", "name required", nil)
}
if state == "absent" {
@ -835,7 +837,7 @@ func (e *Executor) moduleURI(ctx context.Context, client *SSHClient, args map[st
method := getStringArg(args, "method", "GET")
if url == "" {
return nil, errors.New("uri: url required")
return nil, coreerr.E("Executor.moduleURI", "url required", nil)
}
var curlOpts []string
@ -913,7 +915,7 @@ func (e *Executor) moduleFail(args map[string]any) (*TaskResult, error) {
func (e *Executor) moduleAssert(args map[string]any, host string) (*TaskResult, error) {
that, ok := args["that"]
if !ok {
return nil, errors.New("assert: 'that' required")
return nil, coreerr.E("Executor.moduleAssert", "'that' required", nil)
}
conditions := normalizeConditions(that)
@ -1014,7 +1016,7 @@ func (e *Executor) moduleGit(ctx context.Context, client *SSHClient, args map[st
version := getStringArg(args, "version", "HEAD")
if repo == "" || dest == "" {
return nil, errors.New("git: repo and dest required")
return nil, coreerr.E("Executor.moduleGit", "repo and dest required", nil)
}
// Check if dest exists
@ -1043,7 +1045,7 @@ func (e *Executor) moduleUnarchive(ctx context.Context, client *SSHClient, args
remote := getBoolArg(args, "remote_src", false)
if src == "" || dest == "" {
return nil, errors.New("unarchive: src and dest required")
return nil, coreerr.E("Executor.moduleUnarchive", "src and dest required", nil)
}
// Create dest directory (best-effort)
@ -1052,12 +1054,12 @@ func (e *Executor) moduleUnarchive(ctx context.Context, client *SSHClient, args
var cmd string
if !remote {
// Upload local file first
content, err := os.ReadFile(src)
data, err := coreio.Local.Read(src)
if err != nil {
return nil, fmt.Errorf("read src: %w", err)
return nil, coreerr.E("Executor.moduleUnarchive", "read src", err)
}
tmpPath := "/tmp/ansible_unarchive_" + filepath.Base(src)
err = client.Upload(ctx, strings.NewReader(string(content)), tmpPath, 0644)
err = client.Upload(ctx, strings.NewReader(data), tmpPath, 0644)
if err != nil {
return nil, err
}
@ -1118,7 +1120,7 @@ func getBoolArg(args map[string]any, key string, def bool) bool {
func (e *Executor) moduleHostname(ctx context.Context, client *SSHClient, args map[string]any) (*TaskResult, error) {
name := getStringArg(args, "name", "")
if name == "" {
return nil, errors.New("hostname: name required")
return nil, coreerr.E("Executor.moduleHostname", "name required", nil)
}
// Set hostname
@ -1140,7 +1142,7 @@ func (e *Executor) moduleSysctl(ctx context.Context, client *SSHClient, args map
state := getStringArg(args, "state", "present")
if name == "" {
return nil, errors.New("sysctl: name required")
return nil, coreerr.E("Executor.moduleSysctl", "name required", nil)
}
if state == "absent" {
@ -1210,7 +1212,7 @@ func (e *Executor) moduleBlockinfile(ctx context.Context, client *SSHClient, arg
path = getStringArg(args, "dest", "")
}
if path == "" {
return nil, errors.New("blockinfile: path required")
return nil, coreerr.E("Executor.moduleBlockinfile", "path required", nil)
}
block := getStringArg(args, "block", "")
@ -1358,13 +1360,13 @@ func (e *Executor) moduleAuthorizedKey(ctx context.Context, client *SSHClient, a
state := getStringArg(args, "state", "present")
if user == "" || key == "" {
return nil, errors.New("authorized_key: user and key required")
return nil, coreerr.E("Executor.moduleAuthorizedKey", "user and key required", nil)
}
// Get user's home directory
stdout, _, _, err := client.Run(ctx, fmt.Sprintf("getent passwd %s | cut -d: -f6", user))
if err != nil {
return nil, fmt.Errorf("get home dir: %w", err)
return nil, coreerr.E("Executor.moduleAuthorizedKey", "get home dir", err)
}
home := strings.TrimSpace(stdout)
if home == "" {
@ -1408,7 +1410,7 @@ func (e *Executor) moduleDockerCompose(ctx context.Context, client *SSHClient, a
state := getStringArg(args, "state", "present")
if projectSrc == "" {
return nil, errors.New("docker_compose: project_src required")
return nil, coreerr.E("Executor.moduleDockerCompose", "project_src required", nil)
}
var cmd string

View file

@ -4,12 +4,12 @@ import (
"fmt"
"iter"
"maps"
"os"
"path/filepath"
"slices"
"strings"
"forge.lthn.ai/core/go-log"
coreio "forge.lthn.ai/core/go-io"
coreerr "forge.lthn.ai/core/go-log"
"gopkg.in/yaml.v3"
)
@ -29,20 +29,20 @@ func NewParser(basePath string) *Parser {
// ParsePlaybook parses an Ansible playbook file.
func (p *Parser) ParsePlaybook(path string) ([]Play, error) {
data, err := os.ReadFile(path)
data, err := coreio.Local.Read(path)
if err != nil {
return nil, fmt.Errorf("read playbook: %w", err)
return nil, coreerr.E("Parser.ParsePlaybook", "read playbook", err)
}
var plays []Play
if err := yaml.Unmarshal(data, &plays); err != nil {
return nil, fmt.Errorf("parse playbook: %w", err)
if err := yaml.Unmarshal([]byte(data), &plays); err != nil {
return nil, coreerr.E("Parser.ParsePlaybook", "parse playbook", err)
}
// Process each play
for i := range plays {
if err := p.processPlay(&plays[i]); err != nil {
return nil, fmt.Errorf("process play %d: %w", i, err)
return nil, coreerr.E("Parser.ParsePlaybook", fmt.Sprintf("process play %d", i), err)
}
}
@ -66,14 +66,14 @@ func (p *Parser) ParsePlaybookIter(path string) (iter.Seq[Play], error) {
// ParseInventory parses an Ansible inventory file.
func (p *Parser) ParseInventory(path string) (*Inventory, error) {
data, err := os.ReadFile(path)
data, err := coreio.Local.Read(path)
if err != nil {
return nil, fmt.Errorf("read inventory: %w", err)
return nil, coreerr.E("Parser.ParseInventory", "read inventory", err)
}
var inv Inventory
if err := yaml.Unmarshal(data, &inv); err != nil {
return nil, fmt.Errorf("parse inventory: %w", err)
if err := yaml.Unmarshal([]byte(data), &inv); err != nil {
return nil, coreerr.E("Parser.ParseInventory", "parse inventory", err)
}
return &inv, nil
@ -81,19 +81,19 @@ func (p *Parser) ParseInventory(path string) (*Inventory, error) {
// ParseTasks parses a tasks file (used by include_tasks).
func (p *Parser) ParseTasks(path string) ([]Task, error) {
data, err := os.ReadFile(path)
data, err := coreio.Local.Read(path)
if err != nil {
return nil, fmt.Errorf("read tasks: %w", err)
return nil, coreerr.E("Parser.ParseTasks", "read tasks", err)
}
var tasks []Task
if err := yaml.Unmarshal(data, &tasks); err != nil {
return nil, fmt.Errorf("parse tasks: %w", err)
if err := yaml.Unmarshal([]byte(data), &tasks); err != nil {
return nil, coreerr.E("Parser.ParseTasks", "parse tasks", err)
}
for i := range tasks {
if err := p.extractModule(&tasks[i]); err != nil {
return nil, fmt.Errorf("task %d: %w", i, err)
return nil, coreerr.E("Parser.ParseTasks", fmt.Sprintf("task %d", i), err)
}
}
@ -139,21 +139,21 @@ func (p *Parser) ParseRole(name string, tasksFrom string) ([]Task, error) {
for _, sp := range searchPaths {
// Clean the path to resolve .. segments
sp = filepath.Clean(sp)
if _, err := os.Stat(sp); err == nil {
if coreio.Local.Exists(sp) {
tasksPath = sp
break
}
}
if tasksPath == "" {
return nil, log.E("parser.ParseRole", fmt.Sprintf("role %s not found in search paths: %v", name, searchPaths), nil)
return nil, coreerr.E("Parser.ParseRole", fmt.Sprintf("role %s not found in search paths: %v", name, searchPaths), nil)
}
// Load role defaults
defaultsPath := filepath.Join(filepath.Dir(filepath.Dir(tasksPath)), "defaults", "main.yml")
if data, err := os.ReadFile(defaultsPath); err == nil {
if data, err := coreio.Local.Read(defaultsPath); err == nil {
var defaults map[string]any
if yaml.Unmarshal(data, &defaults) == nil {
if yaml.Unmarshal([]byte(data), &defaults) == nil {
for k, v := range defaults {
if _, exists := p.vars[k]; !exists {
p.vars[k] = v
@ -164,9 +164,9 @@ func (p *Parser) ParseRole(name string, tasksFrom string) ([]Task, error) {
// Load role vars
varsPath := filepath.Join(filepath.Dir(filepath.Dir(tasksPath)), "vars", "main.yml")
if data, err := os.ReadFile(varsPath); err == nil {
if data, err := coreio.Local.Read(varsPath); err == nil {
var roleVars map[string]any
if yaml.Unmarshal(data, &roleVars) == nil {
if yaml.Unmarshal([]byte(data), &roleVars) == nil {
for k, v := range roleVars {
p.vars[k] = v
}
@ -185,25 +185,25 @@ func (p *Parser) processPlay(play *Play) error {
for i := range play.PreTasks {
if err := p.extractModule(&play.PreTasks[i]); err != nil {
return fmt.Errorf("pre_task %d: %w", i, err)
return coreerr.E("Parser.processPlay", fmt.Sprintf("pre_task %d", i), err)
}
}
for i := range play.Tasks {
if err := p.extractModule(&play.Tasks[i]); err != nil {
return fmt.Errorf("task %d: %w", i, err)
return coreerr.E("Parser.processPlay", fmt.Sprintf("task %d", i), err)
}
}
for i := range play.PostTasks {
if err := p.extractModule(&play.PostTasks[i]); err != nil {
return fmt.Errorf("post_task %d: %w", i, err)
return coreerr.E("Parser.processPlay", fmt.Sprintf("post_task %d", i), err)
}
}
for i := range play.Handlers {
if err := p.extractModule(&play.Handlers[i]); err != nil {
return fmt.Errorf("handler %d: %w", i, err)
return coreerr.E("Parser.processPlay", fmt.Sprintf("handler %d", i), err)
}
}

53
ssh.go
View file

@ -12,7 +12,8 @@ import (
"sync"
"time"
"forge.lthn.ai/core/go-log"
coreio "forge.lthn.ai/core/go-io"
coreerr "forge.lthn.ai/core/go-log"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
@ -91,8 +92,8 @@ func (c *SSHClient) Connect(ctx context.Context) error {
keyPath = filepath.Join(home, keyPath[1:])
}
if key, err := os.ReadFile(keyPath); err == nil {
if signer, err := ssh.ParsePrivateKey(key); err == nil {
if key, err := coreio.Local.Read(keyPath); err == nil {
if signer, err := ssh.ParsePrivateKey([]byte(key)); err == nil {
authMethods = append(authMethods, ssh.PublicKeys(signer))
}
}
@ -106,8 +107,8 @@ func (c *SSHClient) Connect(ctx context.Context) error {
filepath.Join(home, ".ssh", "id_rsa"),
}
for _, keyPath := range defaultKeys {
if key, err := os.ReadFile(keyPath); err == nil {
if signer, err := ssh.ParsePrivateKey(key); err == nil {
if key, err := coreio.Local.Read(keyPath); err == nil {
if signer, err := ssh.ParsePrivateKey([]byte(key)); err == nil {
authMethods = append(authMethods, ssh.PublicKeys(signer))
break
}
@ -128,7 +129,7 @@ func (c *SSHClient) Connect(ctx context.Context) error {
}
if len(authMethods) == 0 {
return log.E("ssh.Connect", "no authentication method available", nil)
return coreerr.E("ssh.Connect", "no authentication method available", nil)
}
// Host key verification
@ -136,23 +137,23 @@ func (c *SSHClient) Connect(ctx context.Context) error {
home, err := os.UserHomeDir()
if err != nil {
return log.E("ssh.Connect", "failed to get user home dir", err)
return coreerr.E("ssh.Connect", "failed to get user home dir", err)
}
knownHostsPath := filepath.Join(home, ".ssh", "known_hosts")
// Ensure known_hosts file exists
if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
if err := os.MkdirAll(filepath.Dir(knownHostsPath), 0700); err != nil {
return log.E("ssh.Connect", "failed to create .ssh dir", err)
if !coreio.Local.Exists(knownHostsPath) {
if err := coreio.Local.EnsureDir(filepath.Dir(knownHostsPath)); err != nil {
return coreerr.E("ssh.Connect", "failed to create .ssh dir", err)
}
if err := os.WriteFile(knownHostsPath, nil, 0600); err != nil {
return log.E("ssh.Connect", "failed to create known_hosts file", err)
if err := coreio.Local.Write(knownHostsPath, ""); err != nil {
return coreerr.E("ssh.Connect", "failed to create known_hosts file", err)
}
}
cb, err := knownhosts.New(knownHostsPath)
if err != nil {
return log.E("ssh.Connect", "failed to load known_hosts", err)
return coreerr.E("ssh.Connect", "failed to load known_hosts", err)
}
hostKeyCallback = cb
@ -169,13 +170,13 @@ func (c *SSHClient) Connect(ctx context.Context) error {
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
return log.E("ssh.Connect", fmt.Sprintf("dial %s", addr), err)
return coreerr.E("ssh.Connect", fmt.Sprintf("dial %s", addr), err)
}
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
// conn is closed by NewClientConn on error
return log.E("ssh.Connect", fmt.Sprintf("ssh connect %s", addr), err)
return coreerr.E("ssh.Connect", fmt.Sprintf("ssh connect %s", addr), err)
}
c.client = ssh.NewClient(sshConn, chans, reqs)
@ -203,7 +204,7 @@ func (c *SSHClient) Run(ctx context.Context, cmd string) (stdout, stderr string,
session, err := c.client.NewSession()
if err != nil {
return "", "", -1, log.E("ssh.Run", "new session", err)
return "", "", -1, coreerr.E("ssh.Run", "new session", err)
}
defer func() { _ = session.Close() }()
@ -225,7 +226,7 @@ func (c *SSHClient) Run(ctx context.Context, cmd string) (stdout, stderr string,
cmd = fmt.Sprintf("sudo -S -u %s bash -c '%s'", becomeUser, escapedCmd)
stdin, err := session.StdinPipe()
if err != nil {
return "", "", -1, log.E("ssh.Run", "stdin pipe", err)
return "", "", -1, coreerr.E("ssh.Run", "stdin pipe", err)
}
go func() {
defer func() { _ = stdin.Close() }()
@ -236,7 +237,7 @@ func (c *SSHClient) Run(ctx context.Context, cmd string) (stdout, stderr string,
cmd = fmt.Sprintf("sudo -S -u %s bash -c '%s'", becomeUser, escapedCmd)
stdin, err := session.StdinPipe()
if err != nil {
return "", "", -1, log.E("ssh.Run", "stdin pipe", err)
return "", "", -1, coreerr.E("ssh.Run", "stdin pipe", err)
}
go func() {
defer func() { _ = stdin.Close() }()
@ -287,7 +288,7 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string,
// Read content
content, err := io.ReadAll(local)
if err != nil {
return log.E("ssh.Upload", "read content", err)
return coreerr.E("ssh.Upload", "read content", err)
}
// Create parent directory
@ -297,7 +298,7 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string,
dirCmd = fmt.Sprintf("sudo mkdir -p %q", dir)
}
if _, _, _, err := c.Run(ctx, dirCmd); err != nil {
return log.E("ssh.Upload", "create parent dir", err)
return coreerr.E("ssh.Upload", "create parent dir", err)
}
// Use cat to write the file (simpler than SCP)
@ -309,13 +310,13 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string,
session2, err := c.client.NewSession()
if err != nil {
return log.E("ssh.Upload", "new session for write", err)
return coreerr.E("ssh.Upload", "new session for write", err)
}
defer func() { _ = session2.Close() }()
stdin, err := session2.StdinPipe()
if err != nil {
return log.E("ssh.Upload", "stdin pipe", err)
return coreerr.E("ssh.Upload", "stdin pipe", err)
}
var stderrBuf bytes.Buffer
@ -343,7 +344,7 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string,
}
if err := session2.Start(writeCmd); err != nil {
return log.E("ssh.Upload", "start write", err)
return coreerr.E("ssh.Upload", "start write", err)
}
go func() {
@ -356,7 +357,7 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string,
} else {
// Normal write
if err := session2.Start(writeCmd); err != nil {
return log.E("ssh.Upload", "start write", err)
return coreerr.E("ssh.Upload", "start write", err)
}
go func() {
@ -366,7 +367,7 @@ func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string,
}
if err := session2.Wait(); err != nil {
return log.E("ssh.Upload", fmt.Sprintf("write failed (stderr: %s)", stderrBuf.String()), err)
return coreerr.E("ssh.Upload", fmt.Sprintf("write failed (stderr: %s)", stderrBuf.String()), err)
}
return nil
@ -385,7 +386,7 @@ func (c *SSHClient) Download(ctx context.Context, remote string) ([]byte, error)
return nil, err
}
if exitCode != 0 {
return nil, log.E("ssh.Download", fmt.Sprintf("cat failed: %s", stderr), nil)
return nil, coreerr.E("ssh.Download", fmt.Sprintf("cat failed: %s", stderr), nil)
}
return []byte(stdout), nil