diff --git a/go.mod b/go.mod index 0264f1a..aacd5ed 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/mattn/go-isatty v0.0.20 github.com/schollz/progressbar/v3 v3.18.0 github.com/spf13/cobra v1.10.1 + github.com/stretchr/testify v1.11.1 github.com/ulikunitz/xz v0.5.15 github.com/wailsapp/wails/v2 v2.11.0 golang.org/x/mod v0.30.0 @@ -25,6 +26,7 @@ require ( github.com/bep/debounce v1.2.1 // indirect github.com/cloudflare/circl v1.6.1 // indirect github.com/cyphar/filepath-securejoin v0.4.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/go-billy/v5 v5.6.2 // indirect @@ -49,6 +51,7 @@ require ( github.com/pjbgf/sha1cd v0.3.2 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/samber/lo v1.49.1 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect @@ -65,4 +68,7 @@ require ( golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/Snider/Enchantrix => ../Enchantrix diff --git a/go.sum b/go.sum index bf4e1a3..e2b44bb 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,6 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw= github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE= -github.com/Snider/Enchantrix v0.0.2 h1:ExZQiBhfS/p/AHFTKhY80TOd+BXZjK95EzByAEgwvjs= -github.com/Snider/Enchantrix v0.0.2/go.mod h1:CtFcLAvnDT1KcuF1JBb/DJj0KplY8jHryO06KzQ1hsQ= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= diff --git a/pkg/smsg/keyserver.go b/pkg/smsg/keyserver.go new file mode 100644 index 0000000..be56620 --- /dev/null +++ b/pkg/smsg/keyserver.go @@ -0,0 +1,309 @@ +package smsg + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/Snider/Enchantrix/pkg/enchantrix" + "github.com/Snider/Enchantrix/pkg/keyserver" + "github.com/Snider/Enchantrix/pkg/trix" +) + +// EncryptKS encrypts a message using the keyserver (V1 format). +// The key material never leaves the keyserver — only the key ID is passed. +func EncryptKS(ctx context.Context, msg *Message, keyID string, ks keyserver.KeyServer) ([]byte, error) { + if msg.Body == "" && len(msg.Attachments) == 0 { + return nil, ErrEmptyMessage + } + + if msg.Timestamp == 0 { + msg.Timestamp = time.Now().Unix() + } + + payload, err := json.Marshal(msg) + if err != nil { + return nil, fmt.Errorf("failed to marshal message: %w", err) + } + + // Keyserver encrypts — key never leaves + encrypted, err := ks.EncryptSMSG(ctx, keyID, payload) + if err != nil { + return nil, fmt.Errorf("keyserver encrypt: %w", err) + } + + t := &trix.Trix{ + Header: map[string]interface{}{ + "version": Version, + "algorithm": "chacha20poly1305", + }, + Payload: encrypted, + } + + return trix.Encode(t, Magic, nil) +} + +// DecryptKS decrypts an SMSG container using the keyserver. +// Handles both V1 and V2 formats. The key material never leaves the keyserver. +func DecryptKS(ctx context.Context, data []byte, keyID string, ks keyserver.KeyServer) (*Message, error) { + t, err := trix.Decode(data, Magic, nil) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidMagic, err) + } + + // Keyserver decrypts — key never leaves + decrypted, err := ks.DecryptSMSG(ctx, keyID, t.Payload) + if err != nil { + return nil, ErrDecryptionFailed + } + + format := "" + compression := "" + if f, ok := t.Header["format"].(string); ok { + format = f + } + if c, ok := t.Header["compression"].(string); ok { + compression = c + } + + switch compression { + case CompressionGzip: + decompressed, err := gzipDecompress(decrypted) + if err != nil { + return nil, fmt.Errorf("gzip decompression failed: %w", err) + } + decrypted = decompressed + case CompressionZstd: + decompressed, err := zstdDecompress(decrypted) + if err != nil { + return nil, fmt.Errorf("zstd decompression failed: %w", err) + } + decrypted = decompressed + } + + if format == FormatV2 { + return parseV2Payload(decrypted) + } + + var msg Message + if err := json.Unmarshal(decrypted, &msg); err != nil { + return nil, fmt.Errorf("%w: invalid message format", ErrInvalidPayload) + } + + return &msg, nil +} + +// V3ParamsKS contains keyserver-aware parameters for V3 streaming. +type V3ParamsKS struct { + License string + Fingerprint string + Cadence Cadence +} + +// EncryptV3KS encrypts using V3 streaming format. Stream keys are derived and +// CEK wrapping is done via keyserver — stream key material never leaves. +// The CEK itself is a random per-message key used for content encryption. +func EncryptV3KS(ctx context.Context, msg *Message, params V3ParamsKS, manifest *Manifest, ks keyserver.KeyServer) ([]byte, error) { + if params.License == "" { + return nil, ErrLicenseRequired + } + if msg.Body == "" && len(msg.Attachments) == 0 { + return nil, ErrEmptyMessage + } + + if msg.Timestamp == 0 { + msg.Timestamp = time.Now().Unix() + } + + cadence := params.Cadence + if cadence == "" { + cadence = CadenceDaily + } + + current, next := GetRollingPeriods(cadence, time.Now().UTC()) + + // Generate random CEK (this is a per-message ephemeral key) + cek, err := GenerateCEK() + if err != nil { + return nil, err + } + + // Derive stream keys via keyserver — stream key material never leaves + currentKeyID, err := ks.DeriveStreamKey(ctx, params.License, current, params.Fingerprint) + if err != nil { + return nil, fmt.Errorf("failed to derive current stream key: %w", err) + } + defer ks.DeleteKey(ctx, currentKeyID) + + nextKeyID, err := ks.DeriveStreamKey(ctx, params.License, next, params.Fingerprint) + if err != nil { + return nil, fmt.Errorf("failed to derive next stream key: %w", err) + } + defer ks.DeleteKey(ctx, nextKeyID) + + // Wrap CEK with stream keys via keyserver + wrappedCurrent, err := ks.WrapCEK(ctx, currentKeyID, cek) + if err != nil { + return nil, fmt.Errorf("failed to wrap CEK for current period: %w", err) + } + + wrappedNext, err := ks.WrapCEK(ctx, nextKeyID, cek) + if err != nil { + return nil, fmt.Errorf("failed to wrap CEK for next period: %w", err) + } + + // Encrypt content with CEK + payload, err := json.Marshal(msg) + if err != nil { + return nil, fmt.Errorf("failed to marshal message: %w", err) + } + + compressed, err := zstdCompress(payload) + if err != nil { + return nil, fmt.Errorf("compression failed: %w", err) + } + + sigil, err := enchantrix.NewChaChaPolySigil(cek) + if err != nil { + return nil, fmt.Errorf("failed to create sigil: %w", err) + } + + encrypted, err := sigil.In(compressed) + if err != nil { + return nil, fmt.Errorf("encryption failed: %w", err) + } + + // Build V3 header + headerMap := map[string]interface{}{ + "version": Version, + "algorithm": "chacha20poly1305", + "format": FormatV3, + "compression": CompressionZstd, + "keyMethod": KeyMethodLTHNRolling, + "cadence": string(cadence), + "wrappedKeys": []WrappedKey{ + {Date: current, Wrapped: wrappedCurrent}, + {Date: next, Wrapped: wrappedNext}, + }, + } + + if manifest != nil { + if manifest.IssuedAt == 0 { + manifest.IssuedAt = time.Now().Unix() + } + headerMap["manifest"] = manifest + } + + t := &trix.Trix{ + Header: headerMap, + Payload: encrypted, + } + + return trix.Encode(t, Magic, nil) +} + +// DecryptV3KS decrypts a V3 streaming message. Stream keys are derived via +// keyserver for CEK unwrapping — stream key material never leaves. +func DecryptV3KS(ctx context.Context, data []byte, params V3ParamsKS, ks keyserver.KeyServer) (*Message, *Header, error) { + if params.License == "" { + return nil, nil, ErrLicenseRequired + } + + t, err := trix.Decode(data, Magic, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to decode container: %w", err) + } + + headerJSON, err := json.Marshal(t.Header) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal header: %w", err) + } + + var header Header + if err := json.Unmarshal(headerJSON, &header); err != nil { + return nil, nil, fmt.Errorf("failed to parse header: %w", err) + } + + if header.Format != FormatV3 { + return nil, nil, fmt.Errorf("expected v3 format, got: %s", header.Format) + } + + cadence := header.Cadence + if cadence == "" && params.Cadence != "" { + cadence = params.Cadence + } + if cadence == "" { + cadence = CadenceDaily + } + + // Unwrap CEK using keyserver-derived stream keys + cek, err := tryUnwrapCEKKS(ctx, header.WrappedKeys, params, cadence, ks) + if err != nil { + return nil, &header, err + } + + // Decrypt payload with CEK + sigil, err := enchantrix.NewChaChaPolySigil(cek) + if err != nil { + return nil, &header, fmt.Errorf("failed to create sigil: %w", err) + } + + compressed, err := sigil.Out(t.Payload) + if err != nil { + return nil, &header, ErrDecryptionFailed + } + + var decompressed []byte + if header.Compression == CompressionZstd { + decompressed, err = zstdDecompress(compressed) + if err != nil { + return nil, &header, fmt.Errorf("decompression failed: %w", err) + } + } else { + decompressed = compressed + } + + var msg Message + if err := json.Unmarshal(decompressed, &msg); err != nil { + return nil, &header, fmt.Errorf("failed to parse message: %w", err) + } + + return &msg, &header, nil +} + +// tryUnwrapCEKKS tries to unwrap CEK using keyserver-derived stream keys. +func tryUnwrapCEKKS(ctx context.Context, wrappedKeys []WrappedKey, params V3ParamsKS, cadence Cadence, ks keyserver.KeyServer) ([]byte, error) { + current, next := GetRollingPeriods(cadence, time.Now().UTC()) + + keysByPeriod := make(map[string]string) + for _, wk := range wrappedKeys { + keysByPeriod[wk.Date] = wk.Wrapped + } + + // Try current period + if wrapped, ok := keysByPeriod[current]; ok { + streamKeyID, err := ks.DeriveStreamKey(ctx, params.License, current, params.Fingerprint) + if err == nil { + cek, unwrapErr := ks.UnwrapCEK(ctx, streamKeyID, wrapped) + ks.DeleteKey(ctx, streamKeyID) + if unwrapErr == nil { + return cek, nil + } + } + } + + // Try next period + if wrapped, ok := keysByPeriod[next]; ok { + streamKeyID, err := ks.DeriveStreamKey(ctx, params.License, next, params.Fingerprint) + if err == nil { + cek, unwrapErr := ks.UnwrapCEK(ctx, streamKeyID, wrapped) + ks.DeleteKey(ctx, streamKeyID) + if unwrapErr == nil { + return cek, nil + } + } + } + + return nil, ErrNoValidKey +} diff --git a/pkg/smsg/keyserver_test.go b/pkg/smsg/keyserver_test.go new file mode 100644 index 0000000..27ade18 --- /dev/null +++ b/pkg/smsg/keyserver_test.go @@ -0,0 +1,148 @@ +package smsg + +import ( + "context" + "path/filepath" + "testing" + + "github.com/Snider/Enchantrix/pkg/keyserver" + "github.com/Snider/Enchantrix/pkg/keystore" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestSMSGKeyServer(t *testing.T) *keyserver.Server { + t.Helper() + dir := t.TempDir() + store, err := keystore.Create(filepath.Join(dir, "keys.trix"), "test-master") + require.NoError(t, err) + t.Cleanup(func() { store.Close() }) + return keyserver.NewServer(store) +} + +func TestEncryptDecryptKSRoundTrip(t *testing.T) { + ctx := context.Background() + ks := newTestSMSGKeyServer(t) + + keyID, err := ks.ImportPassword(ctx, "smsg-password", "smsg-key") + require.NoError(t, err) + + msg := NewMessage("Hello, keyserver SMSG!") + msg.WithSubject("Test").WithFrom("alice") + + encrypted, err := EncryptKS(ctx, msg, keyID, ks) + require.NoError(t, err) + assert.NotEmpty(t, encrypted) + + decrypted, err := DecryptKS(ctx, encrypted, keyID, ks) + require.NoError(t, err) + assert.Equal(t, "Hello, keyserver SMSG!", decrypted.Body) + assert.Equal(t, "Test", decrypted.Subject) + assert.Equal(t, "alice", decrypted.From) +} + +func TestKSCompatWithOldAPI(t *testing.T) { + ctx := context.Background() + ks := newTestSMSGKeyServer(t) + + password := "compat-password" + keyID, err := ks.ImportPassword(ctx, password, "compat") + require.NoError(t, err) + + // Encrypt with old API + msg := NewMessage("compatibility test") + encrypted, err := Encrypt(msg, password) + require.NoError(t, err) + + // Decrypt with keyserver + decrypted, err := DecryptKS(ctx, encrypted, keyID, ks) + require.NoError(t, err) + assert.Equal(t, "compatibility test", decrypted.Body) +} + +func TestKSEncryptOldDecrypt(t *testing.T) { + ctx := context.Background() + ks := newTestSMSGKeyServer(t) + + password := "compat-password-2" + keyID, err := ks.ImportPassword(ctx, password, "compat2") + require.NoError(t, err) + + // Encrypt with keyserver + msg := NewMessage("keyserver encrypted") + encrypted, err := EncryptKS(ctx, msg, keyID, ks) + require.NoError(t, err) + + // Decrypt with old API + decrypted, err := Decrypt(encrypted, password) + require.NoError(t, err) + assert.Equal(t, "keyserver encrypted", decrypted.Body) +} + +func TestEncryptKSEmptyMessage(t *testing.T) { + ctx := context.Background() + ks := newTestSMSGKeyServer(t) + + keyID, _ := ks.ImportPassword(ctx, "pass", "key") + + msg := &Message{} + _, err := EncryptKS(ctx, msg, keyID, ks) + assert.ErrorIs(t, err, ErrEmptyMessage) +} + +func TestV3KSRoundTrip(t *testing.T) { + ctx := context.Background() + ks := newTestSMSGKeyServer(t) + + params := V3ParamsKS{ + License: "test-license-123", + Fingerprint: "device-fp-456", + Cadence: CadenceDaily, + } + + msg := NewMessage("V3 keyserver streaming test") + + encrypted, err := EncryptV3KS(ctx, msg, params, nil, ks) + require.NoError(t, err) + assert.NotEmpty(t, encrypted) + + decrypted, header, err := DecryptV3KS(ctx, encrypted, params, ks) + require.NoError(t, err) + assert.Equal(t, "V3 keyserver streaming test", decrypted.Body) + assert.Equal(t, FormatV3, header.Format) + assert.Equal(t, KeyMethodLTHNRolling, header.KeyMethod) +} + +func TestV3KSWithManifest(t *testing.T) { + ctx := context.Background() + ks := newTestSMSGKeyServer(t) + + params := V3ParamsKS{ + License: "license-xyz", + Fingerprint: "fp-abc", + } + + manifest := NewManifest("Test Track") + manifest.Artist = "Test Artist" + + msg := NewMessage("manifest test") + + encrypted, err := EncryptV3KS(ctx, msg, params, manifest, ks) + require.NoError(t, err) + + _, header, err := DecryptV3KS(ctx, encrypted, params, ks) + require.NoError(t, err) + assert.NotNil(t, header.Manifest) + assert.Equal(t, "Test Track", header.Manifest.Title) +} + +func TestV3KSNoLicense(t *testing.T) { + ctx := context.Background() + ks := newTestSMSGKeyServer(t) + + params := V3ParamsKS{Fingerprint: "fp"} + msg := NewMessage("test") + + _, err := EncryptV3KS(ctx, msg, params, nil, ks) + assert.ErrorIs(t, err, ErrLicenseRequired) +} diff --git a/pkg/stmf/keyserver.go b/pkg/stmf/keyserver.go new file mode 100644 index 0000000..056ec2e --- /dev/null +++ b/pkg/stmf/keyserver.go @@ -0,0 +1,43 @@ +package stmf + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/Snider/Enchantrix/pkg/keyserver" + "github.com/Snider/Enchantrix/pkg/keystore" +) + +// DecryptKS decrypts a STMF payload using the keyserver. The server's X25519 +// private key never leaves the keyserver — ECDH and decryption happen internally. +func DecryptKS(ctx context.Context, stmfData []byte, keyID string, ks keyserver.KeyServer) (*FormData, error) { + // Keyserver performs ECDH + decrypt internally — private key never leaves + plaintext, err := ks.DecryptSTMF(ctx, keyID, stmfData) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrDecryptionFailed, err) + } + + var formData FormData + if err := json.Unmarshal(plaintext, &formData); err != nil { + return nil, fmt.Errorf("%w: invalid JSON payload: %v", ErrInvalidPayload, err) + } + + return &formData, nil +} + +// GenerateKeyPairKS generates an X25519 keypair in the keyserver and returns +// the public key for distribution. The private key stays in the keyserver. +func GenerateKeyPairKS(ctx context.Context, label string, ks keyserver.KeyServer) (publicKey []byte, keyID string, err error) { + keyID, err = ks.GenerateKey(ctx, keystore.X25519, label) + if err != nil { + return nil, "", fmt.Errorf("failed to generate keypair: %w", err) + } + + publicKey, err = ks.GetPublicKey(ctx, keyID) + if err != nil { + return nil, "", fmt.Errorf("failed to get public key: %w", err) + } + + return publicKey, keyID, nil +} diff --git a/pkg/stmf/keyserver_test.go b/pkg/stmf/keyserver_test.go new file mode 100644 index 0000000..6f62754 --- /dev/null +++ b/pkg/stmf/keyserver_test.go @@ -0,0 +1,121 @@ +package stmf + +import ( + "context" + "path/filepath" + "testing" + + "github.com/Snider/Enchantrix/pkg/keyserver" + "github.com/Snider/Enchantrix/pkg/keystore" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestSTMFKeyServer(t *testing.T) *keyserver.Server { + t.Helper() + dir := t.TempDir() + store, err := keystore.Create(filepath.Join(dir, "keys.trix"), "test-master") + require.NoError(t, err) + t.Cleanup(func() { store.Close() }) + return keyserver.NewServer(store) +} + +func TestGenerateKeyPairKS(t *testing.T) { + ctx := context.Background() + ks := newTestSTMFKeyServer(t) + + pubKey, keyID, err := GenerateKeyPairKS(ctx, "stmf-server", ks) + require.NoError(t, err) + assert.Len(t, pubKey, 32) + assert.NotEmpty(t, keyID) + + // Public key should be retrievable + pubKey2, err := ks.GetPublicKey(ctx, keyID) + require.NoError(t, err) + assert.Equal(t, pubKey, pubKey2) +} + +func TestSTMFDecryptKS(t *testing.T) { + ctx := context.Background() + ks := newTestSTMFKeyServer(t) + + // Generate server keypair via keyserver + pubKey, keyID, err := GenerateKeyPairKS(ctx, "stmf-server", ks) + require.NoError(t, err) + + // Client encrypts with public key (no keyserver needed) + formData := NewFormData(). + AddField("username", "alice"). + AddFieldWithType("password", "secret123", "password") + + encrypted, err := Encrypt(formData, pubKey) + require.NoError(t, err) + + // Server decrypts via keyserver — private key never leaves + decrypted, err := DecryptKS(ctx, encrypted, keyID, ks) + require.NoError(t, err) + + assert.Equal(t, "alice", decrypted.Get("username")) + assert.Equal(t, "secret123", decrypted.Get("password")) +} + +func TestSTMFKeyRotation(t *testing.T) { + ctx := context.Background() + ks := newTestSTMFKeyServer(t) + + // Generate first keypair + pubKey1, keyID1, err := GenerateKeyPairKS(ctx, "server-v1", ks) + require.NoError(t, err) + + // Generate second keypair (rotation) + pubKey2, keyID2, err := GenerateKeyPairKS(ctx, "server-v2", ks) + require.NoError(t, err) + + assert.NotEqual(t, pubKey1, pubKey2) + + // Encrypt with old key + form1 := NewFormData().AddField("version", "1") + enc1, err := Encrypt(form1, pubKey1) + require.NoError(t, err) + + // Encrypt with new key + form2 := NewFormData().AddField("version", "2") + enc2, err := Encrypt(form2, pubKey2) + require.NoError(t, err) + + // Old key still decrypts old forms + dec1, err := DecryptKS(ctx, enc1, keyID1, ks) + require.NoError(t, err) + assert.Equal(t, "1", dec1.Get("version")) + + // New key decrypts new forms + dec2, err := DecryptKS(ctx, enc2, keyID2, ks) + require.NoError(t, err) + assert.Equal(t, "2", dec2.Get("version")) + + // New key cannot decrypt old forms (different ECDH shared secret) + _, err = DecryptKS(ctx, enc1, keyID2, ks) + assert.Error(t, err) +} + +func TestSTMFPrivateKeyNeverLeaves(t *testing.T) { + ctx := context.Background() + ks := newTestSTMFKeyServer(t) + + _, keyID, err := GenerateKeyPairKS(ctx, "secure-key", ks) + require.NoError(t, err) + + // ListKeys should not expose key data + keys, err := ks.ListKeys(ctx) + require.NoError(t, err) + for _, k := range keys { + assert.Nil(t, k.KeyData, "ListKeys must not expose key data") + } + + // GetPublicKey only returns public component + pubKey, err := ks.GetPublicKey(ctx, keyID) + require.NoError(t, err) + assert.Len(t, pubKey, 32) // public key only + + // There's no GetPrivateKey method on the interface — by design +} diff --git a/pkg/tim/keyserver.go b/pkg/tim/keyserver.go new file mode 100644 index 0000000..f4ea82c --- /dev/null +++ b/pkg/tim/keyserver.go @@ -0,0 +1,171 @@ +package tim + +import ( + "context" + "encoding/binary" + "fmt" + "io/fs" + "os" + "path/filepath" + + "github.com/Snider/Borg/pkg/datanode" + "github.com/Snider/Enchantrix/pkg/keyserver" + "github.com/Snider/Enchantrix/pkg/trix" +) + +// ToSigilKS encrypts the TIM using the keyserver. The key material never +// leaves the keyserver — only the key ID is passed. +// The output format matches ToSigil: a Trix container with "STIM" magic. +func (m *TerminalIsolationMatrix) ToSigilKS(ctx context.Context, keyID string, ks keyserver.KeyServer) ([]byte, error) { + if m.Config == nil { + return nil, ErrConfigIsNil + } + + rootfsTar, err := m.RootFS.ToTar() + if err != nil { + return nil, fmt.Errorf("failed to serialize rootfs: %w", err) + } + + // Keyserver encrypts config+rootfs atomically — key never leaves + payload, err := ks.EncryptTIM(ctx, keyID, m.Config, rootfsTar) + if err != nil { + return nil, fmt.Errorf("keyserver encrypt TIM: %w", err) + } + + // Parse encrypted config size from payload for header metadata + configSize := binary.BigEndian.Uint32(payload[:4]) + rootfsSize := len(payload) - 4 - int(configSize) + + // Wrap in Trix container with same header format as ToSigil + t := &trix.Trix{ + Header: map[string]interface{}{ + "encryption_algorithm": "chacha20poly1305", + "tim": true, + "config_size": configSize, + "rootfs_size": rootfsSize, + "version": "1.0", + }, + Payload: payload, + } + + return trix.Encode(t, "STIM", nil) +} + +// FromSigilKS decrypts and deserializes a .stim file using the keyserver. +// The key material never leaves the keyserver. +func FromSigilKS(data []byte, keyID string, ks keyserver.KeyServer) (*TerminalIsolationMatrix, error) { + return FromSigilKSCtx(context.Background(), data, keyID, ks) +} + +// FromSigilKSCtx decrypts and deserializes a .stim file using the keyserver +// with a context. +func FromSigilKSCtx(ctx context.Context, data []byte, keyID string, ks keyserver.KeyServer) (*TerminalIsolationMatrix, error) { + t, err := trix.Decode(data, "STIM", nil) + if err != nil { + return nil, fmt.Errorf("failed to decode stim: %w", err) + } + + // Keyserver decrypts config+rootfs atomically — key never leaves + config, rootfsTar, err := ks.DecryptTIM(ctx, keyID, t.Payload) + if err != nil { + return nil, fmt.Errorf("keyserver decrypt TIM: %w", err) + } + + // Reconstruct DataNode from decrypted rootfs tar + rootfs, err := datanode.FromTar(rootfsTar) + if err != nil { + return nil, fmt.Errorf("failed to parse rootfs: %w", err) + } + + return &TerminalIsolationMatrix{ + Config: config, + RootFS: rootfs, + }, nil +} + +// CacheKS provides encrypted TIM storage using a keyserver. +// The key material never leaves the keyserver — only the key ID is referenced. +type CacheKS struct { + Dir string + KeyID string + KS keyserver.KeyServer +} + +// NewCacheKS creates a keyserver-backed TIM cache. +func NewCacheKS(dir string, keyID string, ks keyserver.KeyServer) (*CacheKS, error) { + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, err + } + return &CacheKS{Dir: dir, KeyID: keyID, KS: ks}, nil +} + +// Store encrypts and saves a TIM to the cache via keyserver. +func (c *CacheKS) Store(ctx context.Context, name string, m *TerminalIsolationMatrix) error { + data, err := m.ToSigilKS(ctx, c.KeyID, c.KS) + if err != nil { + return err + } + path := filepath.Join(c.Dir, name+".stim") + return os.WriteFile(path, data, 0600) +} + +// Load retrieves and decrypts a TIM from the cache via keyserver. +func (c *CacheKS) Load(ctx context.Context, name string) (*TerminalIsolationMatrix, error) { + path := filepath.Join(c.Dir, name+".stim") + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return FromSigilKSCtx(ctx, data, c.KeyID, c.KS) +} + +// RunEncryptedKS runs an encrypted .stim file, decrypting via keyserver. +// The key material never leaves the keyserver. +func RunEncryptedKS(ctx context.Context, stimPath, keyID string, ks keyserver.KeyServer) error { + data, err := os.ReadFile(stimPath) + if err != nil { + return fmt.Errorf("failed to read stim file: %w", err) + } + + m, err := FromSigilKSCtx(ctx, data, keyID, ks) + if err != nil { + return fmt.Errorf("failed to decrypt stim: %w", err) + } + + tempDir, err := os.MkdirTemp("", "borg-run-*") + if err != nil { + return fmt.Errorf("failed to create temporary directory: %w", err) + } + defer os.RemoveAll(tempDir) + + if err := os.WriteFile(filepath.Join(tempDir, "config.json"), m.Config, 0600); err != nil { + return fmt.Errorf("failed to write config: %w", err) + } + + rootfsPath := filepath.Join(tempDir, "rootfs") + if err := os.MkdirAll(rootfsPath, 0755); err != nil { + return fmt.Errorf("failed to create rootfs dir: %w", err) + } + + err = m.RootFS.Walk(".", func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if d.IsDir() { + return nil + } + target := filepath.Join(rootfsPath, path) + if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { + return err + } + return m.RootFS.CopyFile(path, target, 0600) + }) + if err != nil { + return fmt.Errorf("failed to extract rootfs: %w", err) + } + + cmd := ExecCommand("runc", "run", "-b", tempDir, "borg-container") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} diff --git a/pkg/tim/keyserver_test.go b/pkg/tim/keyserver_test.go new file mode 100644 index 0000000..1f70a66 --- /dev/null +++ b/pkg/tim/keyserver_test.go @@ -0,0 +1,150 @@ +package tim + +import ( + "context" + "path/filepath" + "testing" + + "github.com/Snider/Borg/pkg/datanode" + "github.com/Snider/Enchantrix/pkg/keyserver" + "github.com/Snider/Enchantrix/pkg/keystore" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestKeyServer(t *testing.T) *keyserver.Server { + t.Helper() + dir := t.TempDir() + store, err := keystore.Create(filepath.Join(dir, "keys.trix"), "test-master") + require.NoError(t, err) + t.Cleanup(func() { store.Close() }) + return keyserver.NewServer(store) +} + +func TestToFromSigilKS(t *testing.T) { + ctx := context.Background() + ks := newTestKeyServer(t) + + keyID, err := ks.ImportPassword(ctx, "test-password", "tim-key") + require.NoError(t, err) + + // Build a TIM + m, err := New() + require.NoError(t, err) + m.RootFS.AddData("hello.txt", []byte("Hello from the keyserver!")) + + // Encrypt via keyserver + stim, err := m.ToSigilKS(ctx, keyID, ks) + require.NoError(t, err) + assert.NotEmpty(t, stim) + + // Decrypt via keyserver + restored, err := FromSigilKS(stim, keyID, ks) + require.NoError(t, err) + + // Verify config round-trips + assert.Equal(t, m.Config, restored.Config) + + // Verify rootfs round-trips + f, err := restored.RootFS.Open("hello.txt") + require.NoError(t, err) + defer f.Close() + info, _ := f.Stat() + buf := make([]byte, info.Size()) + f.Read(buf) + assert.Equal(t, "Hello from the keyserver!", string(buf)) +} + +func TestToSigilKSMatchesOldFormat(t *testing.T) { + ctx := context.Background() + ks := newTestKeyServer(t) + + password := "compat-test" + keyID, err := ks.ImportPassword(ctx, password, "compat") + require.NoError(t, err) + + m, err := New() + require.NoError(t, err) + m.RootFS.AddData("file.txt", []byte("content")) + + // Encrypt with keyserver + stimKS, err := m.ToSigilKS(ctx, keyID, ks) + require.NoError(t, err) + + // Decrypt with old password API (should work — same key derivation) + restored, err := FromSigil(stimKS, password) + require.NoError(t, err) + assert.Equal(t, m.Config, restored.Config) +} + +func TestOldFormatDecryptsWithKS(t *testing.T) { + ctx := context.Background() + ks := newTestKeyServer(t) + + password := "compat-test-2" + keyID, err := ks.ImportPassword(ctx, password, "compat2") + require.NoError(t, err) + + m, err := New() + require.NoError(t, err) + m.RootFS.AddData("data.bin", []byte{0xDE, 0xAD, 0xBE, 0xEF}) + + // Encrypt with old API + stim, err := m.ToSigil(password) + require.NoError(t, err) + + // Decrypt with keyserver + restored, err := FromSigilKS(stim, keyID, ks) + require.NoError(t, err) + assert.Equal(t, m.Config, restored.Config) +} + +func TestCacheKS(t *testing.T) { + ctx := context.Background() + ks := newTestKeyServer(t) + dir := t.TempDir() + + keyID, err := ks.ImportPassword(ctx, "cache-pass", "cache-key") + require.NoError(t, err) + + cache, err := NewCacheKS(dir, keyID, ks) + require.NoError(t, err) + + // Create TIM + m, err := New() + require.NoError(t, err) + m.RootFS.AddData("cached.txt", []byte("cached content")) + + // Store + err = cache.Store(ctx, "test-tim", m) + require.NoError(t, err) + + // Load + loaded, err := cache.Load(ctx, "test-tim") + require.NoError(t, err) + + assert.Equal(t, m.Config, loaded.Config) + + f, err := loaded.RootFS.Open("cached.txt") + require.NoError(t, err) + defer f.Close() + info, _ := f.Stat() + buf := make([]byte, info.Size()) + f.Read(buf) + assert.Equal(t, "cached content", string(buf)) +} + +func TestToSigilKSNilConfig(t *testing.T) { + ctx := context.Background() + ks := newTestKeyServer(t) + + keyID, _ := ks.ImportPassword(ctx, "pass", "key") + + m := &TerminalIsolationMatrix{ + Config: nil, + RootFS: datanode.New(), + } + + _, err := m.ToSigilKS(ctx, keyID, ks) + assert.ErrorIs(t, err, ErrConfigIsNil) +}