126 lines
3.4 KiB
Go
126 lines
3.4 KiB
Go
|
|
package sftp
|
||
|
|
|
||
|
|
import (
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
|
||
|
|
"github.com/pkg/sftp"
|
||
|
|
"github.com/skeema/knownhosts"
|
||
|
|
"golang.org/x/crypto/ssh"
|
||
|
|
)
|
||
|
|
|
||
|
|
// New creates a new, connected instance of the SFTP storage medium.
|
||
|
|
func New(cfg ConnectionConfig) (*Medium, error) {
|
||
|
|
var authMethods []ssh.AuthMethod
|
||
|
|
|
||
|
|
if cfg.KeyFile != "" {
|
||
|
|
key, err := os.ReadFile(cfg.KeyFile)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("unable to read private key: %w", err)
|
||
|
|
}
|
||
|
|
signer, err := ssh.ParsePrivateKey(key)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("unable to parse private key: %w", err)
|
||
|
|
}
|
||
|
|
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
||
|
|
} else if cfg.Password != "" {
|
||
|
|
authMethods = append(authMethods, ssh.Password(cfg.Password))
|
||
|
|
} else {
|
||
|
|
return nil, fmt.Errorf("no authentication method provided (password or keyfile)")
|
||
|
|
}
|
||
|
|
|
||
|
|
kh, err := knownhosts.New(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to read known_hosts: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
sshConfig := &ssh.ClientConfig{
|
||
|
|
User: cfg.User,
|
||
|
|
Auth: authMethods,
|
||
|
|
HostKeyCallback: kh.HostKeyCallback(),
|
||
|
|
}
|
||
|
|
|
||
|
|
addr := net.JoinHostPort(cfg.Host, cfg.Port)
|
||
|
|
conn, err := ssh.Dial("tcp", addr, sshConfig)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to dial ssh: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
sftpClient, err := sftp.NewClient(conn)
|
||
|
|
if err != nil {
|
||
|
|
// Ensure the underlying ssh connection is closed on failure
|
||
|
|
conn.Close()
|
||
|
|
return nil, fmt.Errorf("failed to create sftp client: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &Medium{client: sftpClient}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Read retrieves the content of a file from the SFTP server.
|
||
|
|
func (m *Medium) Read(path string) (string, error) {
|
||
|
|
file, err := m.client.Open(path)
|
||
|
|
if err != nil {
|
||
|
|
return "", fmt.Errorf("sftp: failed to open file %s: %w", path, err)
|
||
|
|
}
|
||
|
|
defer file.Close()
|
||
|
|
|
||
|
|
data, err := io.ReadAll(file)
|
||
|
|
if err != nil {
|
||
|
|
return "", fmt.Errorf("sftp: failed to read file %s: %w", path, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return string(data), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Write saves the given content to a file on the SFTP server.
|
||
|
|
func (m *Medium) Write(path, content string) error {
|
||
|
|
// Ensure the remote directory exists first.
|
||
|
|
dir := filepath.Dir(path)
|
||
|
|
if err := m.EnsureDir(dir); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
file, err := m.client.Create(path)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("sftp: failed to create file %s: %w", path, err)
|
||
|
|
}
|
||
|
|
defer file.Close()
|
||
|
|
|
||
|
|
if _, err := file.Write([]byte(content)); err != nil {
|
||
|
|
return fmt.Errorf("sftp: failed to write to file %s: %w", path, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// EnsureDir makes sure a directory exists on the SFTP server.
|
||
|
|
func (m *Medium) EnsureDir(path string) error {
|
||
|
|
// MkdirAll is idempotent, so it won't error if the path already exists.
|
||
|
|
return m.client.MkdirAll(path)
|
||
|
|
}
|
||
|
|
|
||
|
|
// IsFile checks if a path exists and is a regular file on the SFTP server.
|
||
|
|
func (m *Medium) IsFile(path string) bool {
|
||
|
|
info, err := m.client.Stat(path)
|
||
|
|
if err != nil {
|
||
|
|
// If the error is "not found", it's definitely not a file.
|
||
|
|
// For any other error, we also conservatively say it's not a file.
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
// Return true only if it's not a directory.
|
||
|
|
return !info.IsDir()
|
||
|
|
}
|
||
|
|
|
||
|
|
// FileGet is a convenience function that reads a file from the medium.
|
||
|
|
func (m *Medium) FileGet(path string) (string, error) {
|
||
|
|
return m.Read(path)
|
||
|
|
}
|
||
|
|
|
||
|
|
// FileSet is a convenience function that writes a file to the medium.
|
||
|
|
func (m *Medium) FileSet(path, content string) error {
|
||
|
|
return m.Write(path, content)
|
||
|
|
}
|