152 lines
4.3 KiB
Go
152 lines
4.3 KiB
Go
|
|
package stmf
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/ecdh"
|
||
|
|
"crypto/sha256"
|
||
|
|
"encoding/base64"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
|
||
|
|
"github.com/Snider/Enchantrix/pkg/enchantrix"
|
||
|
|
"github.com/Snider/Enchantrix/pkg/trix"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Decrypt decrypts a STMF payload using the server's private key.
|
||
|
|
// It extracts the ephemeral public key from the header, performs ECDH,
|
||
|
|
// and decrypts with ChaCha20-Poly1305.
|
||
|
|
func Decrypt(stmfData []byte, serverPrivateKey []byte) (*FormData, error) {
|
||
|
|
// Load server's private key
|
||
|
|
serverPriv, err := LoadPrivateKey(serverPrivateKey)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return DecryptWithKey(stmfData, serverPriv)
|
||
|
|
}
|
||
|
|
|
||
|
|
// DecryptBase64 decrypts a base64-encoded STMF payload
|
||
|
|
func DecryptBase64(encoded string, serverPrivateKey []byte) (*FormData, error) {
|
||
|
|
data, err := base64.StdEncoding.DecodeString(encoded)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("%w: invalid base64: %v", ErrInvalidPayload, err)
|
||
|
|
}
|
||
|
|
return Decrypt(data, serverPrivateKey)
|
||
|
|
}
|
||
|
|
|
||
|
|
// DecryptWithKey decrypts a STMF payload using a pre-loaded private key
|
||
|
|
func DecryptWithKey(stmfData []byte, serverPrivateKey *ecdh.PrivateKey) (*FormData, error) {
|
||
|
|
// Decode the trix container
|
||
|
|
t, err := trix.Decode(stmfData, Magic, nil)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("%w: %v", ErrInvalidMagic, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Extract ephemeral public key from header
|
||
|
|
ephemeralPKBase64, ok := t.Header["ephemeral_pk"].(string)
|
||
|
|
if !ok {
|
||
|
|
return nil, fmt.Errorf("%w: missing ephemeral_pk in header", ErrInvalidPayload)
|
||
|
|
}
|
||
|
|
|
||
|
|
ephemeralPKBytes, err := base64.StdEncoding.DecodeString(ephemeralPKBase64)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("%w: invalid ephemeral_pk base64: %v", ErrInvalidPayload, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Load ephemeral public key
|
||
|
|
ephemeralPub, err := LoadPublicKey(ephemeralPKBytes)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("%w: invalid ephemeral public key: %v", ErrInvalidPayload, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Perform ECDH key exchange (server private * ephemeral public = shared secret)
|
||
|
|
sharedSecret, err := serverPrivateKey.ECDH(ephemeralPub)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("ECDH failed: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Derive symmetric key using SHA-256 (same as encryption)
|
||
|
|
symmetricKey := sha256.Sum256(sharedSecret)
|
||
|
|
|
||
|
|
// Create ChaCha20-Poly1305 sigil
|
||
|
|
sigil, err := enchantrix.NewChaChaPolySigil(symmetricKey[:])
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to create sigil: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Decrypt the payload
|
||
|
|
decrypted, err := sigil.Out(t.Payload)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("%w: %v", ErrDecryptionFailed, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Unmarshal form data
|
||
|
|
var formData FormData
|
||
|
|
if err := json.Unmarshal(decrypted, &formData); err != nil {
|
||
|
|
return nil, fmt.Errorf("%w: invalid JSON payload: %v", ErrInvalidPayload, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &formData, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// DecryptToMap is a convenience function that returns the form data as a simple map
|
||
|
|
func DecryptToMap(stmfData []byte, serverPrivateKey []byte) (map[string]string, error) {
|
||
|
|
formData, err := Decrypt(stmfData, serverPrivateKey)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return formData.ToMap(), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// DecryptBase64ToMap decrypts base64 and returns a map
|
||
|
|
func DecryptBase64ToMap(encoded string, serverPrivateKey []byte) (map[string]string, error) {
|
||
|
|
formData, err := DecryptBase64(encoded, serverPrivateKey)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return formData.ToMap(), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// ValidatePayload checks if the data is a valid STMF container without decrypting
|
||
|
|
func ValidatePayload(stmfData []byte) error {
|
||
|
|
t, err := trix.Decode(stmfData, Magic, nil)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("%w: %v", ErrInvalidMagic, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check required header fields
|
||
|
|
if _, ok := t.Header["ephemeral_pk"].(string); !ok {
|
||
|
|
return fmt.Errorf("%w: missing ephemeral_pk", ErrInvalidPayload)
|
||
|
|
}
|
||
|
|
|
||
|
|
if _, ok := t.Header["algorithm"].(string); !ok {
|
||
|
|
return fmt.Errorf("%w: missing algorithm", ErrInvalidPayload)
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetPayloadInfo extracts metadata from a STMF payload without decrypting
|
||
|
|
func GetPayloadInfo(stmfData []byte) (*Header, error) {
|
||
|
|
t, err := trix.Decode(stmfData, Magic, nil)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("%w: %v", ErrInvalidMagic, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
header := &Header{}
|
||
|
|
|
||
|
|
if v, ok := t.Header["version"].(string); ok {
|
||
|
|
header.Version = v
|
||
|
|
}
|
||
|
|
if v, ok := t.Header["algorithm"].(string); ok {
|
||
|
|
header.Algorithm = v
|
||
|
|
}
|
||
|
|
if v, ok := t.Header["ephemeral_pk"].(string); ok {
|
||
|
|
header.EphemeralPK = v
|
||
|
|
}
|
||
|
|
if v, ok := t.Header["nonce"].(string); ok {
|
||
|
|
header.Nonce = v
|
||
|
|
}
|
||
|
|
|
||
|
|
return header, nil
|
||
|
|
}
|