cli/filesystem/sftp/client.go

126 lines
3.4 KiB
Go
Raw Normal View History

package sftp
import (
"fmt"
"io"
"net"
"os"
"path/filepath"
"github.com/pkg/sftp"
"github.com/skeema/knownhosts"
"golang.org/x/crypto/ssh"
)
// New creates a new, connected instance of the SFTP storage medium.
func New(cfg ConnectionConfig) (*Medium, error) {
var authMethods []ssh.AuthMethod
if cfg.KeyFile != "" {
key, err := os.ReadFile(cfg.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to read private key: %w", err)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return nil, fmt.Errorf("unable to parse private key: %w", err)
}
authMethods = append(authMethods, ssh.PublicKeys(signer))
} else if cfg.Password != "" {
authMethods = append(authMethods, ssh.Password(cfg.Password))
} else {
return nil, fmt.Errorf("no authentication method provided (password or keyfile)")
}
kh, err := knownhosts.New(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
if err != nil {
return nil, fmt.Errorf("failed to read known_hosts: %w", err)
}
sshConfig := &ssh.ClientConfig{
User: cfg.User,
Auth: authMethods,
HostKeyCallback: kh.HostKeyCallback(),
}
addr := net.JoinHostPort(cfg.Host, cfg.Port)
conn, err := ssh.Dial("tcp", addr, sshConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial ssh: %w", err)
}
sftpClient, err := sftp.NewClient(conn)
if err != nil {
// Ensure the underlying ssh connection is closed on failure
conn.Close()
return nil, fmt.Errorf("failed to create sftp client: %w", err)
}
return &Medium{client: sftpClient}, nil
}
// Read retrieves the content of a file from the SFTP server.
func (m *Medium) Read(path string) (string, error) {
file, err := m.client.Open(path)
if err != nil {
return "", fmt.Errorf("sftp: failed to open file %s: %w", path, err)
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return "", fmt.Errorf("sftp: failed to read file %s: %w", path, err)
}
return string(data), nil
}
// Write saves the given content to a file on the SFTP server.
func (m *Medium) Write(path, content string) error {
// Ensure the remote directory exists first.
dir := filepath.Dir(path)
if err := m.EnsureDir(dir); err != nil {
return err
}
file, err := m.client.Create(path)
if err != nil {
return fmt.Errorf("sftp: failed to create file %s: %w", path, err)
}
defer file.Close()
if _, err := file.Write([]byte(content)); err != nil {
return fmt.Errorf("sftp: failed to write to file %s: %w", path, err)
}
return nil
}
// EnsureDir makes sure a directory exists on the SFTP server.
func (m *Medium) EnsureDir(path string) error {
// MkdirAll is idempotent, so it won't error if the path already exists.
return m.client.MkdirAll(path)
}
// IsFile checks if a path exists and is a regular file on the SFTP server.
func (m *Medium) IsFile(path string) bool {
info, err := m.client.Stat(path)
if err != nil {
// If the error is "not found", it's definitely not a file.
// For any other error, we also conservatively say it's not a file.
return false
}
// Return true only if it's not a directory.
return !info.IsDir()
}
// FileGet is a convenience function that reads a file from the medium.
func (m *Medium) FileGet(path string) (string, error) {
return m.Read(path)
}
// FileSet is a convenience function that writes a file to the medium.
func (m *Medium) FileSet(path, content string) error {
return m.Write(path, content)
}