feat: add keyserver-aware crypto APIs for TIM, SMSG, and STMF
Adds *KS variant functions that delegate crypto operations to the Enchantrix keyserver — key material never leaves the keyserver boundary. - tim: ToSigilKS, FromSigilKS, CacheKS, RunEncryptedKS - smsg: EncryptKS, DecryptKS, EncryptV3KS, DecryptV3KS - stmf: DecryptKS, GenerateKeyPairKS All variants are backward-compatible with existing password-based APIs. Adds testify dependency for integration tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a77024aad4
commit
cb64050704
8 changed files with 948 additions and 2 deletions
6
go.mod
6
go.mod
|
|
@ -11,6 +11,7 @@ require (
|
|||
github.com/mattn/go-isatty v0.0.20
|
||||
github.com/schollz/progressbar/v3 v3.18.0
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/ulikunitz/xz v0.5.15
|
||||
github.com/wailsapp/wails/v2 v2.11.0
|
||||
golang.org/x/mod v0.30.0
|
||||
|
|
@ -25,6 +26,7 @@ require (
|
|||
github.com/bep/debounce v1.2.1 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cyphar/filepath-securejoin v0.4.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
|
||||
github.com/go-git/go-billy/v5 v5.6.2 // indirect
|
||||
|
|
@ -49,6 +51,7 @@ require (
|
|||
github.com/pjbgf/sha1cd v0.3.2 // indirect
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/samber/lo v1.49.1 // indirect
|
||||
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
|
||||
|
|
@ -65,4 +68,7 @@ require (
|
|||
golang.org/x/term v0.37.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace github.com/Snider/Enchantrix => ../Enchantrix
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -5,8 +5,6 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
|
|||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
|
||||
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
|
||||
github.com/Snider/Enchantrix v0.0.2 h1:ExZQiBhfS/p/AHFTKhY80TOd+BXZjK95EzByAEgwvjs=
|
||||
github.com/Snider/Enchantrix v0.0.2/go.mod h1:CtFcLAvnDT1KcuF1JBb/DJj0KplY8jHryO06KzQ1hsQ=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
|
|
|
|||
309
pkg/smsg/keyserver.go
Normal file
309
pkg/smsg/keyserver.go
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
package smsg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Enchantrix/pkg/enchantrix"
|
||||
"github.com/Snider/Enchantrix/pkg/keyserver"
|
||||
"github.com/Snider/Enchantrix/pkg/trix"
|
||||
)
|
||||
|
||||
// EncryptKS encrypts a message using the keyserver (V1 format).
|
||||
// The key material never leaves the keyserver — only the key ID is passed.
|
||||
func EncryptKS(ctx context.Context, msg *Message, keyID string, ks keyserver.KeyServer) ([]byte, error) {
|
||||
if msg.Body == "" && len(msg.Attachments) == 0 {
|
||||
return nil, ErrEmptyMessage
|
||||
}
|
||||
|
||||
if msg.Timestamp == 0 {
|
||||
msg.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Keyserver encrypts — key never leaves
|
||||
encrypted, err := ks.EncryptSMSG(ctx, keyID, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("keyserver encrypt: %w", err)
|
||||
}
|
||||
|
||||
t := &trix.Trix{
|
||||
Header: map[string]interface{}{
|
||||
"version": Version,
|
||||
"algorithm": "chacha20poly1305",
|
||||
},
|
||||
Payload: encrypted,
|
||||
}
|
||||
|
||||
return trix.Encode(t, Magic, nil)
|
||||
}
|
||||
|
||||
// DecryptKS decrypts an SMSG container using the keyserver.
|
||||
// Handles both V1 and V2 formats. The key material never leaves the keyserver.
|
||||
func DecryptKS(ctx context.Context, data []byte, keyID string, ks keyserver.KeyServer) (*Message, error) {
|
||||
t, err := trix.Decode(data, Magic, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidMagic, err)
|
||||
}
|
||||
|
||||
// Keyserver decrypts — key never leaves
|
||||
decrypted, err := ks.DecryptSMSG(ctx, keyID, t.Payload)
|
||||
if err != nil {
|
||||
return nil, ErrDecryptionFailed
|
||||
}
|
||||
|
||||
format := ""
|
||||
compression := ""
|
||||
if f, ok := t.Header["format"].(string); ok {
|
||||
format = f
|
||||
}
|
||||
if c, ok := t.Header["compression"].(string); ok {
|
||||
compression = c
|
||||
}
|
||||
|
||||
switch compression {
|
||||
case CompressionGzip:
|
||||
decompressed, err := gzipDecompress(decrypted)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip decompression failed: %w", err)
|
||||
}
|
||||
decrypted = decompressed
|
||||
case CompressionZstd:
|
||||
decompressed, err := zstdDecompress(decrypted)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("zstd decompression failed: %w", err)
|
||||
}
|
||||
decrypted = decompressed
|
||||
}
|
||||
|
||||
if format == FormatV2 {
|
||||
return parseV2Payload(decrypted)
|
||||
}
|
||||
|
||||
var msg Message
|
||||
if err := json.Unmarshal(decrypted, &msg); err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid message format", ErrInvalidPayload)
|
||||
}
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
// V3ParamsKS contains keyserver-aware parameters for V3 streaming.
|
||||
type V3ParamsKS struct {
|
||||
License string
|
||||
Fingerprint string
|
||||
Cadence Cadence
|
||||
}
|
||||
|
||||
// EncryptV3KS encrypts using V3 streaming format. Stream keys are derived and
|
||||
// CEK wrapping is done via keyserver — stream key material never leaves.
|
||||
// The CEK itself is a random per-message key used for content encryption.
|
||||
func EncryptV3KS(ctx context.Context, msg *Message, params V3ParamsKS, manifest *Manifest, ks keyserver.KeyServer) ([]byte, error) {
|
||||
if params.License == "" {
|
||||
return nil, ErrLicenseRequired
|
||||
}
|
||||
if msg.Body == "" && len(msg.Attachments) == 0 {
|
||||
return nil, ErrEmptyMessage
|
||||
}
|
||||
|
||||
if msg.Timestamp == 0 {
|
||||
msg.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
cadence := params.Cadence
|
||||
if cadence == "" {
|
||||
cadence = CadenceDaily
|
||||
}
|
||||
|
||||
current, next := GetRollingPeriods(cadence, time.Now().UTC())
|
||||
|
||||
// Generate random CEK (this is a per-message ephemeral key)
|
||||
cek, err := GenerateCEK()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Derive stream keys via keyserver — stream key material never leaves
|
||||
currentKeyID, err := ks.DeriveStreamKey(ctx, params.License, current, params.Fingerprint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive current stream key: %w", err)
|
||||
}
|
||||
defer ks.DeleteKey(ctx, currentKeyID)
|
||||
|
||||
nextKeyID, err := ks.DeriveStreamKey(ctx, params.License, next, params.Fingerprint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive next stream key: %w", err)
|
||||
}
|
||||
defer ks.DeleteKey(ctx, nextKeyID)
|
||||
|
||||
// Wrap CEK with stream keys via keyserver
|
||||
wrappedCurrent, err := ks.WrapCEK(ctx, currentKeyID, cek)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to wrap CEK for current period: %w", err)
|
||||
}
|
||||
|
||||
wrappedNext, err := ks.WrapCEK(ctx, nextKeyID, cek)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to wrap CEK for next period: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt content with CEK
|
||||
payload, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
compressed, err := zstdCompress(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compression failed: %w", err)
|
||||
}
|
||||
|
||||
sigil, err := enchantrix.NewChaChaPolySigil(cek)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create sigil: %w", err)
|
||||
}
|
||||
|
||||
encrypted, err := sigil.In(compressed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encryption failed: %w", err)
|
||||
}
|
||||
|
||||
// Build V3 header
|
||||
headerMap := map[string]interface{}{
|
||||
"version": Version,
|
||||
"algorithm": "chacha20poly1305",
|
||||
"format": FormatV3,
|
||||
"compression": CompressionZstd,
|
||||
"keyMethod": KeyMethodLTHNRolling,
|
||||
"cadence": string(cadence),
|
||||
"wrappedKeys": []WrappedKey{
|
||||
{Date: current, Wrapped: wrappedCurrent},
|
||||
{Date: next, Wrapped: wrappedNext},
|
||||
},
|
||||
}
|
||||
|
||||
if manifest != nil {
|
||||
if manifest.IssuedAt == 0 {
|
||||
manifest.IssuedAt = time.Now().Unix()
|
||||
}
|
||||
headerMap["manifest"] = manifest
|
||||
}
|
||||
|
||||
t := &trix.Trix{
|
||||
Header: headerMap,
|
||||
Payload: encrypted,
|
||||
}
|
||||
|
||||
return trix.Encode(t, Magic, nil)
|
||||
}
|
||||
|
||||
// DecryptV3KS decrypts a V3 streaming message. Stream keys are derived via
|
||||
// keyserver for CEK unwrapping — stream key material never leaves.
|
||||
func DecryptV3KS(ctx context.Context, data []byte, params V3ParamsKS, ks keyserver.KeyServer) (*Message, *Header, error) {
|
||||
if params.License == "" {
|
||||
return nil, nil, ErrLicenseRequired
|
||||
}
|
||||
|
||||
t, err := trix.Decode(data, Magic, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to decode container: %w", err)
|
||||
}
|
||||
|
||||
headerJSON, err := json.Marshal(t.Header)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to marshal header: %w", err)
|
||||
}
|
||||
|
||||
var header Header
|
||||
if err := json.Unmarshal(headerJSON, &header); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
if header.Format != FormatV3 {
|
||||
return nil, nil, fmt.Errorf("expected v3 format, got: %s", header.Format)
|
||||
}
|
||||
|
||||
cadence := header.Cadence
|
||||
if cadence == "" && params.Cadence != "" {
|
||||
cadence = params.Cadence
|
||||
}
|
||||
if cadence == "" {
|
||||
cadence = CadenceDaily
|
||||
}
|
||||
|
||||
// Unwrap CEK using keyserver-derived stream keys
|
||||
cek, err := tryUnwrapCEKKS(ctx, header.WrappedKeys, params, cadence, ks)
|
||||
if err != nil {
|
||||
return nil, &header, err
|
||||
}
|
||||
|
||||
// Decrypt payload with CEK
|
||||
sigil, err := enchantrix.NewChaChaPolySigil(cek)
|
||||
if err != nil {
|
||||
return nil, &header, fmt.Errorf("failed to create sigil: %w", err)
|
||||
}
|
||||
|
||||
compressed, err := sigil.Out(t.Payload)
|
||||
if err != nil {
|
||||
return nil, &header, ErrDecryptionFailed
|
||||
}
|
||||
|
||||
var decompressed []byte
|
||||
if header.Compression == CompressionZstd {
|
||||
decompressed, err = zstdDecompress(compressed)
|
||||
if err != nil {
|
||||
return nil, &header, fmt.Errorf("decompression failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
decompressed = compressed
|
||||
}
|
||||
|
||||
var msg Message
|
||||
if err := json.Unmarshal(decompressed, &msg); err != nil {
|
||||
return nil, &header, fmt.Errorf("failed to parse message: %w", err)
|
||||
}
|
||||
|
||||
return &msg, &header, nil
|
||||
}
|
||||
|
||||
// tryUnwrapCEKKS tries to unwrap CEK using keyserver-derived stream keys.
|
||||
func tryUnwrapCEKKS(ctx context.Context, wrappedKeys []WrappedKey, params V3ParamsKS, cadence Cadence, ks keyserver.KeyServer) ([]byte, error) {
|
||||
current, next := GetRollingPeriods(cadence, time.Now().UTC())
|
||||
|
||||
keysByPeriod := make(map[string]string)
|
||||
for _, wk := range wrappedKeys {
|
||||
keysByPeriod[wk.Date] = wk.Wrapped
|
||||
}
|
||||
|
||||
// Try current period
|
||||
if wrapped, ok := keysByPeriod[current]; ok {
|
||||
streamKeyID, err := ks.DeriveStreamKey(ctx, params.License, current, params.Fingerprint)
|
||||
if err == nil {
|
||||
cek, unwrapErr := ks.UnwrapCEK(ctx, streamKeyID, wrapped)
|
||||
ks.DeleteKey(ctx, streamKeyID)
|
||||
if unwrapErr == nil {
|
||||
return cek, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try next period
|
||||
if wrapped, ok := keysByPeriod[next]; ok {
|
||||
streamKeyID, err := ks.DeriveStreamKey(ctx, params.License, next, params.Fingerprint)
|
||||
if err == nil {
|
||||
cek, unwrapErr := ks.UnwrapCEK(ctx, streamKeyID, wrapped)
|
||||
ks.DeleteKey(ctx, streamKeyID)
|
||||
if unwrapErr == nil {
|
||||
return cek, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrNoValidKey
|
||||
}
|
||||
148
pkg/smsg/keyserver_test.go
Normal file
148
pkg/smsg/keyserver_test.go
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
package smsg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Enchantrix/pkg/keyserver"
|
||||
"github.com/Snider/Enchantrix/pkg/keystore"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestSMSGKeyServer(t *testing.T) *keyserver.Server {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
store, err := keystore.Create(filepath.Join(dir, "keys.trix"), "test-master")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { store.Close() })
|
||||
return keyserver.NewServer(store)
|
||||
}
|
||||
|
||||
func TestEncryptDecryptKSRoundTrip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSMSGKeyServer(t)
|
||||
|
||||
keyID, err := ks.ImportPassword(ctx, "smsg-password", "smsg-key")
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := NewMessage("Hello, keyserver SMSG!")
|
||||
msg.WithSubject("Test").WithFrom("alice")
|
||||
|
||||
encrypted, err := EncryptKS(ctx, msg, keyID, ks)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, encrypted)
|
||||
|
||||
decrypted, err := DecryptKS(ctx, encrypted, keyID, ks)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Hello, keyserver SMSG!", decrypted.Body)
|
||||
assert.Equal(t, "Test", decrypted.Subject)
|
||||
assert.Equal(t, "alice", decrypted.From)
|
||||
}
|
||||
|
||||
func TestKSCompatWithOldAPI(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSMSGKeyServer(t)
|
||||
|
||||
password := "compat-password"
|
||||
keyID, err := ks.ImportPassword(ctx, password, "compat")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Encrypt with old API
|
||||
msg := NewMessage("compatibility test")
|
||||
encrypted, err := Encrypt(msg, password)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decrypt with keyserver
|
||||
decrypted, err := DecryptKS(ctx, encrypted, keyID, ks)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "compatibility test", decrypted.Body)
|
||||
}
|
||||
|
||||
func TestKSEncryptOldDecrypt(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSMSGKeyServer(t)
|
||||
|
||||
password := "compat-password-2"
|
||||
keyID, err := ks.ImportPassword(ctx, password, "compat2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Encrypt with keyserver
|
||||
msg := NewMessage("keyserver encrypted")
|
||||
encrypted, err := EncryptKS(ctx, msg, keyID, ks)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decrypt with old API
|
||||
decrypted, err := Decrypt(encrypted, password)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "keyserver encrypted", decrypted.Body)
|
||||
}
|
||||
|
||||
func TestEncryptKSEmptyMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSMSGKeyServer(t)
|
||||
|
||||
keyID, _ := ks.ImportPassword(ctx, "pass", "key")
|
||||
|
||||
msg := &Message{}
|
||||
_, err := EncryptKS(ctx, msg, keyID, ks)
|
||||
assert.ErrorIs(t, err, ErrEmptyMessage)
|
||||
}
|
||||
|
||||
func TestV3KSRoundTrip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSMSGKeyServer(t)
|
||||
|
||||
params := V3ParamsKS{
|
||||
License: "test-license-123",
|
||||
Fingerprint: "device-fp-456",
|
||||
Cadence: CadenceDaily,
|
||||
}
|
||||
|
||||
msg := NewMessage("V3 keyserver streaming test")
|
||||
|
||||
encrypted, err := EncryptV3KS(ctx, msg, params, nil, ks)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, encrypted)
|
||||
|
||||
decrypted, header, err := DecryptV3KS(ctx, encrypted, params, ks)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "V3 keyserver streaming test", decrypted.Body)
|
||||
assert.Equal(t, FormatV3, header.Format)
|
||||
assert.Equal(t, KeyMethodLTHNRolling, header.KeyMethod)
|
||||
}
|
||||
|
||||
func TestV3KSWithManifest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSMSGKeyServer(t)
|
||||
|
||||
params := V3ParamsKS{
|
||||
License: "license-xyz",
|
||||
Fingerprint: "fp-abc",
|
||||
}
|
||||
|
||||
manifest := NewManifest("Test Track")
|
||||
manifest.Artist = "Test Artist"
|
||||
|
||||
msg := NewMessage("manifest test")
|
||||
|
||||
encrypted, err := EncryptV3KS(ctx, msg, params, manifest, ks)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, header, err := DecryptV3KS(ctx, encrypted, params, ks)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, header.Manifest)
|
||||
assert.Equal(t, "Test Track", header.Manifest.Title)
|
||||
}
|
||||
|
||||
func TestV3KSNoLicense(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSMSGKeyServer(t)
|
||||
|
||||
params := V3ParamsKS{Fingerprint: "fp"}
|
||||
msg := NewMessage("test")
|
||||
|
||||
_, err := EncryptV3KS(ctx, msg, params, nil, ks)
|
||||
assert.ErrorIs(t, err, ErrLicenseRequired)
|
||||
}
|
||||
43
pkg/stmf/keyserver.go
Normal file
43
pkg/stmf/keyserver.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
package stmf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Snider/Enchantrix/pkg/keyserver"
|
||||
"github.com/Snider/Enchantrix/pkg/keystore"
|
||||
)
|
||||
|
||||
// DecryptKS decrypts a STMF payload using the keyserver. The server's X25519
|
||||
// private key never leaves the keyserver — ECDH and decryption happen internally.
|
||||
func DecryptKS(ctx context.Context, stmfData []byte, keyID string, ks keyserver.KeyServer) (*FormData, error) {
|
||||
// Keyserver performs ECDH + decrypt internally — private key never leaves
|
||||
plaintext, err := ks.DecryptSTMF(ctx, keyID, stmfData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrDecryptionFailed, err)
|
||||
}
|
||||
|
||||
var formData FormData
|
||||
if err := json.Unmarshal(plaintext, &formData); err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid JSON payload: %v", ErrInvalidPayload, err)
|
||||
}
|
||||
|
||||
return &formData, nil
|
||||
}
|
||||
|
||||
// GenerateKeyPairKS generates an X25519 keypair in the keyserver and returns
|
||||
// the public key for distribution. The private key stays in the keyserver.
|
||||
func GenerateKeyPairKS(ctx context.Context, label string, ks keyserver.KeyServer) (publicKey []byte, keyID string, err error) {
|
||||
keyID, err = ks.GenerateKey(ctx, keystore.X25519, label)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to generate keypair: %w", err)
|
||||
}
|
||||
|
||||
publicKey, err = ks.GetPublicKey(ctx, keyID)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
return publicKey, keyID, nil
|
||||
}
|
||||
121
pkg/stmf/keyserver_test.go
Normal file
121
pkg/stmf/keyserver_test.go
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
package stmf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Enchantrix/pkg/keyserver"
|
||||
"github.com/Snider/Enchantrix/pkg/keystore"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestSTMFKeyServer(t *testing.T) *keyserver.Server {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
store, err := keystore.Create(filepath.Join(dir, "keys.trix"), "test-master")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { store.Close() })
|
||||
return keyserver.NewServer(store)
|
||||
}
|
||||
|
||||
func TestGenerateKeyPairKS(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSTMFKeyServer(t)
|
||||
|
||||
pubKey, keyID, err := GenerateKeyPairKS(ctx, "stmf-server", ks)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pubKey, 32)
|
||||
assert.NotEmpty(t, keyID)
|
||||
|
||||
// Public key should be retrievable
|
||||
pubKey2, err := ks.GetPublicKey(ctx, keyID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pubKey, pubKey2)
|
||||
}
|
||||
|
||||
func TestSTMFDecryptKS(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSTMFKeyServer(t)
|
||||
|
||||
// Generate server keypair via keyserver
|
||||
pubKey, keyID, err := GenerateKeyPairKS(ctx, "stmf-server", ks)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Client encrypts with public key (no keyserver needed)
|
||||
formData := NewFormData().
|
||||
AddField("username", "alice").
|
||||
AddFieldWithType("password", "secret123", "password")
|
||||
|
||||
encrypted, err := Encrypt(formData, pubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server decrypts via keyserver — private key never leaves
|
||||
decrypted, err := DecryptKS(ctx, encrypted, keyID, ks)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "alice", decrypted.Get("username"))
|
||||
assert.Equal(t, "secret123", decrypted.Get("password"))
|
||||
}
|
||||
|
||||
func TestSTMFKeyRotation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSTMFKeyServer(t)
|
||||
|
||||
// Generate first keypair
|
||||
pubKey1, keyID1, err := GenerateKeyPairKS(ctx, "server-v1", ks)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate second keypair (rotation)
|
||||
pubKey2, keyID2, err := GenerateKeyPairKS(ctx, "server-v2", ks)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, pubKey1, pubKey2)
|
||||
|
||||
// Encrypt with old key
|
||||
form1 := NewFormData().AddField("version", "1")
|
||||
enc1, err := Encrypt(form1, pubKey1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Encrypt with new key
|
||||
form2 := NewFormData().AddField("version", "2")
|
||||
enc2, err := Encrypt(form2, pubKey2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Old key still decrypts old forms
|
||||
dec1, err := DecryptKS(ctx, enc1, keyID1, ks)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1", dec1.Get("version"))
|
||||
|
||||
// New key decrypts new forms
|
||||
dec2, err := DecryptKS(ctx, enc2, keyID2, ks)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "2", dec2.Get("version"))
|
||||
|
||||
// New key cannot decrypt old forms (different ECDH shared secret)
|
||||
_, err = DecryptKS(ctx, enc1, keyID2, ks)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSTMFPrivateKeyNeverLeaves(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestSTMFKeyServer(t)
|
||||
|
||||
_, keyID, err := GenerateKeyPairKS(ctx, "secure-key", ks)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ListKeys should not expose key data
|
||||
keys, err := ks.ListKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, k := range keys {
|
||||
assert.Nil(t, k.KeyData, "ListKeys must not expose key data")
|
||||
}
|
||||
|
||||
// GetPublicKey only returns public component
|
||||
pubKey, err := ks.GetPublicKey(ctx, keyID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pubKey, 32) // public key only
|
||||
|
||||
// There's no GetPrivateKey method on the interface — by design
|
||||
}
|
||||
171
pkg/tim/keyserver.go
Normal file
171
pkg/tim/keyserver.go
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
package tim
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Enchantrix/pkg/keyserver"
|
||||
"github.com/Snider/Enchantrix/pkg/trix"
|
||||
)
|
||||
|
||||
// ToSigilKS encrypts the TIM using the keyserver. The key material never
|
||||
// leaves the keyserver — only the key ID is passed.
|
||||
// The output format matches ToSigil: a Trix container with "STIM" magic.
|
||||
func (m *TerminalIsolationMatrix) ToSigilKS(ctx context.Context, keyID string, ks keyserver.KeyServer) ([]byte, error) {
|
||||
if m.Config == nil {
|
||||
return nil, ErrConfigIsNil
|
||||
}
|
||||
|
||||
rootfsTar, err := m.RootFS.ToTar()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize rootfs: %w", err)
|
||||
}
|
||||
|
||||
// Keyserver encrypts config+rootfs atomically — key never leaves
|
||||
payload, err := ks.EncryptTIM(ctx, keyID, m.Config, rootfsTar)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("keyserver encrypt TIM: %w", err)
|
||||
}
|
||||
|
||||
// Parse encrypted config size from payload for header metadata
|
||||
configSize := binary.BigEndian.Uint32(payload[:4])
|
||||
rootfsSize := len(payload) - 4 - int(configSize)
|
||||
|
||||
// Wrap in Trix container with same header format as ToSigil
|
||||
t := &trix.Trix{
|
||||
Header: map[string]interface{}{
|
||||
"encryption_algorithm": "chacha20poly1305",
|
||||
"tim": true,
|
||||
"config_size": configSize,
|
||||
"rootfs_size": rootfsSize,
|
||||
"version": "1.0",
|
||||
},
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
return trix.Encode(t, "STIM", nil)
|
||||
}
|
||||
|
||||
// FromSigilKS decrypts and deserializes a .stim file using the keyserver.
|
||||
// The key material never leaves the keyserver.
|
||||
func FromSigilKS(data []byte, keyID string, ks keyserver.KeyServer) (*TerminalIsolationMatrix, error) {
|
||||
return FromSigilKSCtx(context.Background(), data, keyID, ks)
|
||||
}
|
||||
|
||||
// FromSigilKSCtx decrypts and deserializes a .stim file using the keyserver
|
||||
// with a context.
|
||||
func FromSigilKSCtx(ctx context.Context, data []byte, keyID string, ks keyserver.KeyServer) (*TerminalIsolationMatrix, error) {
|
||||
t, err := trix.Decode(data, "STIM", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode stim: %w", err)
|
||||
}
|
||||
|
||||
// Keyserver decrypts config+rootfs atomically — key never leaves
|
||||
config, rootfsTar, err := ks.DecryptTIM(ctx, keyID, t.Payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("keyserver decrypt TIM: %w", err)
|
||||
}
|
||||
|
||||
// Reconstruct DataNode from decrypted rootfs tar
|
||||
rootfs, err := datanode.FromTar(rootfsTar)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse rootfs: %w", err)
|
||||
}
|
||||
|
||||
return &TerminalIsolationMatrix{
|
||||
Config: config,
|
||||
RootFS: rootfs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CacheKS provides encrypted TIM storage using a keyserver.
|
||||
// The key material never leaves the keyserver — only the key ID is referenced.
|
||||
type CacheKS struct {
|
||||
Dir string
|
||||
KeyID string
|
||||
KS keyserver.KeyServer
|
||||
}
|
||||
|
||||
// NewCacheKS creates a keyserver-backed TIM cache.
|
||||
func NewCacheKS(dir string, keyID string, ks keyserver.KeyServer) (*CacheKS, error) {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CacheKS{Dir: dir, KeyID: keyID, KS: ks}, nil
|
||||
}
|
||||
|
||||
// Store encrypts and saves a TIM to the cache via keyserver.
|
||||
func (c *CacheKS) Store(ctx context.Context, name string, m *TerminalIsolationMatrix) error {
|
||||
data, err := m.ToSigilKS(ctx, c.KeyID, c.KS)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path := filepath.Join(c.Dir, name+".stim")
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
// Load retrieves and decrypts a TIM from the cache via keyserver.
|
||||
func (c *CacheKS) Load(ctx context.Context, name string) (*TerminalIsolationMatrix, error) {
|
||||
path := filepath.Join(c.Dir, name+".stim")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return FromSigilKSCtx(ctx, data, c.KeyID, c.KS)
|
||||
}
|
||||
|
||||
// RunEncryptedKS runs an encrypted .stim file, decrypting via keyserver.
|
||||
// The key material never leaves the keyserver.
|
||||
func RunEncryptedKS(ctx context.Context, stimPath, keyID string, ks keyserver.KeyServer) error {
|
||||
data, err := os.ReadFile(stimPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read stim file: %w", err)
|
||||
}
|
||||
|
||||
m, err := FromSigilKSCtx(ctx, data, keyID, ks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt stim: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp("", "borg-run-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), m.Config, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write config: %w", err)
|
||||
}
|
||||
|
||||
rootfsPath := filepath.Join(tempDir, "rootfs")
|
||||
if err := os.MkdirAll(rootfsPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create rootfs dir: %w", err)
|
||||
}
|
||||
|
||||
err = m.RootFS.Walk(".", func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
target := filepath.Join(rootfsPath, path)
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.RootFS.CopyFile(path, target, 0600)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract rootfs: %w", err)
|
||||
}
|
||||
|
||||
cmd := ExecCommand("runc", "run", "-b", tempDir, "borg-container")
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
150
pkg/tim/keyserver_test.go
Normal file
150
pkg/tim/keyserver_test.go
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
package tim
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Enchantrix/pkg/keyserver"
|
||||
"github.com/Snider/Enchantrix/pkg/keystore"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestKeyServer(t *testing.T) *keyserver.Server {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
store, err := keystore.Create(filepath.Join(dir, "keys.trix"), "test-master")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { store.Close() })
|
||||
return keyserver.NewServer(store)
|
||||
}
|
||||
|
||||
func TestToFromSigilKS(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestKeyServer(t)
|
||||
|
||||
keyID, err := ks.ImportPassword(ctx, "test-password", "tim-key")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build a TIM
|
||||
m, err := New()
|
||||
require.NoError(t, err)
|
||||
m.RootFS.AddData("hello.txt", []byte("Hello from the keyserver!"))
|
||||
|
||||
// Encrypt via keyserver
|
||||
stim, err := m.ToSigilKS(ctx, keyID, ks)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, stim)
|
||||
|
||||
// Decrypt via keyserver
|
||||
restored, err := FromSigilKS(stim, keyID, ks)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify config round-trips
|
||||
assert.Equal(t, m.Config, restored.Config)
|
||||
|
||||
// Verify rootfs round-trips
|
||||
f, err := restored.RootFS.Open("hello.txt")
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
info, _ := f.Stat()
|
||||
buf := make([]byte, info.Size())
|
||||
f.Read(buf)
|
||||
assert.Equal(t, "Hello from the keyserver!", string(buf))
|
||||
}
|
||||
|
||||
func TestToSigilKSMatchesOldFormat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestKeyServer(t)
|
||||
|
||||
password := "compat-test"
|
||||
keyID, err := ks.ImportPassword(ctx, password, "compat")
|
||||
require.NoError(t, err)
|
||||
|
||||
m, err := New()
|
||||
require.NoError(t, err)
|
||||
m.RootFS.AddData("file.txt", []byte("content"))
|
||||
|
||||
// Encrypt with keyserver
|
||||
stimKS, err := m.ToSigilKS(ctx, keyID, ks)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decrypt with old password API (should work — same key derivation)
|
||||
restored, err := FromSigil(stimKS, password)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, m.Config, restored.Config)
|
||||
}
|
||||
|
||||
func TestOldFormatDecryptsWithKS(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestKeyServer(t)
|
||||
|
||||
password := "compat-test-2"
|
||||
keyID, err := ks.ImportPassword(ctx, password, "compat2")
|
||||
require.NoError(t, err)
|
||||
|
||||
m, err := New()
|
||||
require.NoError(t, err)
|
||||
m.RootFS.AddData("data.bin", []byte{0xDE, 0xAD, 0xBE, 0xEF})
|
||||
|
||||
// Encrypt with old API
|
||||
stim, err := m.ToSigil(password)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decrypt with keyserver
|
||||
restored, err := FromSigilKS(stim, keyID, ks)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, m.Config, restored.Config)
|
||||
}
|
||||
|
||||
func TestCacheKS(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestKeyServer(t)
|
||||
dir := t.TempDir()
|
||||
|
||||
keyID, err := ks.ImportPassword(ctx, "cache-pass", "cache-key")
|
||||
require.NoError(t, err)
|
||||
|
||||
cache, err := NewCacheKS(dir, keyID, ks)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create TIM
|
||||
m, err := New()
|
||||
require.NoError(t, err)
|
||||
m.RootFS.AddData("cached.txt", []byte("cached content"))
|
||||
|
||||
// Store
|
||||
err = cache.Store(ctx, "test-tim", m)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load
|
||||
loaded, err := cache.Load(ctx, "test-tim")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, m.Config, loaded.Config)
|
||||
|
||||
f, err := loaded.RootFS.Open("cached.txt")
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
info, _ := f.Stat()
|
||||
buf := make([]byte, info.Size())
|
||||
f.Read(buf)
|
||||
assert.Equal(t, "cached content", string(buf))
|
||||
}
|
||||
|
||||
func TestToSigilKSNilConfig(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ks := newTestKeyServer(t)
|
||||
|
||||
keyID, _ := ks.ImportPassword(ctx, "pass", "key")
|
||||
|
||||
m := &TerminalIsolationMatrix{
|
||||
Config: nil,
|
||||
RootFS: datanode.New(),
|
||||
}
|
||||
|
||||
_, err := m.ToSigilKS(ctx, keyID, ks)
|
||||
assert.ErrorIs(t, err, ErrConfigIsNil)
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue