Borg/pkg/smsg/keyserver.go
Claude cb64050704
feat: add keyserver-aware crypto APIs for TIM, SMSG, and STMF
Adds *KS variant functions that delegate crypto operations to the
Enchantrix keyserver — key material never leaves the keyserver boundary.

- tim: ToSigilKS, FromSigilKS, CacheKS, RunEncryptedKS
- smsg: EncryptKS, DecryptKS, EncryptV3KS, DecryptV3KS
- stmf: DecryptKS, GenerateKeyPairKS

All variants are backward-compatible with existing password-based APIs.
Adds testify dependency for integration tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-05 21:30:44 +00:00

309 lines
8.2 KiB
Go

package smsg
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Snider/Enchantrix/pkg/enchantrix"
"github.com/Snider/Enchantrix/pkg/keyserver"
"github.com/Snider/Enchantrix/pkg/trix"
)
// EncryptKS encrypts a message using the keyserver (V1 format).
// The key material never leaves the keyserver — only the key ID is passed.
func EncryptKS(ctx context.Context, msg *Message, keyID string, ks keyserver.KeyServer) ([]byte, error) {
if msg.Body == "" && len(msg.Attachments) == 0 {
return nil, ErrEmptyMessage
}
if msg.Timestamp == 0 {
msg.Timestamp = time.Now().Unix()
}
payload, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("failed to marshal message: %w", err)
}
// Keyserver encrypts — key never leaves
encrypted, err := ks.EncryptSMSG(ctx, keyID, payload)
if err != nil {
return nil, fmt.Errorf("keyserver encrypt: %w", err)
}
t := &trix.Trix{
Header: map[string]interface{}{
"version": Version,
"algorithm": "chacha20poly1305",
},
Payload: encrypted,
}
return trix.Encode(t, Magic, nil)
}
// DecryptKS decrypts an SMSG container using the keyserver.
// Handles both V1 and V2 formats. The key material never leaves the keyserver.
func DecryptKS(ctx context.Context, data []byte, keyID string, ks keyserver.KeyServer) (*Message, error) {
t, err := trix.Decode(data, Magic, nil)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidMagic, err)
}
// Keyserver decrypts — key never leaves
decrypted, err := ks.DecryptSMSG(ctx, keyID, t.Payload)
if err != nil {
return nil, ErrDecryptionFailed
}
format := ""
compression := ""
if f, ok := t.Header["format"].(string); ok {
format = f
}
if c, ok := t.Header["compression"].(string); ok {
compression = c
}
switch compression {
case CompressionGzip:
decompressed, err := gzipDecompress(decrypted)
if err != nil {
return nil, fmt.Errorf("gzip decompression failed: %w", err)
}
decrypted = decompressed
case CompressionZstd:
decompressed, err := zstdDecompress(decrypted)
if err != nil {
return nil, fmt.Errorf("zstd decompression failed: %w", err)
}
decrypted = decompressed
}
if format == FormatV2 {
return parseV2Payload(decrypted)
}
var msg Message
if err := json.Unmarshal(decrypted, &msg); err != nil {
return nil, fmt.Errorf("%w: invalid message format", ErrInvalidPayload)
}
return &msg, nil
}
// V3ParamsKS contains keyserver-aware parameters for V3 streaming.
type V3ParamsKS struct {
License string
Fingerprint string
Cadence Cadence
}
// EncryptV3KS encrypts using V3 streaming format. Stream keys are derived and
// CEK wrapping is done via keyserver — stream key material never leaves.
// The CEK itself is a random per-message key used for content encryption.
func EncryptV3KS(ctx context.Context, msg *Message, params V3ParamsKS, manifest *Manifest, ks keyserver.KeyServer) ([]byte, error) {
if params.License == "" {
return nil, ErrLicenseRequired
}
if msg.Body == "" && len(msg.Attachments) == 0 {
return nil, ErrEmptyMessage
}
if msg.Timestamp == 0 {
msg.Timestamp = time.Now().Unix()
}
cadence := params.Cadence
if cadence == "" {
cadence = CadenceDaily
}
current, next := GetRollingPeriods(cadence, time.Now().UTC())
// Generate random CEK (this is a per-message ephemeral key)
cek, err := GenerateCEK()
if err != nil {
return nil, err
}
// Derive stream keys via keyserver — stream key material never leaves
currentKeyID, err := ks.DeriveStreamKey(ctx, params.License, current, params.Fingerprint)
if err != nil {
return nil, fmt.Errorf("failed to derive current stream key: %w", err)
}
defer ks.DeleteKey(ctx, currentKeyID)
nextKeyID, err := ks.DeriveStreamKey(ctx, params.License, next, params.Fingerprint)
if err != nil {
return nil, fmt.Errorf("failed to derive next stream key: %w", err)
}
defer ks.DeleteKey(ctx, nextKeyID)
// Wrap CEK with stream keys via keyserver
wrappedCurrent, err := ks.WrapCEK(ctx, currentKeyID, cek)
if err != nil {
return nil, fmt.Errorf("failed to wrap CEK for current period: %w", err)
}
wrappedNext, err := ks.WrapCEK(ctx, nextKeyID, cek)
if err != nil {
return nil, fmt.Errorf("failed to wrap CEK for next period: %w", err)
}
// Encrypt content with CEK
payload, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("failed to marshal message: %w", err)
}
compressed, err := zstdCompress(payload)
if err != nil {
return nil, fmt.Errorf("compression failed: %w", err)
}
sigil, err := enchantrix.NewChaChaPolySigil(cek)
if err != nil {
return nil, fmt.Errorf("failed to create sigil: %w", err)
}
encrypted, err := sigil.In(compressed)
if err != nil {
return nil, fmt.Errorf("encryption failed: %w", err)
}
// Build V3 header
headerMap := map[string]interface{}{
"version": Version,
"algorithm": "chacha20poly1305",
"format": FormatV3,
"compression": CompressionZstd,
"keyMethod": KeyMethodLTHNRolling,
"cadence": string(cadence),
"wrappedKeys": []WrappedKey{
{Date: current, Wrapped: wrappedCurrent},
{Date: next, Wrapped: wrappedNext},
},
}
if manifest != nil {
if manifest.IssuedAt == 0 {
manifest.IssuedAt = time.Now().Unix()
}
headerMap["manifest"] = manifest
}
t := &trix.Trix{
Header: headerMap,
Payload: encrypted,
}
return trix.Encode(t, Magic, nil)
}
// DecryptV3KS decrypts a V3 streaming message. Stream keys are derived via
// keyserver for CEK unwrapping — stream key material never leaves.
func DecryptV3KS(ctx context.Context, data []byte, params V3ParamsKS, ks keyserver.KeyServer) (*Message, *Header, error) {
if params.License == "" {
return nil, nil, ErrLicenseRequired
}
t, err := trix.Decode(data, Magic, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to decode container: %w", err)
}
headerJSON, err := json.Marshal(t.Header)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal header: %w", err)
}
var header Header
if err := json.Unmarshal(headerJSON, &header); err != nil {
return nil, nil, fmt.Errorf("failed to parse header: %w", err)
}
if header.Format != FormatV3 {
return nil, nil, fmt.Errorf("expected v3 format, got: %s", header.Format)
}
cadence := header.Cadence
if cadence == "" && params.Cadence != "" {
cadence = params.Cadence
}
if cadence == "" {
cadence = CadenceDaily
}
// Unwrap CEK using keyserver-derived stream keys
cek, err := tryUnwrapCEKKS(ctx, header.WrappedKeys, params, cadence, ks)
if err != nil {
return nil, &header, err
}
// Decrypt payload with CEK
sigil, err := enchantrix.NewChaChaPolySigil(cek)
if err != nil {
return nil, &header, fmt.Errorf("failed to create sigil: %w", err)
}
compressed, err := sigil.Out(t.Payload)
if err != nil {
return nil, &header, ErrDecryptionFailed
}
var decompressed []byte
if header.Compression == CompressionZstd {
decompressed, err = zstdDecompress(compressed)
if err != nil {
return nil, &header, fmt.Errorf("decompression failed: %w", err)
}
} else {
decompressed = compressed
}
var msg Message
if err := json.Unmarshal(decompressed, &msg); err != nil {
return nil, &header, fmt.Errorf("failed to parse message: %w", err)
}
return &msg, &header, nil
}
// tryUnwrapCEKKS tries to unwrap CEK using keyserver-derived stream keys.
func tryUnwrapCEKKS(ctx context.Context, wrappedKeys []WrappedKey, params V3ParamsKS, cadence Cadence, ks keyserver.KeyServer) ([]byte, error) {
current, next := GetRollingPeriods(cadence, time.Now().UTC())
keysByPeriod := make(map[string]string)
for _, wk := range wrappedKeys {
keysByPeriod[wk.Date] = wk.Wrapped
}
// Try current period
if wrapped, ok := keysByPeriod[current]; ok {
streamKeyID, err := ks.DeriveStreamKey(ctx, params.License, current, params.Fingerprint)
if err == nil {
cek, unwrapErr := ks.UnwrapCEK(ctx, streamKeyID, wrapped)
ks.DeleteKey(ctx, streamKeyID)
if unwrapErr == nil {
return cek, nil
}
}
}
// Try next period
if wrapped, ok := keysByPeriod[next]; ok {
streamKeyID, err := ks.DeriveStreamKey(ctx, params.License, next, params.Fingerprint)
if err == nil {
cek, unwrapErr := ks.UnwrapCEK(ctx, streamKeyID, wrapped)
ks.DeleteKey(ctx, streamKeyID)
if unwrapErr == nil {
return cek, nil
}
}
}
return nil, ErrNoValidKey
}