feat: extract crypto/security packages from core/go
ChaCha20-Poly1305, AES-256-GCM, Argon2 key derivation, OpenPGP challenge-response auth, and trust tier policy engine. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
8498ecf890
30 changed files with 3746 additions and 0 deletions
455
auth/auth.go
Normal file
455
auth/auth.go
Normal file
|
|
@ -0,0 +1,455 @@
|
|||
// Package auth implements OpenPGP challenge-response authentication with
|
||||
// support for both online (HTTP) and air-gapped (file-based) transport.
|
||||
//
|
||||
// Ported from dAppServer's mod-auth/lethean.service.ts.
|
||||
//
|
||||
// Authentication Flow (Online):
|
||||
//
|
||||
// 1. Client sends public key to server
|
||||
// 2. Server generates a random nonce, encrypts it with client's public key
|
||||
// 3. Client decrypts the nonce and signs it with their private key
|
||||
// 4. Server verifies the signature, creates a session token
|
||||
//
|
||||
// Authentication Flow (Air-Gapped / Courier):
|
||||
//
|
||||
// Same crypto but challenge/response are exchanged via files on a Medium.
|
||||
//
|
||||
// Storage Layout (via Medium):
|
||||
//
|
||||
// users/
|
||||
// {userID}.pub PGP public key (armored)
|
||||
// {userID}.key PGP private key (armored, password-encrypted)
|
||||
// {userID}.rev Revocation certificate (placeholder)
|
||||
// {userID}.json User metadata (encrypted with user's public key)
|
||||
// {userID}.lthn LTHN password hash
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go/pkg/framework/core"
|
||||
|
||||
"forge.lthn.ai/core/go-crypt/crypt/lthn"
|
||||
"forge.lthn.ai/core/go-crypt/crypt/pgp"
|
||||
"forge.lthn.ai/core/go/pkg/io"
|
||||
)
|
||||
|
||||
// Default durations for challenge and session lifetimes.
|
||||
const (
|
||||
DefaultChallengeTTL = 5 * time.Minute
|
||||
DefaultSessionTTL = 24 * time.Hour
|
||||
nonceBytes = 32
|
||||
)
|
||||
|
||||
// protectedUsers lists usernames that cannot be deleted.
|
||||
// The "server" user holds the server keypair; deleting it would
|
||||
// permanently destroy all joining data and require a full rebuild.
|
||||
var protectedUsers = map[string]bool{
|
||||
"server": true,
|
||||
}
|
||||
|
||||
// User represents a registered user with PGP credentials.
|
||||
type User struct {
|
||||
PublicKey string `json:"public_key"`
|
||||
KeyID string `json:"key_id"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
PasswordHash string `json:"password_hash"` // LTHN hash
|
||||
Created time.Time `json:"created"`
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
}
|
||||
|
||||
// Challenge is a PGP-encrypted nonce sent to a client during authentication.
|
||||
type Challenge struct {
|
||||
Nonce []byte `json:"nonce"`
|
||||
Encrypted string `json:"encrypted"` // PGP-encrypted nonce (armored)
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// Session represents an authenticated session.
|
||||
type Session struct {
|
||||
Token string `json:"token"`
|
||||
UserID string `json:"user_id"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// Option configures an Authenticator.
|
||||
type Option func(*Authenticator)
|
||||
|
||||
// WithChallengeTTL sets the lifetime of a challenge before it expires.
|
||||
func WithChallengeTTL(d time.Duration) Option {
|
||||
return func(a *Authenticator) {
|
||||
a.challengeTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
// WithSessionTTL sets the lifetime of a session before it expires.
|
||||
func WithSessionTTL(d time.Duration) Option {
|
||||
return func(a *Authenticator) {
|
||||
a.sessionTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticator manages PGP-based challenge-response authentication.
|
||||
// All user data and keys are persisted through an io.Medium, which may
|
||||
// be backed by disk, memory (MockMedium), or any other storage backend.
|
||||
type Authenticator struct {
|
||||
medium io.Medium
|
||||
sessions map[string]*Session
|
||||
challenges map[string]*Challenge // userID -> pending challenge
|
||||
mu sync.RWMutex
|
||||
challengeTTL time.Duration
|
||||
sessionTTL time.Duration
|
||||
}
|
||||
|
||||
// New creates an Authenticator that persists user data via the given Medium.
|
||||
func New(m io.Medium, opts ...Option) *Authenticator {
|
||||
a := &Authenticator{
|
||||
medium: m,
|
||||
sessions: make(map[string]*Session),
|
||||
challenges: make(map[string]*Challenge),
|
||||
challengeTTL: DefaultChallengeTTL,
|
||||
sessionTTL: DefaultSessionTTL,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// userPath returns the storage path for a user artifact.
|
||||
func userPath(userID, ext string) string {
|
||||
return "users/" + userID + ext
|
||||
}
|
||||
|
||||
// Register creates a new user account. It hashes the username with LTHN to
|
||||
// produce a userID, generates a PGP keypair (protected by the given password),
|
||||
// and persists the public key, private key, revocation placeholder, password
|
||||
// hash, and encrypted metadata via the Medium.
|
||||
func (a *Authenticator) Register(username, password string) (*User, error) {
|
||||
const op = "auth.Register"
|
||||
|
||||
userID := lthn.Hash(username)
|
||||
|
||||
// Check if user already exists
|
||||
if a.medium.IsFile(userPath(userID, ".pub")) {
|
||||
return nil, coreerr.E(op, "user already exists", nil)
|
||||
}
|
||||
|
||||
// Ensure users directory exists
|
||||
if err := a.medium.EnsureDir("users"); err != nil {
|
||||
return nil, coreerr.E(op, "failed to create users directory", err)
|
||||
}
|
||||
|
||||
// Generate PGP keypair
|
||||
kp, err := pgp.CreateKeyPair(userID, userID+"@auth.local", password)
|
||||
if err != nil {
|
||||
return nil, coreerr.E(op, "failed to create PGP keypair", err)
|
||||
}
|
||||
|
||||
// Store public key
|
||||
if err := a.medium.Write(userPath(userID, ".pub"), kp.PublicKey); err != nil {
|
||||
return nil, coreerr.E(op, "failed to write public key", err)
|
||||
}
|
||||
|
||||
// Store private key (already encrypted by PGP if password is non-empty)
|
||||
if err := a.medium.Write(userPath(userID, ".key"), kp.PrivateKey); err != nil {
|
||||
return nil, coreerr.E(op, "failed to write private key", err)
|
||||
}
|
||||
|
||||
// Store revocation certificate placeholder
|
||||
if err := a.medium.Write(userPath(userID, ".rev"), "REVOCATION_PLACEHOLDER"); err != nil {
|
||||
return nil, coreerr.E(op, "failed to write revocation certificate", err)
|
||||
}
|
||||
|
||||
// Store LTHN password hash
|
||||
passwordHash := lthn.Hash(password)
|
||||
if err := a.medium.Write(userPath(userID, ".lthn"), passwordHash); err != nil {
|
||||
return nil, coreerr.E(op, "failed to write password hash", err)
|
||||
}
|
||||
|
||||
// Build user metadata
|
||||
now := time.Now()
|
||||
user := &User{
|
||||
PublicKey: kp.PublicKey,
|
||||
KeyID: userID,
|
||||
Fingerprint: lthn.Hash(kp.PublicKey),
|
||||
PasswordHash: passwordHash,
|
||||
Created: now,
|
||||
LastLogin: time.Time{},
|
||||
}
|
||||
|
||||
// Encrypt metadata with the user's public key and store
|
||||
metaJSON, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
return nil, coreerr.E(op, "failed to marshal user metadata", err)
|
||||
}
|
||||
|
||||
encMeta, err := pgp.Encrypt(metaJSON, kp.PublicKey)
|
||||
if err != nil {
|
||||
return nil, coreerr.E(op, "failed to encrypt user metadata", err)
|
||||
}
|
||||
|
||||
if err := a.medium.Write(userPath(userID, ".json"), string(encMeta)); err != nil {
|
||||
return nil, coreerr.E(op, "failed to write user metadata", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// CreateChallenge generates a cryptographic challenge for the given user.
|
||||
// A random nonce is created and encrypted with the user's PGP public key.
|
||||
// The client must decrypt the nonce and sign it to prove key ownership.
|
||||
func (a *Authenticator) CreateChallenge(userID string) (*Challenge, error) {
|
||||
const op = "auth.CreateChallenge"
|
||||
|
||||
// Read user's public key
|
||||
pubKey, err := a.medium.Read(userPath(userID, ".pub"))
|
||||
if err != nil {
|
||||
return nil, coreerr.E(op, "user not found", err)
|
||||
}
|
||||
|
||||
// Generate random nonce
|
||||
nonce := make([]byte, nonceBytes)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, coreerr.E(op, "failed to generate nonce", err)
|
||||
}
|
||||
|
||||
// Encrypt nonce with user's public key
|
||||
encrypted, err := pgp.Encrypt(nonce, pubKey)
|
||||
if err != nil {
|
||||
return nil, coreerr.E(op, "failed to encrypt nonce", err)
|
||||
}
|
||||
|
||||
challenge := &Challenge{
|
||||
Nonce: nonce,
|
||||
Encrypted: string(encrypted),
|
||||
ExpiresAt: time.Now().Add(a.challengeTTL),
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
a.challenges[userID] = challenge
|
||||
a.mu.Unlock()
|
||||
|
||||
return challenge, nil
|
||||
}
|
||||
|
||||
// ValidateResponse verifies a signed nonce from the client. The client must
|
||||
// have decrypted the challenge nonce and signed it with their private key.
|
||||
// On success, a new session is created and returned.
|
||||
func (a *Authenticator) ValidateResponse(userID string, signedNonce []byte) (*Session, error) {
|
||||
const op = "auth.ValidateResponse"
|
||||
|
||||
a.mu.Lock()
|
||||
challenge, exists := a.challenges[userID]
|
||||
if exists {
|
||||
delete(a.challenges, userID)
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
return nil, coreerr.E(op, "no pending challenge for user", nil)
|
||||
}
|
||||
|
||||
// Check challenge expiry
|
||||
if time.Now().After(challenge.ExpiresAt) {
|
||||
return nil, coreerr.E(op, "challenge expired", nil)
|
||||
}
|
||||
|
||||
// Read user's public key
|
||||
pubKey, err := a.medium.Read(userPath(userID, ".pub"))
|
||||
if err != nil {
|
||||
return nil, coreerr.E(op, "user not found", err)
|
||||
}
|
||||
|
||||
// Verify signature over the original nonce
|
||||
if err := pgp.Verify(challenge.Nonce, signedNonce, pubKey); err != nil {
|
||||
return nil, coreerr.E(op, "signature verification failed", err)
|
||||
}
|
||||
|
||||
return a.createSession(userID)
|
||||
}
|
||||
|
||||
// ValidateSession checks whether a token maps to a valid, non-expired session.
|
||||
func (a *Authenticator) ValidateSession(token string) (*Session, error) {
|
||||
const op = "auth.ValidateSession"
|
||||
|
||||
a.mu.RLock()
|
||||
session, exists := a.sessions[token]
|
||||
a.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, coreerr.E(op, "session not found", nil)
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
a.mu.Lock()
|
||||
delete(a.sessions, token)
|
||||
a.mu.Unlock()
|
||||
return nil, coreerr.E(op, "session expired", nil)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// RefreshSession extends the expiry of an existing valid session.
|
||||
func (a *Authenticator) RefreshSession(token string) (*Session, error) {
|
||||
const op = "auth.RefreshSession"
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
session, exists := a.sessions[token]
|
||||
if !exists {
|
||||
return nil, coreerr.E(op, "session not found", nil)
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
delete(a.sessions, token)
|
||||
return nil, coreerr.E(op, "session expired", nil)
|
||||
}
|
||||
|
||||
session.ExpiresAt = time.Now().Add(a.sessionTTL)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// RevokeSession removes a session, invalidating the token immediately.
|
||||
func (a *Authenticator) RevokeSession(token string) error {
|
||||
const op = "auth.RevokeSession"
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if _, exists := a.sessions[token]; !exists {
|
||||
return coreerr.E(op, "session not found", nil)
|
||||
}
|
||||
|
||||
delete(a.sessions, token)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser removes a user and all associated keys from storage.
|
||||
// The "server" user is protected and cannot be deleted (mirroring the
|
||||
// original TypeScript implementation's safeguard).
|
||||
func (a *Authenticator) DeleteUser(userID string) error {
|
||||
const op = "auth.DeleteUser"
|
||||
|
||||
// Protect special users
|
||||
if protectedUsers[userID] {
|
||||
return coreerr.E(op, "cannot delete protected user", nil)
|
||||
}
|
||||
|
||||
// Check user exists
|
||||
if !a.medium.IsFile(userPath(userID, ".pub")) {
|
||||
return coreerr.E(op, "user not found", nil)
|
||||
}
|
||||
|
||||
// Remove all artifacts
|
||||
extensions := []string{".pub", ".key", ".rev", ".json", ".lthn"}
|
||||
for _, ext := range extensions {
|
||||
p := userPath(userID, ext)
|
||||
if a.medium.IsFile(p) {
|
||||
if err := a.medium.Delete(p); err != nil {
|
||||
return coreerr.E(op, "failed to delete "+ext, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke any active sessions for this user
|
||||
a.mu.Lock()
|
||||
for token, session := range a.sessions {
|
||||
if session.UserID == userID {
|
||||
delete(a.sessions, token)
|
||||
}
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Login performs password-based authentication as a convenience method.
|
||||
// It verifies the password against the stored LTHN hash and, on success,
|
||||
// creates a new session. This bypasses the PGP challenge-response flow.
|
||||
func (a *Authenticator) Login(userID, password string) (*Session, error) {
|
||||
const op = "auth.Login"
|
||||
|
||||
// Read stored password hash
|
||||
storedHash, err := a.medium.Read(userPath(userID, ".lthn"))
|
||||
if err != nil {
|
||||
return nil, coreerr.E(op, "user not found", err)
|
||||
}
|
||||
|
||||
// Verify password
|
||||
if !lthn.Verify(password, storedHash) {
|
||||
return nil, coreerr.E(op, "invalid password", nil)
|
||||
}
|
||||
|
||||
return a.createSession(userID)
|
||||
}
|
||||
|
||||
// WriteChallengeFile writes an encrypted challenge to a file for air-gapped
|
||||
// (courier) transport. The challenge is created and then its encrypted nonce
|
||||
// is written to the specified path on the Medium.
|
||||
func (a *Authenticator) WriteChallengeFile(userID, path string) error {
|
||||
const op = "auth.WriteChallengeFile"
|
||||
|
||||
challenge, err := a.CreateChallenge(userID)
|
||||
if err != nil {
|
||||
return coreerr.E(op, "failed to create challenge", err)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(challenge)
|
||||
if err != nil {
|
||||
return coreerr.E(op, "failed to marshal challenge", err)
|
||||
}
|
||||
|
||||
if err := a.medium.Write(path, string(data)); err != nil {
|
||||
return coreerr.E(op, "failed to write challenge file", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadResponseFile reads a signed response from a file and validates it,
|
||||
// completing the air-gapped authentication flow. The file must contain the
|
||||
// raw PGP signature bytes (armored).
|
||||
func (a *Authenticator) ReadResponseFile(userID, path string) (*Session, error) {
|
||||
const op = "auth.ReadResponseFile"
|
||||
|
||||
content, err := a.medium.Read(path)
|
||||
if err != nil {
|
||||
return nil, coreerr.E(op, "failed to read response file", err)
|
||||
}
|
||||
|
||||
session, err := a.ValidateResponse(userID, []byte(content))
|
||||
if err != nil {
|
||||
return nil, coreerr.E(op, "failed to validate response", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// createSession generates a cryptographically random session token and
|
||||
// stores the session in the in-memory session map.
|
||||
func (a *Authenticator) createSession(userID string) (*Session, error) {
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return nil, fmt.Errorf("auth: failed to generate session token: %w", err)
|
||||
}
|
||||
|
||||
session := &Session{
|
||||
Token: hex.EncodeToString(tokenBytes),
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().Add(a.sessionTTL),
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
a.sessions[session.Token] = session
|
||||
a.mu.Unlock()
|
||||
|
||||
return session, nil
|
||||
}
|
||||
581
auth/auth_test.go
Normal file
581
auth/auth_test.go
Normal file
|
|
@ -0,0 +1,581 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"forge.lthn.ai/core/go-crypt/crypt/lthn"
|
||||
"forge.lthn.ai/core/go-crypt/crypt/pgp"
|
||||
"forge.lthn.ai/core/go/pkg/io"
|
||||
)
|
||||
|
||||
// helper creates a fresh Authenticator backed by MockMedium.
|
||||
func newTestAuth(opts ...Option) (*Authenticator, *io.MockMedium) {
|
||||
m := io.NewMockMedium()
|
||||
a := New(m, opts...)
|
||||
return a, m
|
||||
}
|
||||
|
||||
// --- Register ---
|
||||
|
||||
func TestRegister_Good(t *testing.T) {
|
||||
a, m := newTestAuth()
|
||||
|
||||
user, err := a.Register("alice", "hunter2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
|
||||
userID := lthn.Hash("alice")
|
||||
|
||||
// Verify public key is stored
|
||||
assert.True(t, m.IsFile(userPath(userID, ".pub")))
|
||||
assert.True(t, m.IsFile(userPath(userID, ".key")))
|
||||
assert.True(t, m.IsFile(userPath(userID, ".rev")))
|
||||
assert.True(t, m.IsFile(userPath(userID, ".json")))
|
||||
assert.True(t, m.IsFile(userPath(userID, ".lthn")))
|
||||
|
||||
// Verify user fields
|
||||
assert.NotEmpty(t, user.PublicKey)
|
||||
assert.Equal(t, userID, user.KeyID)
|
||||
assert.NotEmpty(t, user.Fingerprint)
|
||||
assert.Equal(t, lthn.Hash("hunter2"), user.PasswordHash)
|
||||
assert.False(t, user.Created.IsZero())
|
||||
}
|
||||
|
||||
func TestRegister_Bad(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
// Register first time succeeds
|
||||
_, err := a.Register("bob", "pass1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Duplicate registration should fail
|
||||
_, err = a.Register("bob", "pass2")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user already exists")
|
||||
}
|
||||
|
||||
func TestRegister_Ugly(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
// Empty username/password should still work (PGP allows it)
|
||||
user, err := a.Register("", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
}
|
||||
|
||||
// --- CreateChallenge ---
|
||||
|
||||
func TestCreateChallenge_Good(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
user, err := a.Register("charlie", "pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
challenge, err := a.CreateChallenge(user.KeyID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, challenge)
|
||||
|
||||
assert.Len(t, challenge.Nonce, nonceBytes)
|
||||
assert.NotEmpty(t, challenge.Encrypted)
|
||||
assert.True(t, challenge.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
|
||||
func TestCreateChallenge_Bad(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
// Challenge for non-existent user
|
||||
_, err := a.CreateChallenge("nonexistent-user-id")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user not found")
|
||||
}
|
||||
|
||||
func TestCreateChallenge_Ugly(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
// Empty userID
|
||||
_, err := a.CreateChallenge("")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// --- ValidateResponse (full challenge-response flow) ---
|
||||
|
||||
func TestValidateResponse_Good(t *testing.T) {
|
||||
a, m := newTestAuth()
|
||||
|
||||
// Register user
|
||||
_, err := a.Register("dave", "password123")
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := lthn.Hash("dave")
|
||||
|
||||
// Create challenge
|
||||
challenge, err := a.CreateChallenge(userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Client-side: decrypt nonce, then sign it
|
||||
privKey, err := m.Read(userPath(userID, ".key"))
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedNonce, err := pgp.Decrypt([]byte(challenge.Encrypted), privKey, "password123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, challenge.Nonce, decryptedNonce)
|
||||
|
||||
signedNonce, err := pgp.Sign(decryptedNonce, privKey, "password123")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate response
|
||||
session, err := a.ValidateResponse(userID, signedNonce)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, session)
|
||||
|
||||
assert.NotEmpty(t, session.Token)
|
||||
assert.Equal(t, userID, session.UserID)
|
||||
assert.True(t, session.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
|
||||
func TestValidateResponse_Bad(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
_, err := a.Register("eve", "pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("eve")
|
||||
|
||||
// No pending challenge
|
||||
_, err = a.ValidateResponse(userID, []byte("fake-signature"))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no pending challenge")
|
||||
}
|
||||
|
||||
func TestValidateResponse_Ugly(t *testing.T) {
|
||||
a, m := newTestAuth(WithChallengeTTL(1 * time.Millisecond))
|
||||
|
||||
_, err := a.Register("frank", "pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("frank")
|
||||
|
||||
// Create challenge and let it expire
|
||||
challenge, err := a.CreateChallenge(userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
// Sign with valid key but expired challenge
|
||||
privKey, err := m.Read(userPath(userID, ".key"))
|
||||
require.NoError(t, err)
|
||||
|
||||
signedNonce, err := pgp.Sign(challenge.Nonce, privKey, "pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = a.ValidateResponse(userID, signedNonce)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "challenge expired")
|
||||
}
|
||||
|
||||
// --- ValidateSession ---
|
||||
|
||||
func TestValidateSession_Good(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
_, err := a.Register("grace", "pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("grace")
|
||||
|
||||
session, err := a.Login(userID, "pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
validated, err := a.ValidateSession(session.Token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, session.Token, validated.Token)
|
||||
assert.Equal(t, userID, validated.UserID)
|
||||
}
|
||||
|
||||
func TestValidateSession_Bad(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
_, err := a.ValidateSession("nonexistent-token")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "session not found")
|
||||
}
|
||||
|
||||
func TestValidateSession_Ugly(t *testing.T) {
|
||||
a, _ := newTestAuth(WithSessionTTL(1 * time.Millisecond))
|
||||
|
||||
_, err := a.Register("heidi", "pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("heidi")
|
||||
|
||||
session, err := a.Login(userID, "pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
_, err = a.ValidateSession(session.Token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "session expired")
|
||||
}
|
||||
|
||||
// --- RefreshSession ---
|
||||
|
||||
func TestRefreshSession_Good(t *testing.T) {
|
||||
a, _ := newTestAuth(WithSessionTTL(1 * time.Hour))
|
||||
|
||||
_, err := a.Register("ivan", "pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("ivan")
|
||||
|
||||
session, err := a.Login(userID, "pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
originalExpiry := session.ExpiresAt
|
||||
|
||||
// Small delay to ensure time moves forward
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
refreshed, err := a.RefreshSession(session.Token)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, refreshed.ExpiresAt.After(originalExpiry))
|
||||
}
|
||||
|
||||
func TestRefreshSession_Bad(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
_, err := a.RefreshSession("nonexistent-token")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "session not found")
|
||||
}
|
||||
|
||||
func TestRefreshSession_Ugly(t *testing.T) {
|
||||
a, _ := newTestAuth(WithSessionTTL(1 * time.Millisecond))
|
||||
|
||||
_, err := a.Register("judy", "pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("judy")
|
||||
|
||||
session, err := a.Login(userID, "pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
_, err = a.RefreshSession(session.Token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "session expired")
|
||||
}
|
||||
|
||||
// --- RevokeSession ---
|
||||
|
||||
func TestRevokeSession_Good(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
_, err := a.Register("karl", "pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("karl")
|
||||
|
||||
session, err := a.Login(userID, "pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = a.RevokeSession(session.Token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should no longer be valid
|
||||
_, err = a.ValidateSession(session.Token)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRevokeSession_Bad(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
err := a.RevokeSession("nonexistent-token")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "session not found")
|
||||
}
|
||||
|
||||
func TestRevokeSession_Ugly(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
// Revoke empty token
|
||||
err := a.RevokeSession("")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// --- DeleteUser ---
|
||||
|
||||
func TestDeleteUser_Good(t *testing.T) {
|
||||
a, m := newTestAuth()
|
||||
|
||||
_, err := a.Register("larry", "pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("larry")
|
||||
|
||||
// Also create a session that should be cleaned up
|
||||
_, err = a.Login(userID, "pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = a.DeleteUser(userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// All files should be gone
|
||||
assert.False(t, m.IsFile(userPath(userID, ".pub")))
|
||||
assert.False(t, m.IsFile(userPath(userID, ".key")))
|
||||
assert.False(t, m.IsFile(userPath(userID, ".rev")))
|
||||
assert.False(t, m.IsFile(userPath(userID, ".json")))
|
||||
assert.False(t, m.IsFile(userPath(userID, ".lthn")))
|
||||
|
||||
// Session should be gone
|
||||
a.mu.RLock()
|
||||
sessionCount := 0
|
||||
for _, s := range a.sessions {
|
||||
if s.UserID == userID {
|
||||
sessionCount++
|
||||
}
|
||||
}
|
||||
a.mu.RUnlock()
|
||||
assert.Equal(t, 0, sessionCount)
|
||||
}
|
||||
|
||||
func TestDeleteUser_Bad(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
// Protected user "server" cannot be deleted
|
||||
err := a.DeleteUser("server")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot delete protected user")
|
||||
}
|
||||
|
||||
func TestDeleteUser_Ugly(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
// Non-existent user
|
||||
err := a.DeleteUser("nonexistent-user-id")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user not found")
|
||||
}
|
||||
|
||||
// --- Login ---
|
||||
|
||||
func TestLogin_Good(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
_, err := a.Register("mallory", "secret")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("mallory")
|
||||
|
||||
session, err := a.Login(userID, "secret")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, session)
|
||||
|
||||
assert.NotEmpty(t, session.Token)
|
||||
assert.Equal(t, userID, session.UserID)
|
||||
assert.True(t, session.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
|
||||
func TestLogin_Bad(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
_, err := a.Register("nancy", "correct-password")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("nancy")
|
||||
|
||||
// Wrong password
|
||||
_, err = a.Login(userID, "wrong-password")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid password")
|
||||
}
|
||||
|
||||
func TestLogin_Ugly(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
// Login for non-existent user
|
||||
_, err := a.Login("nonexistent-user-id", "pass")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user not found")
|
||||
}
|
||||
|
||||
// --- WriteChallengeFile / ReadResponseFile (Air-Gapped) ---
|
||||
|
||||
func TestAirGappedFlow_Good(t *testing.T) {
|
||||
a, m := newTestAuth()
|
||||
|
||||
_, err := a.Register("oscar", "airgap-pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("oscar")
|
||||
|
||||
// Write challenge to file
|
||||
challengePath := "transfer/challenge.json"
|
||||
err = a.WriteChallengeFile(userID, challengePath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, m.IsFile(challengePath))
|
||||
|
||||
// Read challenge file to get the encrypted nonce (simulating courier)
|
||||
challengeData, err := m.Read(challengePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
var challenge Challenge
|
||||
err = json.Unmarshal([]byte(challengeData), &challenge)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Client-side: decrypt nonce and sign it
|
||||
privKey, err := m.Read(userPath(userID, ".key"))
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedNonce, err := pgp.Decrypt([]byte(challenge.Encrypted), privKey, "airgap-pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
signedNonce, err := pgp.Sign(decryptedNonce, privKey, "airgap-pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write signed response to file
|
||||
responsePath := "transfer/response.sig"
|
||||
err = m.Write(responsePath, string(signedNonce))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server reads response file
|
||||
session, err := a.ReadResponseFile(userID, responsePath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, session)
|
||||
|
||||
assert.NotEmpty(t, session.Token)
|
||||
assert.Equal(t, userID, session.UserID)
|
||||
}
|
||||
|
||||
func TestWriteChallengeFile_Bad(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
// Challenge for non-existent user
|
||||
err := a.WriteChallengeFile("nonexistent-user", "challenge.json")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReadResponseFile_Bad(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
// Response file does not exist
|
||||
_, err := a.ReadResponseFile("some-user", "nonexistent-file.sig")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReadResponseFile_Ugly(t *testing.T) {
|
||||
a, m := newTestAuth()
|
||||
|
||||
_, err := a.Register("peggy", "pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("peggy")
|
||||
|
||||
// Create a challenge
|
||||
_, err = a.CreateChallenge(userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write garbage to response file
|
||||
responsePath := "transfer/bad-response.sig"
|
||||
err = m.Write(responsePath, "not-a-valid-signature")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = a.ReadResponseFile(userID, responsePath)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// --- Options ---
|
||||
|
||||
func TestWithChallengeTTL_Good(t *testing.T) {
|
||||
ttl := 30 * time.Second
|
||||
a, _ := newTestAuth(WithChallengeTTL(ttl))
|
||||
assert.Equal(t, ttl, a.challengeTTL)
|
||||
}
|
||||
|
||||
func TestWithSessionTTL_Good(t *testing.T) {
|
||||
ttl := 2 * time.Hour
|
||||
a, _ := newTestAuth(WithSessionTTL(ttl))
|
||||
assert.Equal(t, ttl, a.sessionTTL)
|
||||
}
|
||||
|
||||
// --- Full Round-Trip (Online Flow) ---
|
||||
|
||||
func TestFullRoundTrip_Good(t *testing.T) {
|
||||
a, m := newTestAuth()
|
||||
|
||||
// 1. Register
|
||||
user, err := a.Register("quinn", "roundtrip-pass")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
|
||||
userID := lthn.Hash("quinn")
|
||||
|
||||
// 2. Create challenge
|
||||
challenge, err := a.CreateChallenge(userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. Client decrypts + signs
|
||||
privKey, err := m.Read(userPath(userID, ".key"))
|
||||
require.NoError(t, err)
|
||||
|
||||
nonce, err := pgp.Decrypt([]byte(challenge.Encrypted), privKey, "roundtrip-pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
sig, err := pgp.Sign(nonce, privKey, "roundtrip-pass")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 4. Server validates, issues session
|
||||
session, err := a.ValidateResponse(userID, sig)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, session)
|
||||
|
||||
// 5. Validate session
|
||||
validated, err := a.ValidateSession(session.Token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, session.Token, validated.Token)
|
||||
|
||||
// 6. Refresh session
|
||||
refreshed, err := a.RefreshSession(session.Token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, session.Token, refreshed.Token)
|
||||
|
||||
// 7. Revoke session
|
||||
err = a.RevokeSession(session.Token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 8. Session should be invalid now
|
||||
_, err = a.ValidateSession(session.Token)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// --- Concurrent Access ---
|
||||
|
||||
func TestConcurrentSessions_Good(t *testing.T) {
|
||||
a, _ := newTestAuth()
|
||||
|
||||
_, err := a.Register("ruth", "pass")
|
||||
require.NoError(t, err)
|
||||
userID := lthn.Hash("ruth")
|
||||
|
||||
// Create multiple sessions concurrently
|
||||
const n = 10
|
||||
sessions := make(chan *Session, n)
|
||||
errs := make(chan error, n)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
s, err := a.Login(userID, "pass")
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
sessions <- s
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case s := <-sessions:
|
||||
require.NotNil(t, s)
|
||||
// Validate each session
|
||||
_, err := a.ValidateSession(s.Token)
|
||||
assert.NoError(t, err)
|
||||
case err := <-errs:
|
||||
t.Fatalf("concurrent login failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
50
crypt/chachapoly/chachapoly.go
Normal file
50
crypt/chachapoly/chachapoly.go
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
package chachapoly
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// Encrypt encrypts data using ChaCha20-Poly1305.
|
||||
func Encrypt(plaintext []byte, key []byte) ([]byte, error) {
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, aead.NonceSize(), aead.NonceSize()+len(plaintext)+aead.Overhead())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return aead.Seal(nonce, nonce, plaintext, nil), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts data using ChaCha20-Poly1305.
|
||||
func Decrypt(ciphertext []byte, key []byte) ([]byte, error) {
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minLen := aead.NonceSize() + aead.Overhead()
|
||||
if len(ciphertext) < minLen {
|
||||
return nil, fmt.Errorf("ciphertext too short: got %d bytes, need at least %d bytes", len(ciphertext), minLen)
|
||||
}
|
||||
|
||||
nonce, ciphertext := ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():]
|
||||
|
||||
decrypted, err := aead.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(decrypted) == 0 {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
return decrypted, nil
|
||||
}
|
||||
114
crypt/chachapoly/chachapoly_test.go
Normal file
114
crypt/chachapoly/chachapoly_test.go
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
package chachapoly
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockReader is a reader that returns an error.
|
||||
type mockReader struct{}
|
||||
|
||||
func (r *mockReader) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New("read error")
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = 1
|
||||
}
|
||||
|
||||
plaintext := []byte("Hello, world!")
|
||||
ciphertext, err := Encrypt(plaintext, key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
decrypted, err := Decrypt(ciphertext, key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
|
||||
func TestEncryptInvalidKeySize(t *testing.T) {
|
||||
key := make([]byte, 16) // Wrong size
|
||||
plaintext := []byte("test")
|
||||
_, err := Encrypt(plaintext, key)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecryptWithWrongKey(t *testing.T) {
|
||||
key1 := make([]byte, 32)
|
||||
key2 := make([]byte, 32)
|
||||
key2[0] = 1 // Different key
|
||||
|
||||
plaintext := []byte("secret")
|
||||
ciphertext, err := Encrypt(plaintext, key1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = Decrypt(ciphertext, key2)
|
||||
assert.Error(t, err) // Should fail authentication
|
||||
}
|
||||
|
||||
func TestDecryptTamperedCiphertext(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
plaintext := []byte("secret")
|
||||
ciphertext, err := Encrypt(plaintext, key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Tamper with the ciphertext
|
||||
ciphertext[0] ^= 0xff
|
||||
|
||||
_, err = Decrypt(ciphertext, key)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestEncryptEmptyPlaintext(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
plaintext := []byte("")
|
||||
ciphertext, err := Encrypt(plaintext, key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
decrypted, err := Decrypt(ciphertext, key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
|
||||
func TestDecryptShortCiphertext(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
shortCiphertext := []byte("short")
|
||||
|
||||
_, err := Decrypt(shortCiphertext, key)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too short")
|
||||
}
|
||||
|
||||
func TestCiphertextDiffersFromPlaintext(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
plaintext := []byte("Hello, world!")
|
||||
ciphertext, err := Encrypt(plaintext, key)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, plaintext, ciphertext)
|
||||
}
|
||||
|
||||
func TestEncryptNonceError(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
plaintext := []byte("test")
|
||||
|
||||
// Replace the rand.Reader with our mock reader
|
||||
oldReader := rand.Reader
|
||||
rand.Reader = &mockReader{}
|
||||
defer func() { rand.Reader = oldReader }()
|
||||
|
||||
_, err := Encrypt(plaintext, key)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecryptInvalidKeySize(t *testing.T) {
|
||||
key := make([]byte, 16) // Wrong size
|
||||
ciphertext := []byte("test")
|
||||
_, err := Decrypt(ciphertext, key)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
55
crypt/checksum.go
Normal file
55
crypt/checksum.go
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
core "forge.lthn.ai/core/go/pkg/framework/core"
|
||||
)
|
||||
|
||||
// SHA256File computes the SHA-256 checksum of a file and returns it as a hex string.
|
||||
func SHA256File(path string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", core.E("crypt.SHA256File", "failed to open file", err)
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return "", core.E("crypt.SHA256File", "failed to read file", err)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// SHA512File computes the SHA-512 checksum of a file and returns it as a hex string.
|
||||
func SHA512File(path string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", core.E("crypt.SHA512File", "failed to open file", err)
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
h := sha512.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return "", core.E("crypt.SHA512File", "failed to read file", err)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// SHA256Sum computes the SHA-256 checksum of data and returns it as a hex string.
|
||||
func SHA256Sum(data []byte) string {
|
||||
h := sha256.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// SHA512Sum computes the SHA-512 checksum of data and returns it as a hex string.
|
||||
func SHA512Sum(data []byte) string {
|
||||
h := sha512.Sum512(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
23
crypt/checksum_test.go
Normal file
23
crypt/checksum_test.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSHA256Sum_Good(t *testing.T) {
|
||||
data := []byte("hello")
|
||||
expected := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
|
||||
|
||||
result := SHA256Sum(data)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestSHA512Sum_Good(t *testing.T) {
|
||||
data := []byte("hello")
|
||||
expected := "9b71d224bd62f3785d96d46ad3ea3d73319bfbc2890caadae2dff72519673ca72323c3d99ba5c11d7c7acc6e14b8c5da0c4663475c2e5c3adef46f73bcdec043"
|
||||
|
||||
result := SHA512Sum(data)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
90
crypt/crypt.go
Normal file
90
crypt/crypt.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
core "forge.lthn.ai/core/go/pkg/framework/core"
|
||||
)
|
||||
|
||||
// Encrypt encrypts data with a passphrase using ChaCha20-Poly1305.
|
||||
// A random salt is generated and prepended to the output.
|
||||
// Format: salt (16 bytes) + nonce (24 bytes) + ciphertext.
|
||||
func Encrypt(plaintext, passphrase []byte) ([]byte, error) {
|
||||
salt, err := generateSalt(argon2SaltLen)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.Encrypt", "failed to generate salt", err)
|
||||
}
|
||||
|
||||
key := DeriveKey(passphrase, salt, argon2KeyLen)
|
||||
|
||||
encrypted, err := ChaCha20Encrypt(plaintext, key)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.Encrypt", "failed to encrypt", err)
|
||||
}
|
||||
|
||||
// Prepend salt to the encrypted data (which already has nonce prepended)
|
||||
result := make([]byte, 0, len(salt)+len(encrypted))
|
||||
result = append(result, salt...)
|
||||
result = append(result, encrypted...)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts data encrypted with Encrypt.
|
||||
// Expects format: salt (16 bytes) + nonce (24 bytes) + ciphertext.
|
||||
func Decrypt(ciphertext, passphrase []byte) ([]byte, error) {
|
||||
if len(ciphertext) < argon2SaltLen {
|
||||
return nil, core.E("crypt.Decrypt", "ciphertext too short", nil)
|
||||
}
|
||||
|
||||
salt := ciphertext[:argon2SaltLen]
|
||||
encrypted := ciphertext[argon2SaltLen:]
|
||||
|
||||
key := DeriveKey(passphrase, salt, argon2KeyLen)
|
||||
|
||||
plaintext, err := ChaCha20Decrypt(encrypted, key)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.Decrypt", "failed to decrypt", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptAES encrypts data using AES-256-GCM with a passphrase.
|
||||
// A random salt is generated and prepended to the output.
|
||||
// Format: salt (16 bytes) + nonce (12 bytes) + ciphertext.
|
||||
func EncryptAES(plaintext, passphrase []byte) ([]byte, error) {
|
||||
salt, err := generateSalt(argon2SaltLen)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.EncryptAES", "failed to generate salt", err)
|
||||
}
|
||||
|
||||
key := DeriveKey(passphrase, salt, argon2KeyLen)
|
||||
|
||||
encrypted, err := AESGCMEncrypt(plaintext, key)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.EncryptAES", "failed to encrypt", err)
|
||||
}
|
||||
|
||||
result := make([]byte, 0, len(salt)+len(encrypted))
|
||||
result = append(result, salt...)
|
||||
result = append(result, encrypted...)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecryptAES decrypts data encrypted with EncryptAES.
|
||||
// Expects format: salt (16 bytes) + nonce (12 bytes) + ciphertext.
|
||||
func DecryptAES(ciphertext, passphrase []byte) ([]byte, error) {
|
||||
if len(ciphertext) < argon2SaltLen {
|
||||
return nil, core.E("crypt.DecryptAES", "ciphertext too short", nil)
|
||||
}
|
||||
|
||||
salt := ciphertext[:argon2SaltLen]
|
||||
encrypted := ciphertext[argon2SaltLen:]
|
||||
|
||||
key := DeriveKey(passphrase, salt, argon2KeyLen)
|
||||
|
||||
plaintext, err := AESGCMDecrypt(encrypted, key)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.DecryptAES", "failed to decrypt", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
45
crypt/crypt_test.go
Normal file
45
crypt/crypt_test.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEncryptDecrypt_Good(t *testing.T) {
|
||||
plaintext := []byte("hello, world!")
|
||||
passphrase := []byte("correct-horse-battery-staple")
|
||||
|
||||
encrypted, err := Encrypt(plaintext, passphrase)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, plaintext, encrypted)
|
||||
|
||||
decrypted, err := Decrypt(encrypted, passphrase)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt_Bad(t *testing.T) {
|
||||
plaintext := []byte("secret data")
|
||||
passphrase := []byte("correct-passphrase")
|
||||
wrongPassphrase := []byte("wrong-passphrase")
|
||||
|
||||
encrypted, err := Encrypt(plaintext, passphrase)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = Decrypt(encrypted, wrongPassphrase)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestEncryptDecryptAES_Good(t *testing.T) {
|
||||
plaintext := []byte("hello, AES world!")
|
||||
passphrase := []byte("my-secure-passphrase")
|
||||
|
||||
encrypted, err := EncryptAES(plaintext, passphrase)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, plaintext, encrypted)
|
||||
|
||||
decrypted, err := DecryptAES(encrypted, passphrase)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
89
crypt/hash.go
Normal file
89
crypt/hash.go
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
core "forge.lthn.ai/core/go/pkg/framework/core"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// HashPassword hashes a password using Argon2id with default parameters.
|
||||
// Returns a string in the format: $argon2id$v=19$m=65536,t=3,p=4$<base64salt>$<base64hash>
|
||||
func HashPassword(password string) (string, error) {
|
||||
salt, err := generateSalt(argon2SaltLen)
|
||||
if err != nil {
|
||||
return "", core.E("crypt.HashPassword", "failed to generate salt", err)
|
||||
}
|
||||
|
||||
hash := argon2.IDKey([]byte(password), salt, argon2Time, argon2Memory, argon2Parallelism, argon2KeyLen)
|
||||
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, argon2Memory, argon2Time, argon2Parallelism,
|
||||
b64Salt, b64Hash)
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies a password against an Argon2id hash string.
|
||||
// The hash must be in the format produced by HashPassword.
|
||||
func VerifyPassword(password, hash string) (bool, error) {
|
||||
parts := strings.Split(hash, "$")
|
||||
if len(parts) != 6 {
|
||||
return false, core.E("crypt.VerifyPassword", "invalid hash format", nil)
|
||||
}
|
||||
|
||||
var version int
|
||||
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
||||
return false, core.E("crypt.VerifyPassword", "failed to parse version", err)
|
||||
}
|
||||
|
||||
var memory uint32
|
||||
var time uint32
|
||||
var parallelism uint8
|
||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, ¶llelism); err != nil {
|
||||
return false, core.E("crypt.VerifyPassword", "failed to parse parameters", err)
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false, core.E("crypt.VerifyPassword", "failed to decode salt", err)
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return false, core.E("crypt.VerifyPassword", "failed to decode hash", err)
|
||||
}
|
||||
|
||||
computedHash := argon2.IDKey([]byte(password), salt, time, memory, parallelism, uint32(len(expectedHash)))
|
||||
|
||||
return subtle.ConstantTimeCompare(computedHash, expectedHash) == 1, nil
|
||||
}
|
||||
|
||||
// HashBcrypt hashes a password using bcrypt with the given cost.
|
||||
// Cost must be between bcrypt.MinCost and bcrypt.MaxCost.
|
||||
func HashBcrypt(password string, cost int) (string, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), cost)
|
||||
if err != nil {
|
||||
return "", core.E("crypt.HashBcrypt", "failed to hash password", err)
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
// VerifyBcrypt verifies a password against a bcrypt hash.
|
||||
func VerifyBcrypt(password, hash string) (bool, error) {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
if err == bcrypt.ErrMismatchedHashAndPassword {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, core.E("crypt.VerifyBcrypt", "failed to verify password", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
50
crypt/hash_test.go
Normal file
50
crypt/hash_test.go
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func TestHashPassword_Good(t *testing.T) {
|
||||
password := "my-secure-password"
|
||||
|
||||
hash, err := HashPassword(password)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, hash)
|
||||
assert.Contains(t, hash, "$argon2id$")
|
||||
|
||||
match, err := VerifyPassword(password, hash)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, match)
|
||||
}
|
||||
|
||||
func TestVerifyPassword_Bad(t *testing.T) {
|
||||
password := "my-secure-password"
|
||||
wrongPassword := "wrong-password"
|
||||
|
||||
hash, err := HashPassword(password)
|
||||
assert.NoError(t, err)
|
||||
|
||||
match, err := VerifyPassword(wrongPassword, hash)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, match)
|
||||
}
|
||||
|
||||
func TestHashBcrypt_Good(t *testing.T) {
|
||||
password := "bcrypt-test-password"
|
||||
|
||||
hash, err := HashBcrypt(password, bcrypt.DefaultCost)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, hash)
|
||||
|
||||
match, err := VerifyBcrypt(password, hash)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, match)
|
||||
|
||||
// Wrong password should not match
|
||||
match, err = VerifyBcrypt("wrong-password", hash)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, match)
|
||||
}
|
||||
30
crypt/hmac.go
Normal file
30
crypt/hmac.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"hash"
|
||||
)
|
||||
|
||||
// HMACSHA256 computes the HMAC-SHA256 of a message using the given key.
|
||||
func HMACSHA256(message, key []byte) []byte {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(message)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
// HMACSHA512 computes the HMAC-SHA512 of a message using the given key.
|
||||
func HMACSHA512(message, key []byte) []byte {
|
||||
mac := hmac.New(sha512.New, key)
|
||||
mac.Write(message)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
// VerifyHMAC verifies an HMAC using constant-time comparison.
|
||||
// hashFunc should be sha256.New, sha512.New, etc.
|
||||
func VerifyHMAC(message, key, mac []byte, hashFunc func() hash.Hash) bool {
|
||||
expected := hmac.New(hashFunc, key)
|
||||
expected.Write(message)
|
||||
return hmac.Equal(mac, expected.Sum(nil))
|
||||
}
|
||||
40
crypt/hmac_test.go
Normal file
40
crypt/hmac_test.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHMACSHA256_Good(t *testing.T) {
|
||||
// RFC 4231 Test Case 2
|
||||
key := []byte("Jefe")
|
||||
message := []byte("what do ya want for nothing?")
|
||||
expected := "5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843"
|
||||
|
||||
mac := HMACSHA256(message, key)
|
||||
assert.Equal(t, expected, hex.EncodeToString(mac))
|
||||
}
|
||||
|
||||
func TestVerifyHMAC_Good(t *testing.T) {
|
||||
key := []byte("secret-key")
|
||||
message := []byte("test message")
|
||||
|
||||
mac := HMACSHA256(message, key)
|
||||
|
||||
valid := VerifyHMAC(message, key, mac, sha256.New)
|
||||
assert.True(t, valid)
|
||||
}
|
||||
|
||||
func TestVerifyHMAC_Bad(t *testing.T) {
|
||||
key := []byte("secret-key")
|
||||
message := []byte("test message")
|
||||
tampered := []byte("tampered message")
|
||||
|
||||
mac := HMACSHA256(message, key)
|
||||
|
||||
valid := VerifyHMAC(tampered, key, mac, sha256.New)
|
||||
assert.False(t, valid)
|
||||
}
|
||||
60
crypt/kdf.go
Normal file
60
crypt/kdf.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
// Package crypt provides cryptographic utilities including encryption,
|
||||
// hashing, key derivation, HMAC, and checksum functions.
|
||||
package crypt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
|
||||
core "forge.lthn.ai/core/go/pkg/framework/core"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
// Argon2id default parameters.
|
||||
const (
|
||||
argon2Memory = 64 * 1024 // 64 MB
|
||||
argon2Time = 3
|
||||
argon2Parallelism = 4
|
||||
argon2KeyLen = 32
|
||||
argon2SaltLen = 16
|
||||
)
|
||||
|
||||
// DeriveKey derives a key from a passphrase using Argon2id with default parameters.
|
||||
// The salt must be argon2SaltLen bytes. keyLen specifies the desired key length.
|
||||
func DeriveKey(passphrase, salt []byte, keyLen uint32) []byte {
|
||||
return argon2.IDKey(passphrase, salt, argon2Time, argon2Memory, argon2Parallelism, keyLen)
|
||||
}
|
||||
|
||||
// DeriveKeyScrypt derives a key from a passphrase using scrypt.
|
||||
// Uses recommended parameters: N=32768, r=8, p=1.
|
||||
func DeriveKeyScrypt(passphrase, salt []byte, keyLen int) ([]byte, error) {
|
||||
key, err := scrypt.Key(passphrase, salt, 32768, 8, 1, keyLen)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.DeriveKeyScrypt", "failed to derive key", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// HKDF derives a key using HKDF-SHA256.
|
||||
// secret is the input keying material, salt is optional (can be nil),
|
||||
// info is optional context, and keyLen is the desired output length.
|
||||
func HKDF(secret, salt, info []byte, keyLen int) ([]byte, error) {
|
||||
reader := hkdf.New(sha256.New, secret, salt, info)
|
||||
key := make([]byte, keyLen)
|
||||
if _, err := io.ReadFull(reader, key); err != nil {
|
||||
return nil, core.E("crypt.HKDF", "failed to derive key", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// generateSalt creates a random salt of the given length.
|
||||
func generateSalt(length int) ([]byte, error) {
|
||||
salt := make([]byte, length)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return nil, core.E("crypt.generateSalt", "failed to generate random salt", err)
|
||||
}
|
||||
return salt, nil
|
||||
}
|
||||
56
crypt/kdf_test.go
Normal file
56
crypt/kdf_test.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDeriveKey_Good(t *testing.T) {
|
||||
passphrase := []byte("test-passphrase")
|
||||
salt := []byte("1234567890123456") // 16 bytes
|
||||
|
||||
key1 := DeriveKey(passphrase, salt, 32)
|
||||
key2 := DeriveKey(passphrase, salt, 32)
|
||||
|
||||
assert.Len(t, key1, 32)
|
||||
assert.Equal(t, key1, key2, "same inputs should produce same output")
|
||||
|
||||
// Different passphrase should produce different key
|
||||
key3 := DeriveKey([]byte("different-passphrase"), salt, 32)
|
||||
assert.NotEqual(t, key1, key3)
|
||||
}
|
||||
|
||||
func TestDeriveKeyScrypt_Good(t *testing.T) {
|
||||
passphrase := []byte("test-passphrase")
|
||||
salt := []byte("1234567890123456")
|
||||
|
||||
key, err := DeriveKeyScrypt(passphrase, salt, 32)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, key, 32)
|
||||
|
||||
// Deterministic
|
||||
key2, err := DeriveKeyScrypt(passphrase, salt, 32)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, key, key2)
|
||||
}
|
||||
|
||||
func TestHKDF_Good(t *testing.T) {
|
||||
secret := []byte("input-keying-material")
|
||||
salt := []byte("optional-salt")
|
||||
info := []byte("context-info")
|
||||
|
||||
key1, err := HKDF(secret, salt, info, 32)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, key1, 32)
|
||||
|
||||
// Deterministic
|
||||
key2, err := HKDF(secret, salt, info, 32)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, key1, key2)
|
||||
|
||||
// Different info should produce different key
|
||||
key3, err := HKDF(secret, salt, []byte("different-info"), 32)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, key1, key3)
|
||||
}
|
||||
94
crypt/lthn/lthn.go
Normal file
94
crypt/lthn/lthn.go
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
// Package lthn implements the LTHN quasi-salted hash algorithm (RFC-0004).
|
||||
//
|
||||
// LTHN produces deterministic, verifiable hashes without requiring separate salt
|
||||
// storage. The salt is derived from the input itself through:
|
||||
// 1. Reversing the input string
|
||||
// 2. Applying "leet speak" style character substitutions
|
||||
//
|
||||
// The final hash is: SHA256(input || derived_salt)
|
||||
//
|
||||
// This is suitable for content identifiers, cache keys, and deduplication.
|
||||
// NOT suitable for password hashing - use bcrypt, Argon2, or scrypt instead.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// hash := lthn.Hash("hello")
|
||||
// valid := lthn.Verify("hello", hash) // true
|
||||
package lthn
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// keyMap defines the character substitutions for quasi-salt derivation.
|
||||
// These are inspired by "leet speak" conventions for letter-number substitution.
|
||||
// The mapping is bidirectional for most characters but NOT fully symmetric.
|
||||
var keyMap = map[rune]rune{
|
||||
'o': '0', // letter O -> zero
|
||||
'l': '1', // letter L -> one
|
||||
'e': '3', // letter E -> three
|
||||
'a': '4', // letter A -> four
|
||||
's': 'z', // letter S -> Z
|
||||
't': '7', // letter T -> seven
|
||||
'0': 'o', // zero -> letter O
|
||||
'1': 'l', // one -> letter L
|
||||
'3': 'e', // three -> letter E
|
||||
'4': 'a', // four -> letter A
|
||||
'7': 't', // seven -> letter T
|
||||
}
|
||||
|
||||
// SetKeyMap replaces the default character substitution map.
|
||||
// Use this to customize the quasi-salt derivation for specific applications.
|
||||
// Changes affect all subsequent Hash and Verify calls.
|
||||
func SetKeyMap(newKeyMap map[rune]rune) {
|
||||
keyMap = newKeyMap
|
||||
}
|
||||
|
||||
// GetKeyMap returns the current character substitution map.
|
||||
func GetKeyMap() map[rune]rune {
|
||||
return keyMap
|
||||
}
|
||||
|
||||
// Hash computes the LTHN hash of the input string.
|
||||
//
|
||||
// The algorithm:
|
||||
// 1. Derive a quasi-salt by reversing the input and applying character substitutions
|
||||
// 2. Concatenate: input + salt
|
||||
// 3. Compute SHA-256 of the concatenated string
|
||||
// 4. Return the hex-encoded digest (64 characters, lowercase)
|
||||
//
|
||||
// The same input always produces the same hash, enabling verification
|
||||
// without storing a separate salt value.
|
||||
func Hash(input string) string {
|
||||
salt := createSalt(input)
|
||||
hash := sha256.Sum256([]byte(input + salt))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// createSalt derives a quasi-salt by reversing the input and applying substitutions.
|
||||
// For example: "hello" -> reversed "olleh" -> substituted "011eh"
|
||||
func createSalt(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
runes := []rune(input)
|
||||
salt := make([]rune, len(runes))
|
||||
for i := 0; i < len(runes); i++ {
|
||||
char := runes[len(runes)-1-i]
|
||||
if replacement, ok := keyMap[char]; ok {
|
||||
salt[i] = replacement
|
||||
} else {
|
||||
salt[i] = char
|
||||
}
|
||||
}
|
||||
return string(salt)
|
||||
}
|
||||
|
||||
// Verify checks if an input string produces the given hash.
|
||||
// Returns true if Hash(input) equals the provided hash value.
|
||||
// Uses direct string comparison - for security-critical applications,
|
||||
// consider using constant-time comparison.
|
||||
func Verify(input string, hash string) bool {
|
||||
return Hash(input) == hash
|
||||
}
|
||||
66
crypt/lthn/lthn_test.go
Normal file
66
crypt/lthn/lthn_test.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package lthn
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
hash := Hash("hello")
|
||||
assert.NotEmpty(t, hash)
|
||||
}
|
||||
|
||||
func TestVerify(t *testing.T) {
|
||||
hash := Hash("hello")
|
||||
assert.True(t, Verify("hello", hash))
|
||||
assert.False(t, Verify("world", hash))
|
||||
}
|
||||
|
||||
func TestCreateSalt_Good(t *testing.T) {
|
||||
// "hello" reversed: "olleh" -> "0113h"
|
||||
expected := "0113h"
|
||||
actual := createSalt("hello")
|
||||
assert.Equal(t, expected, actual, "Salt should be correctly created for 'hello'")
|
||||
}
|
||||
|
||||
func TestCreateSalt_Bad(t *testing.T) {
|
||||
// Test with an empty string
|
||||
expected := ""
|
||||
actual := createSalt("")
|
||||
assert.Equal(t, expected, actual, "Salt for an empty string should be empty")
|
||||
}
|
||||
|
||||
func TestCreateSalt_Ugly(t *testing.T) {
|
||||
// Test with characters not in the keyMap
|
||||
input := "world123"
|
||||
// "world123" reversed: "321dlrow" -> "e2ld1r0w"
|
||||
expected := "e2ld1r0w"
|
||||
actual := createSalt(input)
|
||||
assert.Equal(t, expected, actual, "Salt should handle characters not in the keyMap")
|
||||
|
||||
// Test with only characters in the keyMap
|
||||
input = "oleta"
|
||||
// "oleta" reversed: "atelo" -> "47310"
|
||||
expected = "47310"
|
||||
actual = createSalt(input)
|
||||
assert.Equal(t, expected, actual, "Salt should correctly handle strings with only keyMap characters")
|
||||
}
|
||||
|
||||
var testKeyMapMu sync.Mutex
|
||||
|
||||
func TestSetKeyMap(t *testing.T) {
|
||||
testKeyMapMu.Lock()
|
||||
originalKeyMap := GetKeyMap()
|
||||
t.Cleanup(func() {
|
||||
SetKeyMap(originalKeyMap)
|
||||
testKeyMapMu.Unlock()
|
||||
})
|
||||
|
||||
newKeyMap := map[rune]rune{
|
||||
'a': 'b',
|
||||
}
|
||||
SetKeyMap(newKeyMap)
|
||||
assert.Equal(t, newKeyMap, GetKeyMap())
|
||||
}
|
||||
191
crypt/openpgp/service.go
Normal file
191
crypt/openpgp/service.go
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
package openpgp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
goio "io"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/go-crypto/openpgp"
|
||||
"github.com/ProtonMail/go-crypto/openpgp/armor"
|
||||
"github.com/ProtonMail/go-crypto/openpgp/packet"
|
||||
core "forge.lthn.ai/core/go/pkg/framework/core"
|
||||
)
|
||||
|
||||
// Service implements the core.Crypt interface using OpenPGP.
|
||||
type Service struct {
|
||||
core *core.Core
|
||||
}
|
||||
|
||||
// New creates a new OpenPGP service instance.
|
||||
func New(c *core.Core) (any, error) {
|
||||
return &Service{core: c}, nil
|
||||
}
|
||||
|
||||
// CreateKeyPair generates a new RSA-4096 PGP keypair.
|
||||
// Returns the armored private key string.
|
||||
func (s *Service) CreateKeyPair(name, passphrase string) (string, error) {
|
||||
config := &packet.Config{
|
||||
Algorithm: packet.PubKeyAlgoRSA,
|
||||
RSABits: 4096,
|
||||
DefaultHash: crypto.SHA256,
|
||||
DefaultCipher: packet.CipherAES256,
|
||||
}
|
||||
|
||||
entity, err := openpgp.NewEntity(name, "Workspace Key", "", config)
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.CreateKeyPair", "failed to create entity", err)
|
||||
}
|
||||
|
||||
// Encrypt private key if passphrase is provided
|
||||
if passphrase != "" {
|
||||
err = entity.PrivateKey.Encrypt([]byte(passphrase))
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.CreateKeyPair", "failed to encrypt private key", err)
|
||||
}
|
||||
for _, subkey := range entity.Subkeys {
|
||||
err = subkey.PrivateKey.Encrypt([]byte(passphrase))
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.CreateKeyPair", "failed to encrypt subkey", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
w, err := armor.Encode(&buf, openpgp.PrivateKeyType, nil)
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.CreateKeyPair", "failed to create armor encoder", err)
|
||||
}
|
||||
|
||||
// Manual serialization to avoid panic from re-signing encrypted keys
|
||||
err = s.serializeEntity(w, entity)
|
||||
if err != nil {
|
||||
w.Close()
|
||||
return "", core.E("openpgp.CreateKeyPair", "failed to serialize private key", err)
|
||||
}
|
||||
w.Close()
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// serializeEntity manually serializes an OpenPGP entity to avoid re-signing.
|
||||
func (s *Service) serializeEntity(w goio.Writer, e *openpgp.Entity) error {
|
||||
err := e.PrivateKey.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, ident := range e.Identities {
|
||||
err = ident.UserId.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ident.SelfSignature.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, subkey := range e.Subkeys {
|
||||
err = subkey.PrivateKey.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = subkey.Sig.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncryptPGP encrypts data for a recipient identified by their public key (armored string in recipientPath).
|
||||
// The encrypted data is written to the provided writer and also returned as an armored string.
|
||||
func (s *Service) EncryptPGP(writer goio.Writer, recipientPath, data string, opts ...any) (string, error) {
|
||||
entityList, err := openpgp.ReadArmoredKeyRing(strings.NewReader(recipientPath))
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.EncryptPGP", "failed to read recipient key", err)
|
||||
}
|
||||
|
||||
var armoredBuf bytes.Buffer
|
||||
armoredWriter, err := armor.Encode(&armoredBuf, "PGP MESSAGE", nil)
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.EncryptPGP", "failed to create armor encoder", err)
|
||||
}
|
||||
|
||||
// MultiWriter to write to both the provided writer and our armored buffer
|
||||
mw := goio.MultiWriter(writer, armoredWriter)
|
||||
|
||||
w, err := openpgp.Encrypt(mw, entityList, nil, nil, nil)
|
||||
if err != nil {
|
||||
armoredWriter.Close()
|
||||
return "", core.E("openpgp.EncryptPGP", "failed to start encryption", err)
|
||||
}
|
||||
|
||||
_, err = goio.WriteString(w, data)
|
||||
if err != nil {
|
||||
w.Close()
|
||||
armoredWriter.Close()
|
||||
return "", core.E("openpgp.EncryptPGP", "failed to write data", err)
|
||||
}
|
||||
|
||||
w.Close()
|
||||
armoredWriter.Close()
|
||||
|
||||
return armoredBuf.String(), nil
|
||||
}
|
||||
|
||||
// DecryptPGP decrypts a PGP message using the provided armored private key and passphrase.
|
||||
func (s *Service) DecryptPGP(privateKey, message, passphrase string, opts ...any) (string, error) {
|
||||
entityList, err := openpgp.ReadArmoredKeyRing(strings.NewReader(privateKey))
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.DecryptPGP", "failed to read private key", err)
|
||||
}
|
||||
|
||||
entity := entityList[0]
|
||||
if entity.PrivateKey.Encrypted {
|
||||
err = entity.PrivateKey.Decrypt([]byte(passphrase))
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.DecryptPGP", "failed to decrypt private key", err)
|
||||
}
|
||||
for _, subkey := range entity.Subkeys {
|
||||
_ = subkey.PrivateKey.Decrypt([]byte(passphrase))
|
||||
}
|
||||
}
|
||||
|
||||
// Decrypt armored message
|
||||
block, err := armor.Decode(strings.NewReader(message))
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.DecryptPGP", "failed to decode armored message", err)
|
||||
}
|
||||
|
||||
md, err := openpgp.ReadMessage(block.Body, entityList, nil, nil)
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.DecryptPGP", "failed to read message", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, err = goio.Copy(&buf, md.UnverifiedBody)
|
||||
if err != nil {
|
||||
return "", core.E("openpgp.DecryptPGP", "failed to read decrypted body", err)
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// HandleIPCEvents handles PGP-related IPC messages.
|
||||
func (s *Service) HandleIPCEvents(c *core.Core, msg core.Message) error {
|
||||
switch m := msg.(type) {
|
||||
case map[string]any:
|
||||
action, _ := m["action"].(string)
|
||||
switch action {
|
||||
case "openpgp.create_key_pair":
|
||||
name, _ := m["name"].(string)
|
||||
passphrase, _ := m["passphrase"].(string)
|
||||
_, err := s.CreateKeyPair(name, passphrase)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure Service implements core.Crypt.
|
||||
var _ core.Crypt = (*Service)(nil)
|
||||
43
crypt/openpgp/service_test.go
Normal file
43
crypt/openpgp/service_test.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
package openpgp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
core "forge.lthn.ai/core/go/pkg/framework/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateKeyPair(t *testing.T) {
|
||||
c, _ := core.New()
|
||||
s := &Service{core: c}
|
||||
|
||||
privKey, err := s.CreateKeyPair("test user", "password123")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, privKey)
|
||||
assert.Contains(t, privKey, "-----BEGIN PGP PRIVATE KEY BLOCK-----")
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
c, _ := core.New()
|
||||
s := &Service{core: c}
|
||||
|
||||
passphrase := "secret"
|
||||
privKey, err := s.CreateKeyPair("test user", passphrase)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// In this simple test, the public key is also in the armored private key string
|
||||
// (openpgp.ReadArmoredKeyRing reads both)
|
||||
publicKey := privKey
|
||||
|
||||
data := "hello openpgp"
|
||||
var buf bytes.Buffer
|
||||
armored, err := s.EncryptPGP(&buf, publicKey, data)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, armored)
|
||||
assert.NotEmpty(t, buf.String())
|
||||
|
||||
decrypted, err := s.DecryptPGP(privKey, armored, passphrase)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, data, decrypted)
|
||||
}
|
||||
230
crypt/pgp/pgp.go
Normal file
230
crypt/pgp/pgp.go
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
// Package pgp provides OpenPGP key generation, encryption, decryption,
|
||||
// signing, and verification using the ProtonMail go-crypto library.
|
||||
//
|
||||
// Ported from Enchantrix (github.com/Snider/Enchantrix/pkg/crypt/std/pgp).
|
||||
package pgp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ProtonMail/go-crypto/openpgp"
|
||||
"github.com/ProtonMail/go-crypto/openpgp/armor"
|
||||
"github.com/ProtonMail/go-crypto/openpgp/packet"
|
||||
)
|
||||
|
||||
// KeyPair holds armored PGP public and private keys.
|
||||
type KeyPair struct {
|
||||
PublicKey string
|
||||
PrivateKey string
|
||||
}
|
||||
|
||||
// CreateKeyPair generates a new PGP key pair for the given identity.
|
||||
// If password is non-empty, the private key is encrypted with it.
|
||||
// Returns a KeyPair with armored public and private keys.
|
||||
func CreateKeyPair(name, email, password string) (*KeyPair, error) {
|
||||
entity, err := openpgp.NewEntity(name, "", email, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to create entity: %w", err)
|
||||
}
|
||||
|
||||
// Sign all the identities
|
||||
for _, id := range entity.Identities {
|
||||
_ = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil)
|
||||
}
|
||||
|
||||
// Encrypt private key with password if provided
|
||||
if password != "" {
|
||||
err = entity.PrivateKey.Encrypt([]byte(password))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to encrypt private key: %w", err)
|
||||
}
|
||||
for _, subkey := range entity.Subkeys {
|
||||
err = subkey.PrivateKey.Encrypt([]byte(password))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to encrypt subkey: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize public key
|
||||
pubKeyBuf := new(bytes.Buffer)
|
||||
pubKeyWriter, err := armor.Encode(pubKeyBuf, openpgp.PublicKeyType, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to create armored public key writer: %w", err)
|
||||
}
|
||||
if err := entity.Serialize(pubKeyWriter); err != nil {
|
||||
pubKeyWriter.Close()
|
||||
return nil, fmt.Errorf("pgp: failed to serialize public key: %w", err)
|
||||
}
|
||||
pubKeyWriter.Close()
|
||||
|
||||
// Serialize private key
|
||||
privKeyBuf := new(bytes.Buffer)
|
||||
privKeyWriter, err := armor.Encode(privKeyBuf, openpgp.PrivateKeyType, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to create armored private key writer: %w", err)
|
||||
}
|
||||
if password != "" {
|
||||
// Manual serialization to avoid re-signing encrypted keys
|
||||
if err := serializeEncryptedEntity(privKeyWriter, entity); err != nil {
|
||||
privKeyWriter.Close()
|
||||
return nil, fmt.Errorf("pgp: failed to serialize private key: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := entity.SerializePrivate(privKeyWriter, nil); err != nil {
|
||||
privKeyWriter.Close()
|
||||
return nil, fmt.Errorf("pgp: failed to serialize private key: %w", err)
|
||||
}
|
||||
}
|
||||
privKeyWriter.Close()
|
||||
|
||||
return &KeyPair{
|
||||
PublicKey: pubKeyBuf.String(),
|
||||
PrivateKey: privKeyBuf.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// serializeEncryptedEntity manually serializes an entity with encrypted private keys
|
||||
// to avoid the panic from re-signing encrypted keys.
|
||||
func serializeEncryptedEntity(w io.Writer, e *openpgp.Entity) error {
|
||||
if err := e.PrivateKey.Serialize(w); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, ident := range e.Identities {
|
||||
if err := ident.UserId.Serialize(w); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ident.SelfSignature.Serialize(w); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, subkey := range e.Subkeys {
|
||||
if err := subkey.PrivateKey.Serialize(w); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := subkey.Sig.Serialize(w); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts data for the recipient identified by their armored public key.
|
||||
// Returns the encrypted data as armored PGP output.
|
||||
func Encrypt(data []byte, publicKeyArmor string) ([]byte, error) {
|
||||
keyring, err := openpgp.ReadArmoredKeyRing(bytes.NewReader([]byte(publicKeyArmor)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to read public key ring: %w", err)
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
armoredWriter, err := armor.Encode(buf, "PGP MESSAGE", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to create armor encoder: %w", err)
|
||||
}
|
||||
|
||||
w, err := openpgp.Encrypt(armoredWriter, keyring, nil, nil, nil)
|
||||
if err != nil {
|
||||
armoredWriter.Close()
|
||||
return nil, fmt.Errorf("pgp: failed to create encryption writer: %w", err)
|
||||
}
|
||||
|
||||
if _, err := w.Write(data); err != nil {
|
||||
w.Close()
|
||||
armoredWriter.Close()
|
||||
return nil, fmt.Errorf("pgp: failed to write data: %w", err)
|
||||
}
|
||||
w.Close()
|
||||
armoredWriter.Close()
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts armored PGP data using the given armored private key.
|
||||
// If the private key is encrypted, the password is used to decrypt it first.
|
||||
func Decrypt(data []byte, privateKeyArmor, password string) ([]byte, error) {
|
||||
keyring, err := openpgp.ReadArmoredKeyRing(bytes.NewReader([]byte(privateKeyArmor)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to read private key ring: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt the private key if it is encrypted
|
||||
for _, entity := range keyring {
|
||||
if entity.PrivateKey != nil && entity.PrivateKey.Encrypted {
|
||||
if err := entity.PrivateKey.Decrypt([]byte(password)); err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to decrypt private key: %w", err)
|
||||
}
|
||||
}
|
||||
for _, subkey := range entity.Subkeys {
|
||||
if subkey.PrivateKey != nil && subkey.PrivateKey.Encrypted {
|
||||
_ = subkey.PrivateKey.Decrypt([]byte(password))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode armored message
|
||||
block, err := armor.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to decode armored message: %w", err)
|
||||
}
|
||||
|
||||
md, err := openpgp.ReadMessage(block.Body, keyring, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to read message: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := io.ReadAll(md.UnverifiedBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to read plaintext: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// Sign creates an armored detached signature for the given data using
|
||||
// the armored private key. If the key is encrypted, the password is used
|
||||
// to decrypt it first.
|
||||
func Sign(data []byte, privateKeyArmor, password string) ([]byte, error) {
|
||||
keyring, err := openpgp.ReadArmoredKeyRing(bytes.NewReader([]byte(privateKeyArmor)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to read private key ring: %w", err)
|
||||
}
|
||||
|
||||
signer := keyring[0]
|
||||
if signer.PrivateKey == nil {
|
||||
return nil, fmt.Errorf("pgp: private key not found in keyring")
|
||||
}
|
||||
|
||||
if signer.PrivateKey.Encrypted {
|
||||
if err := signer.PrivateKey.Decrypt([]byte(password)); err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to decrypt private key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
config := &packet.Config{}
|
||||
err = openpgp.ArmoredDetachSign(buf, signer, bytes.NewReader(data), config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgp: failed to sign message: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Verify verifies an armored detached signature against the given data
|
||||
// and armored public key. Returns nil if the signature is valid.
|
||||
func Verify(data, signature []byte, publicKeyArmor string) error {
|
||||
keyring, err := openpgp.ReadArmoredKeyRing(bytes.NewReader([]byte(publicKeyArmor)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("pgp: failed to read public key ring: %w", err)
|
||||
}
|
||||
|
||||
_, err = openpgp.CheckArmoredDetachedSignature(keyring, bytes.NewReader(data), bytes.NewReader(signature), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pgp: signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
164
crypt/pgp/pgp_test.go
Normal file
164
crypt/pgp/pgp_test.go
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
package pgp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateKeyPair_Good(t *testing.T) {
|
||||
kp, err := CreateKeyPair("Test User", "test@example.com", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, kp)
|
||||
assert.Contains(t, kp.PublicKey, "-----BEGIN PGP PUBLIC KEY BLOCK-----")
|
||||
assert.Contains(t, kp.PrivateKey, "-----BEGIN PGP PRIVATE KEY BLOCK-----")
|
||||
}
|
||||
|
||||
func TestCreateKeyPair_Bad(t *testing.T) {
|
||||
// Empty name still works (openpgp allows it), but test with password
|
||||
kp, err := CreateKeyPair("Secure User", "secure@example.com", "strong-password")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, kp)
|
||||
assert.Contains(t, kp.PublicKey, "-----BEGIN PGP PUBLIC KEY BLOCK-----")
|
||||
assert.Contains(t, kp.PrivateKey, "-----BEGIN PGP PRIVATE KEY BLOCK-----")
|
||||
}
|
||||
|
||||
func TestCreateKeyPair_Ugly(t *testing.T) {
|
||||
// Minimal identity
|
||||
kp, err := CreateKeyPair("", "", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, kp)
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt_Good(t *testing.T) {
|
||||
kp, err := CreateKeyPair("Test User", "test@example.com", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("hello, OpenPGP!")
|
||||
ciphertext, err := Encrypt(plaintext, kp.PublicKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, ciphertext)
|
||||
assert.Contains(t, string(ciphertext), "-----BEGIN PGP MESSAGE-----")
|
||||
|
||||
decrypted, err := Decrypt(ciphertext, kp.PrivateKey, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt_Bad(t *testing.T) {
|
||||
kp1, err := CreateKeyPair("User One", "one@example.com", "")
|
||||
require.NoError(t, err)
|
||||
kp2, err := CreateKeyPair("User Two", "two@example.com", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("secret data")
|
||||
ciphertext, err := Encrypt(plaintext, kp1.PublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decrypting with wrong key should fail
|
||||
_, err = Decrypt(ciphertext, kp2.PrivateKey, "")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt_Ugly(t *testing.T) {
|
||||
// Invalid public key for encryption
|
||||
_, err := Encrypt([]byte("data"), "not-a-pgp-key")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Invalid private key for decryption
|
||||
_, err = Decrypt([]byte("data"), "not-a-pgp-key", "")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestEncryptDecryptWithPassword_Good(t *testing.T) {
|
||||
password := "my-secret-passphrase"
|
||||
kp, err := CreateKeyPair("Secure User", "secure@example.com", password)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("encrypted with password-protected key")
|
||||
ciphertext, err := Encrypt(plaintext, kp.PublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := Decrypt(ciphertext, kp.PrivateKey, password)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
|
||||
func TestSignVerify_Good(t *testing.T) {
|
||||
kp, err := CreateKeyPair("Signer", "signer@example.com", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
data := []byte("message to sign")
|
||||
signature, err := Sign(data, kp.PrivateKey, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
assert.Contains(t, string(signature), "-----BEGIN PGP SIGNATURE-----")
|
||||
|
||||
err = Verify(data, signature, kp.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSignVerify_Bad(t *testing.T) {
|
||||
kp, err := CreateKeyPair("Signer", "signer@example.com", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
data := []byte("original message")
|
||||
signature, err := Sign(data, kp.PrivateKey, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify with tampered data should fail
|
||||
err = Verify([]byte("tampered message"), signature, kp.PublicKey)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSignVerify_Ugly(t *testing.T) {
|
||||
// Invalid key for signing
|
||||
_, err := Sign([]byte("data"), "not-a-key", "")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Invalid key for verification
|
||||
kp, err := CreateKeyPair("Signer", "signer@example.com", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
data := []byte("message")
|
||||
sig, err := Sign(data, kp.PrivateKey, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = Verify(data, sig, "not-a-key")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSignVerifyWithPassword_Good(t *testing.T) {
|
||||
password := "signing-password"
|
||||
kp, err := CreateKeyPair("Signer", "signer@example.com", password)
|
||||
require.NoError(t, err)
|
||||
|
||||
data := []byte("signed with password-protected key")
|
||||
signature, err := Sign(data, kp.PrivateKey, password)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = Verify(data, signature, kp.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestFullRoundTrip_Good(t *testing.T) {
|
||||
// Generate keys, encrypt, decrypt, sign, and verify - full round trip
|
||||
kp, err := CreateKeyPair("Full Test", "full@example.com", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
original := []byte("full round-trip test data")
|
||||
|
||||
// Encrypt then decrypt
|
||||
ciphertext, err := Encrypt(original, kp.PublicKey)
|
||||
require.NoError(t, err)
|
||||
decrypted, err := Decrypt(ciphertext, kp.PrivateKey, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, original, decrypted)
|
||||
|
||||
// Sign then verify
|
||||
signature, err := Sign(original, kp.PrivateKey, "")
|
||||
require.NoError(t, err)
|
||||
err = Verify(original, signature, kp.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
91
crypt/rsa/rsa.go
Normal file
91
crypt/rsa/rsa.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package rsa
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Service provides RSA functionality.
|
||||
type Service struct{}
|
||||
|
||||
// NewService creates and returns a new Service instance for performing RSA-related operations.
|
||||
func NewService() *Service {
|
||||
return &Service{}
|
||||
}
|
||||
|
||||
// GenerateKeyPair creates a new RSA key pair.
|
||||
func (s *Service) GenerateKeyPair(bits int) (publicKey, privateKey []byte, err error) {
|
||||
if bits < 2048 {
|
||||
return nil, nil, fmt.Errorf("rsa: key size too small: %d (minimum 2048)", bits)
|
||||
}
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
privKeyBytes := x509.MarshalPKCS1PrivateKey(privKey)
|
||||
privKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: privKeyBytes,
|
||||
})
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: pubKeyBytes,
|
||||
})
|
||||
|
||||
return pubKeyPEM, privKeyPEM, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts data with a public key.
|
||||
func (s *Service) Encrypt(publicKey, data, label []byte) ([]byte, error) {
|
||||
block, _ := pem.Decode(publicKey)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("failed to decode public key")
|
||||
}
|
||||
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
rsaPub, ok := pub.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not an RSA public key")
|
||||
}
|
||||
|
||||
ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, rsaPub, data, label)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt data: %w", err)
|
||||
}
|
||||
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts data with a private key.
|
||||
func (s *Service) Decrypt(privateKey, ciphertext, label []byte) ([]byte, error) {
|
||||
block, _ := pem.Decode(privateKey)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("failed to decode private key")
|
||||
}
|
||||
|
||||
priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, priv, ciphertext, label)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt data: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
101
crypt/rsa/rsa_test.go
Normal file
101
crypt/rsa/rsa_test.go
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
package rsa
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockReader is a reader that returns an error.
|
||||
type mockReader struct{}
|
||||
|
||||
func (r *mockReader) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New("read error")
|
||||
}
|
||||
|
||||
func TestRSA_Good(t *testing.T) {
|
||||
s := NewService()
|
||||
|
||||
// Generate a new key pair
|
||||
pubKey, privKey, err := s.GenerateKeyPair(2048)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, pubKey)
|
||||
assert.NotEmpty(t, privKey)
|
||||
|
||||
// Encrypt and decrypt a message
|
||||
message := []byte("Hello, World!")
|
||||
ciphertext, err := s.Encrypt(pubKey, message, nil)
|
||||
assert.NoError(t, err)
|
||||
plaintext, err := s.Decrypt(privKey, ciphertext, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, message, plaintext)
|
||||
}
|
||||
|
||||
func TestRSA_Bad(t *testing.T) {
|
||||
s := NewService()
|
||||
|
||||
// Decrypt with wrong key
|
||||
pubKey, _, err := s.GenerateKeyPair(2048)
|
||||
assert.NoError(t, err)
|
||||
_, otherPrivKey, err := s.GenerateKeyPair(2048)
|
||||
assert.NoError(t, err)
|
||||
message := []byte("Hello, World!")
|
||||
ciphertext, err := s.Encrypt(pubKey, message, nil)
|
||||
assert.NoError(t, err)
|
||||
_, err = s.Decrypt(otherPrivKey, ciphertext, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Key size too small
|
||||
_, _, err = s.GenerateKeyPair(512)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRSA_Ugly(t *testing.T) {
|
||||
s := NewService()
|
||||
|
||||
// Malformed keys and messages
|
||||
_, err := s.Encrypt([]byte("not-a-key"), []byte("message"), nil)
|
||||
assert.Error(t, err)
|
||||
_, err = s.Decrypt([]byte("not-a-key"), []byte("message"), nil)
|
||||
assert.Error(t, err)
|
||||
_, err = s.Encrypt([]byte("-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAJ/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4=\n-----END PUBLIC KEY-----"), []byte("message"), nil)
|
||||
assert.Error(t, err)
|
||||
_, err = s.Decrypt([]byte("-----BEGIN RSA PRIVATE KEY-----\nMIIBOQIBAAJBAL/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CAwEAAQJB\nAL/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4C\ngYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/4CgYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/4=\n-----END RSA PRIVATE KEY-----"), []byte("message"), nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Key generation failure
|
||||
oldReader := rand.Reader
|
||||
rand.Reader = &mockReader{}
|
||||
t.Cleanup(func() { rand.Reader = oldReader })
|
||||
_, _, err = s.GenerateKeyPair(2048)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Encrypt with non-RSA key
|
||||
rand.Reader = oldReader // Restore reader for this test
|
||||
ecdsaPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.NoError(t, err)
|
||||
ecdsaPubKeyBytes, err := x509.MarshalPKIXPublicKey(&ecdsaPrivKey.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
ecdsaPubKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: ecdsaPubKeyBytes,
|
||||
})
|
||||
_, err = s.Encrypt(ecdsaPubKeyPEM, []byte("message"), nil)
|
||||
assert.Error(t, err)
|
||||
rand.Reader = &mockReader{} // Set it back for the next test
|
||||
|
||||
// Encrypt message too long
|
||||
rand.Reader = oldReader // Restore reader for this test
|
||||
pubKey, _, err := s.GenerateKeyPair(2048)
|
||||
assert.NoError(t, err)
|
||||
message := make([]byte, 2048)
|
||||
_, err = s.Encrypt(pubKey, message, nil)
|
||||
assert.Error(t, err)
|
||||
rand.Reader = &mockReader{} // Set it back
|
||||
}
|
||||
100
crypt/symmetric.go
Normal file
100
crypt/symmetric.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
|
||||
core "forge.lthn.ai/core/go/pkg/framework/core"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// ChaCha20Encrypt encrypts plaintext using ChaCha20-Poly1305.
|
||||
// The key must be 32 bytes. The nonce is randomly generated and prepended
|
||||
// to the ciphertext.
|
||||
func ChaCha20Encrypt(plaintext, key []byte) ([]byte, error) {
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.ChaCha20Encrypt", "failed to create cipher", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, core.E("crypt.ChaCha20Encrypt", "failed to generate nonce", err)
|
||||
}
|
||||
|
||||
ciphertext := aead.Seal(nonce, nonce, plaintext, nil)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// ChaCha20Decrypt decrypts ciphertext encrypted with ChaCha20Encrypt.
|
||||
// The key must be 32 bytes. Expects the nonce prepended to the ciphertext.
|
||||
func ChaCha20Decrypt(ciphertext, key []byte) ([]byte, error) {
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.ChaCha20Decrypt", "failed to create cipher", err)
|
||||
}
|
||||
|
||||
nonceSize := aead.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return nil, core.E("crypt.ChaCha20Decrypt", "ciphertext too short", nil)
|
||||
}
|
||||
|
||||
nonce, encrypted := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
plaintext, err := aead.Open(nil, nonce, encrypted, nil)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.ChaCha20Decrypt", "failed to decrypt", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// AESGCMEncrypt encrypts plaintext using AES-256-GCM.
|
||||
// The key must be 32 bytes. The nonce is randomly generated and prepended
|
||||
// to the ciphertext.
|
||||
func AESGCMEncrypt(plaintext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.AESGCMEncrypt", "failed to create cipher", err)
|
||||
}
|
||||
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.AESGCMEncrypt", "failed to create GCM", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, core.E("crypt.AESGCMEncrypt", "failed to generate nonce", err)
|
||||
}
|
||||
|
||||
ciphertext := aead.Seal(nonce, nonce, plaintext, nil)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// AESGCMDecrypt decrypts ciphertext encrypted with AESGCMEncrypt.
|
||||
// The key must be 32 bytes. Expects the nonce prepended to the ciphertext.
|
||||
func AESGCMDecrypt(ciphertext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.AESGCMDecrypt", "failed to create cipher", err)
|
||||
}
|
||||
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.AESGCMDecrypt", "failed to create GCM", err)
|
||||
}
|
||||
|
||||
nonceSize := aead.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return nil, core.E("crypt.AESGCMDecrypt", "ciphertext too short", nil)
|
||||
}
|
||||
|
||||
nonce, encrypted := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
plaintext, err := aead.Open(nil, nonce, encrypted, nil)
|
||||
if err != nil {
|
||||
return nil, core.E("crypt.AESGCMDecrypt", "failed to decrypt", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
55
crypt/symmetric_test.go
Normal file
55
crypt/symmetric_test.go
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
package crypt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestChaCha20_Good(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
plaintext := []byte("ChaCha20-Poly1305 test data")
|
||||
|
||||
encrypted, err := ChaCha20Encrypt(plaintext, key)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, plaintext, encrypted)
|
||||
|
||||
decrypted, err := ChaCha20Decrypt(encrypted, key)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
|
||||
func TestChaCha20_Bad(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
wrongKey := make([]byte, 32)
|
||||
_, _ = rand.Read(key)
|
||||
_, _ = rand.Read(wrongKey)
|
||||
|
||||
plaintext := []byte("secret message")
|
||||
|
||||
encrypted, err := ChaCha20Encrypt(plaintext, key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = ChaCha20Decrypt(encrypted, wrongKey)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAESGCM_Good(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
plaintext := []byte("AES-256-GCM test data")
|
||||
|
||||
encrypted, err := AESGCMEncrypt(plaintext, key)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, plaintext, encrypted)
|
||||
|
||||
decrypted, err := AESGCMDecrypt(encrypted, key)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
20
go.mod
Normal file
20
go.mod
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
module forge.lthn.ai/core/go-crypt
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
forge.lthn.ai/core/go v0.0.0
|
||||
github.com/ProtonMail/go-crypto v1.3.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/crypto v0.48.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/cloudflare/circl v1.6.3 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace forge.lthn.ai/core/go => ../go
|
||||
18
go.sum
Normal file
18
go.sum
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
|
||||
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
|
||||
github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8=
|
||||
github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
238
trust/policy.go
Normal file
238
trust/policy.go
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
package trust
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Policy defines the access rules for a given trust tier.
|
||||
type Policy struct {
|
||||
// Tier is the trust level this policy applies to.
|
||||
Tier Tier
|
||||
// Allowed lists the capabilities granted at this tier.
|
||||
Allowed []Capability
|
||||
// RequiresApproval lists capabilities that need human/higher-tier approval.
|
||||
RequiresApproval []Capability
|
||||
// Denied lists explicitly denied capabilities.
|
||||
Denied []Capability
|
||||
}
|
||||
|
||||
// PolicyEngine evaluates capability requests against registered policies.
|
||||
type PolicyEngine struct {
|
||||
registry *Registry
|
||||
policies map[Tier]*Policy
|
||||
}
|
||||
|
||||
// Decision is the result of a policy evaluation.
|
||||
type Decision int
|
||||
|
||||
const (
|
||||
// Deny means the action is not permitted.
|
||||
Deny Decision = iota
|
||||
// Allow means the action is permitted.
|
||||
Allow
|
||||
// NeedsApproval means the action requires human or higher-tier approval.
|
||||
NeedsApproval
|
||||
)
|
||||
|
||||
// String returns the human-readable name of the decision.
|
||||
func (d Decision) String() string {
|
||||
switch d {
|
||||
case Deny:
|
||||
return "deny"
|
||||
case Allow:
|
||||
return "allow"
|
||||
case NeedsApproval:
|
||||
return "needs_approval"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", int(d))
|
||||
}
|
||||
}
|
||||
|
||||
// EvalResult contains the outcome of a capability evaluation.
|
||||
type EvalResult struct {
|
||||
Decision Decision
|
||||
Agent string
|
||||
Cap Capability
|
||||
Reason string
|
||||
}
|
||||
|
||||
// NewPolicyEngine creates a policy engine with the given registry and default policies.
|
||||
func NewPolicyEngine(registry *Registry) *PolicyEngine {
|
||||
pe := &PolicyEngine{
|
||||
registry: registry,
|
||||
policies: make(map[Tier]*Policy),
|
||||
}
|
||||
pe.loadDefaults()
|
||||
return pe
|
||||
}
|
||||
|
||||
// Evaluate checks whether the named agent can perform the given capability.
|
||||
// If the agent has scoped repos and the capability is repo-scoped, the repo
|
||||
// parameter is checked against the agent's allowed repos.
|
||||
func (pe *PolicyEngine) Evaluate(agentName string, cap Capability, repo string) EvalResult {
|
||||
agent := pe.registry.Get(agentName)
|
||||
if agent == nil {
|
||||
return EvalResult{
|
||||
Decision: Deny,
|
||||
Agent: agentName,
|
||||
Cap: cap,
|
||||
Reason: "agent not registered",
|
||||
}
|
||||
}
|
||||
|
||||
policy, ok := pe.policies[agent.Tier]
|
||||
if !ok {
|
||||
return EvalResult{
|
||||
Decision: Deny,
|
||||
Agent: agentName,
|
||||
Cap: cap,
|
||||
Reason: fmt.Sprintf("no policy for tier %s", agent.Tier),
|
||||
}
|
||||
}
|
||||
|
||||
// Check explicit denials first.
|
||||
for _, denied := range policy.Denied {
|
||||
if denied == cap {
|
||||
return EvalResult{
|
||||
Decision: Deny,
|
||||
Agent: agentName,
|
||||
Cap: cap,
|
||||
Reason: fmt.Sprintf("capability %s is denied for tier %s", cap, agent.Tier),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if capability requires approval.
|
||||
for _, approval := range policy.RequiresApproval {
|
||||
if approval == cap {
|
||||
return EvalResult{
|
||||
Decision: NeedsApproval,
|
||||
Agent: agentName,
|
||||
Cap: cap,
|
||||
Reason: fmt.Sprintf("capability %s requires approval for tier %s", cap, agent.Tier),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if capability is allowed.
|
||||
for _, allowed := range policy.Allowed {
|
||||
if allowed == cap {
|
||||
// For repo-scoped capabilities, verify repo access.
|
||||
if isRepoScoped(cap) && len(agent.ScopedRepos) > 0 {
|
||||
if !repoAllowed(agent.ScopedRepos, repo) {
|
||||
return EvalResult{
|
||||
Decision: Deny,
|
||||
Agent: agentName,
|
||||
Cap: cap,
|
||||
Reason: fmt.Sprintf("agent %q does not have access to repo %q", agentName, repo),
|
||||
}
|
||||
}
|
||||
}
|
||||
return EvalResult{
|
||||
Decision: Allow,
|
||||
Agent: agentName,
|
||||
Cap: cap,
|
||||
Reason: fmt.Sprintf("capability %s allowed for tier %s", cap, agent.Tier),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return EvalResult{
|
||||
Decision: Deny,
|
||||
Agent: agentName,
|
||||
Cap: cap,
|
||||
Reason: fmt.Sprintf("capability %s not granted for tier %s", cap, agent.Tier),
|
||||
}
|
||||
}
|
||||
|
||||
// SetPolicy replaces the policy for a given tier.
|
||||
func (pe *PolicyEngine) SetPolicy(p Policy) error {
|
||||
if !p.Tier.Valid() {
|
||||
return fmt.Errorf("trust.SetPolicy: invalid tier %d", p.Tier)
|
||||
}
|
||||
pe.policies[p.Tier] = &p
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPolicy returns the policy for a tier, or nil if none is set.
|
||||
func (pe *PolicyEngine) GetPolicy(t Tier) *Policy {
|
||||
return pe.policies[t]
|
||||
}
|
||||
|
||||
// loadDefaults installs the default trust policies from the issue spec.
|
||||
func (pe *PolicyEngine) loadDefaults() {
|
||||
// Tier 3 — Full Trust
|
||||
pe.policies[TierFull] = &Policy{
|
||||
Tier: TierFull,
|
||||
Allowed: []Capability{
|
||||
CapPushRepo,
|
||||
CapMergePR,
|
||||
CapCreatePR,
|
||||
CapCreateIssue,
|
||||
CapCommentIssue,
|
||||
CapReadSecrets,
|
||||
CapRunPrivileged,
|
||||
CapAccessWorkspace,
|
||||
CapModifyFlows,
|
||||
},
|
||||
}
|
||||
|
||||
// Tier 2 — Verified
|
||||
pe.policies[TierVerified] = &Policy{
|
||||
Tier: TierVerified,
|
||||
Allowed: []Capability{
|
||||
CapPushRepo, // scoped to assigned repos
|
||||
CapCreatePR, // can create, not merge
|
||||
CapCreateIssue,
|
||||
CapCommentIssue,
|
||||
CapReadSecrets, // scoped to their repos
|
||||
},
|
||||
RequiresApproval: []Capability{
|
||||
CapMergePR,
|
||||
},
|
||||
Denied: []Capability{
|
||||
CapAccessWorkspace, // cannot access other agents' workspaces
|
||||
CapModifyFlows,
|
||||
CapRunPrivileged,
|
||||
},
|
||||
}
|
||||
|
||||
// Tier 1 — Untrusted
|
||||
pe.policies[TierUntrusted] = &Policy{
|
||||
Tier: TierUntrusted,
|
||||
Allowed: []Capability{
|
||||
CapCreatePR, // fork only, checked at enforcement layer
|
||||
CapCommentIssue,
|
||||
},
|
||||
Denied: []Capability{
|
||||
CapPushRepo,
|
||||
CapMergePR,
|
||||
CapCreateIssue,
|
||||
CapReadSecrets,
|
||||
CapRunPrivileged,
|
||||
CapAccessWorkspace,
|
||||
CapModifyFlows,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// isRepoScoped returns true if the capability is constrained by repo scope.
|
||||
func isRepoScoped(cap Capability) bool {
|
||||
return strings.HasPrefix(string(cap), "repo.") ||
|
||||
strings.HasPrefix(string(cap), "pr.") ||
|
||||
cap == CapReadSecrets
|
||||
}
|
||||
|
||||
// repoAllowed checks if repo is in the agent's scoped list.
|
||||
func repoAllowed(scoped []string, repo string) bool {
|
||||
if repo == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range scoped {
|
||||
if r == repo {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
268
trust/policy_test.go
Normal file
268
trust/policy_test.go
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
package trust
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestEngine(t *testing.T) *PolicyEngine {
|
||||
t.Helper()
|
||||
r := NewRegistry()
|
||||
require.NoError(t, r.Register(Agent{
|
||||
Name: "Athena",
|
||||
Tier: TierFull,
|
||||
}))
|
||||
require.NoError(t, r.Register(Agent{
|
||||
Name: "Clotho",
|
||||
Tier: TierVerified,
|
||||
ScopedRepos: []string{"host-uk/core", "host-uk/docs"},
|
||||
}))
|
||||
require.NoError(t, r.Register(Agent{
|
||||
Name: "BugSETI-001",
|
||||
Tier: TierUntrusted,
|
||||
}))
|
||||
return NewPolicyEngine(r)
|
||||
}
|
||||
|
||||
// --- Decision ---
|
||||
|
||||
func TestDecisionString_Good(t *testing.T) {
|
||||
assert.Equal(t, "deny", Deny.String())
|
||||
assert.Equal(t, "allow", Allow.String())
|
||||
assert.Equal(t, "needs_approval", NeedsApproval.String())
|
||||
}
|
||||
|
||||
func TestDecisionString_Bad_Unknown(t *testing.T) {
|
||||
assert.Contains(t, Decision(99).String(), "unknown")
|
||||
}
|
||||
|
||||
// --- Tier 3 (Full Trust) ---
|
||||
|
||||
func TestEvaluate_Good_Tier3CanDoAnything(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
|
||||
caps := []Capability{
|
||||
CapPushRepo, CapMergePR, CapCreatePR, CapCreateIssue,
|
||||
CapCommentIssue, CapReadSecrets, CapRunPrivileged,
|
||||
CapAccessWorkspace, CapModifyFlows,
|
||||
}
|
||||
for _, cap := range caps {
|
||||
result := pe.Evaluate("Athena", cap, "")
|
||||
assert.Equal(t, Allow, result.Decision, "Athena should be allowed %s", cap)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tier 2 (Verified) ---
|
||||
|
||||
func TestEvaluate_Good_Tier2CanCreatePR(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("Clotho", CapCreatePR, "host-uk/core")
|
||||
assert.Equal(t, Allow, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Good_Tier2CanPushToScopedRepo(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("Clotho", CapPushRepo, "host-uk/core")
|
||||
assert.Equal(t, Allow, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Good_Tier2NeedsApprovalToMerge(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("Clotho", CapMergePR, "host-uk/core")
|
||||
assert.Equal(t, NeedsApproval, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Good_Tier2CanCreateIssue(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("Clotho", CapCreateIssue, "")
|
||||
assert.Equal(t, Allow, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Bad_Tier2CannotAccessWorkspace(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("Clotho", CapAccessWorkspace, "")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Bad_Tier2CannotModifyFlows(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("Clotho", CapModifyFlows, "")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Bad_Tier2CannotRunPrivileged(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("Clotho", CapRunPrivileged, "")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Bad_Tier2CannotPushToUnscopedRepo(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("Clotho", CapPushRepo, "host-uk/secret-repo")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
assert.Contains(t, result.Reason, "does not have access")
|
||||
}
|
||||
|
||||
func TestEvaluate_Bad_Tier2RepoScopeEmptyRepo(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
// Push without specifying a repo should be denied for scoped agents.
|
||||
result := pe.Evaluate("Clotho", CapPushRepo, "")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
}
|
||||
|
||||
// --- Tier 1 (Untrusted) ---
|
||||
|
||||
func TestEvaluate_Good_Tier1CanCreatePR(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("BugSETI-001", CapCreatePR, "")
|
||||
assert.Equal(t, Allow, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Good_Tier1CanCommentIssue(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("BugSETI-001", CapCommentIssue, "")
|
||||
assert.Equal(t, Allow, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Bad_Tier1CannotPush(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("BugSETI-001", CapPushRepo, "")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Bad_Tier1CannotMerge(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("BugSETI-001", CapMergePR, "")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Bad_Tier1CannotCreateIssue(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("BugSETI-001", CapCreateIssue, "")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Bad_Tier1CannotReadSecrets(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("BugSETI-001", CapReadSecrets, "")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
}
|
||||
|
||||
func TestEvaluate_Bad_Tier1CannotRunPrivileged(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("BugSETI-001", CapRunPrivileged, "")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
}
|
||||
|
||||
// --- Edge cases ---
|
||||
|
||||
func TestEvaluate_Bad_UnknownAgent(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("Unknown", CapCreatePR, "")
|
||||
assert.Equal(t, Deny, result.Decision)
|
||||
assert.Contains(t, result.Reason, "not registered")
|
||||
}
|
||||
|
||||
func TestEvaluate_Good_EvalResultFields(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
result := pe.Evaluate("Athena", CapPushRepo, "")
|
||||
assert.Equal(t, "Athena", result.Agent)
|
||||
assert.Equal(t, CapPushRepo, result.Cap)
|
||||
assert.NotEmpty(t, result.Reason)
|
||||
}
|
||||
|
||||
// --- SetPolicy ---
|
||||
|
||||
func TestSetPolicy_Good(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
err := pe.SetPolicy(Policy{
|
||||
Tier: TierVerified,
|
||||
Allowed: []Capability{CapPushRepo, CapMergePR},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the new policy is in effect.
|
||||
result := pe.Evaluate("Clotho", CapMergePR, "host-uk/core")
|
||||
assert.Equal(t, Allow, result.Decision)
|
||||
}
|
||||
|
||||
func TestSetPolicy_Bad_InvalidTier(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
err := pe.SetPolicy(Policy{Tier: Tier(0)})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid tier")
|
||||
}
|
||||
|
||||
func TestGetPolicy_Good(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
p := pe.GetPolicy(TierFull)
|
||||
require.NotNil(t, p)
|
||||
assert.Equal(t, TierFull, p.Tier)
|
||||
}
|
||||
|
||||
func TestGetPolicy_Bad_NotFound(t *testing.T) {
|
||||
pe := newTestEngine(t)
|
||||
assert.Nil(t, pe.GetPolicy(Tier(99)))
|
||||
}
|
||||
|
||||
// --- isRepoScoped / repoAllowed helpers ---
|
||||
|
||||
func TestIsRepoScoped_Good(t *testing.T) {
|
||||
assert.True(t, isRepoScoped(CapPushRepo))
|
||||
assert.True(t, isRepoScoped(CapCreatePR))
|
||||
assert.True(t, isRepoScoped(CapMergePR))
|
||||
assert.True(t, isRepoScoped(CapReadSecrets))
|
||||
}
|
||||
|
||||
func TestIsRepoScoped_Bad_NotScoped(t *testing.T) {
|
||||
assert.False(t, isRepoScoped(CapRunPrivileged))
|
||||
assert.False(t, isRepoScoped(CapAccessWorkspace))
|
||||
assert.False(t, isRepoScoped(CapModifyFlows))
|
||||
}
|
||||
|
||||
func TestRepoAllowed_Good(t *testing.T) {
|
||||
scoped := []string{"host-uk/core", "host-uk/docs"}
|
||||
assert.True(t, repoAllowed(scoped, "host-uk/core"))
|
||||
assert.True(t, repoAllowed(scoped, "host-uk/docs"))
|
||||
}
|
||||
|
||||
func TestRepoAllowed_Bad_NotInScope(t *testing.T) {
|
||||
scoped := []string{"host-uk/core"}
|
||||
assert.False(t, repoAllowed(scoped, "host-uk/secret"))
|
||||
}
|
||||
|
||||
func TestRepoAllowed_Bad_EmptyRepo(t *testing.T) {
|
||||
scoped := []string{"host-uk/core"}
|
||||
assert.False(t, repoAllowed(scoped, ""))
|
||||
}
|
||||
|
||||
func TestRepoAllowed_Bad_EmptyScope(t *testing.T) {
|
||||
assert.False(t, repoAllowed(nil, "host-uk/core"))
|
||||
assert.False(t, repoAllowed([]string{}, "host-uk/core"))
|
||||
}
|
||||
|
||||
// --- Tier 3 ignores repo scoping ---
|
||||
|
||||
func TestEvaluate_Good_Tier3IgnoresRepoScope(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
require.NoError(t, r.Register(Agent{
|
||||
Name: "Virgil",
|
||||
Tier: TierFull,
|
||||
ScopedRepos: []string{}, // empty scope should not restrict Tier 3
|
||||
}))
|
||||
pe := NewPolicyEngine(r)
|
||||
|
||||
result := pe.Evaluate("Virgil", CapPushRepo, "any-repo")
|
||||
assert.Equal(t, Allow, result.Decision)
|
||||
}
|
||||
|
||||
// --- Default rate limits ---
|
||||
|
||||
func TestDefaultRateLimit(t *testing.T) {
|
||||
assert.Equal(t, 10, defaultRateLimit(TierUntrusted))
|
||||
assert.Equal(t, 60, defaultRateLimit(TierVerified))
|
||||
assert.Equal(t, 0, defaultRateLimit(TierFull))
|
||||
assert.Equal(t, 10, defaultRateLimit(Tier(99))) // unknown defaults to 10
|
||||
}
|
||||
165
trust/trust.go
Normal file
165
trust/trust.go
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
// Package trust implements an agent trust model with tiered access control.
|
||||
//
|
||||
// Agents are assigned trust tiers that determine their capabilities:
|
||||
//
|
||||
// - Tier 3 (Full Trust): Internal agents with full access (e.g., Athena, Virgil, Charon)
|
||||
// - Tier 2 (Verified): Partner agents with scoped access (e.g., Clotho, Hypnos)
|
||||
// - Tier 1 (Untrusted): External/community agents with minimal access
|
||||
//
|
||||
// The package provides a Registry for managing agent identities and a PolicyEngine
|
||||
// for evaluating capability requests against trust policies.
|
||||
package trust
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Tier represents an agent's trust level in the system.
|
||||
type Tier int
|
||||
|
||||
const (
|
||||
// TierUntrusted is for external/community agents with minimal access.
|
||||
TierUntrusted Tier = 1
|
||||
// TierVerified is for partner agents with scoped access.
|
||||
TierVerified Tier = 2
|
||||
// TierFull is for internal agents with full access.
|
||||
TierFull Tier = 3
|
||||
)
|
||||
|
||||
// String returns the human-readable name of the tier.
|
||||
func (t Tier) String() string {
|
||||
switch t {
|
||||
case TierUntrusted:
|
||||
return "untrusted"
|
||||
case TierVerified:
|
||||
return "verified"
|
||||
case TierFull:
|
||||
return "full"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", int(t))
|
||||
}
|
||||
}
|
||||
|
||||
// Valid returns true if the tier is a recognised trust level.
|
||||
func (t Tier) Valid() bool {
|
||||
return t >= TierUntrusted && t <= TierFull
|
||||
}
|
||||
|
||||
// Capability represents a specific action an agent can perform.
|
||||
type Capability string
|
||||
|
||||
const (
|
||||
CapPushRepo Capability = "repo.push"
|
||||
CapMergePR Capability = "pr.merge"
|
||||
CapCreatePR Capability = "pr.create"
|
||||
CapCreateIssue Capability = "issue.create"
|
||||
CapCommentIssue Capability = "issue.comment"
|
||||
CapReadSecrets Capability = "secrets.read"
|
||||
CapRunPrivileged Capability = "cmd.privileged"
|
||||
CapAccessWorkspace Capability = "workspace.access"
|
||||
CapModifyFlows Capability = "flows.modify"
|
||||
)
|
||||
|
||||
// Agent represents an agent identity in the trust system.
|
||||
type Agent struct {
|
||||
// Name is the unique identifier for the agent (e.g., "Athena", "Clotho").
|
||||
Name string
|
||||
// Tier is the agent's trust level.
|
||||
Tier Tier
|
||||
// ScopedRepos limits repo access for Tier 2 agents. Empty means no repo access.
|
||||
// Tier 3 agents ignore this field (they have access to all repos).
|
||||
ScopedRepos []string
|
||||
// RateLimit is the maximum requests per minute. 0 means unlimited.
|
||||
RateLimit int
|
||||
// TokenExpiresAt is when the agent's token expires.
|
||||
TokenExpiresAt time.Time
|
||||
// CreatedAt is when the agent was registered.
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// Registry manages agent identities and their trust tiers.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
agents map[string]*Agent
|
||||
}
|
||||
|
||||
// NewRegistry creates an empty agent registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
agents: make(map[string]*Agent),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds or updates an agent in the registry.
|
||||
// Returns an error if the agent name is empty or the tier is invalid.
|
||||
func (r *Registry) Register(agent Agent) error {
|
||||
if agent.Name == "" {
|
||||
return fmt.Errorf("trust.Register: agent name is required")
|
||||
}
|
||||
if !agent.Tier.Valid() {
|
||||
return fmt.Errorf("trust.Register: invalid tier %d for agent %q", agent.Tier, agent.Name)
|
||||
}
|
||||
if agent.CreatedAt.IsZero() {
|
||||
agent.CreatedAt = time.Now()
|
||||
}
|
||||
if agent.RateLimit == 0 {
|
||||
agent.RateLimit = defaultRateLimit(agent.Tier)
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.agents[agent.Name] = &agent
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the agent with the given name, or nil if not found.
|
||||
func (r *Registry) Get(name string) *Agent {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.agents[name]
|
||||
}
|
||||
|
||||
// Remove deletes an agent from the registry.
|
||||
func (r *Registry) Remove(name string) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if _, ok := r.agents[name]; !ok {
|
||||
return false
|
||||
}
|
||||
delete(r.agents, name)
|
||||
return true
|
||||
}
|
||||
|
||||
// List returns all registered agents. The returned slice is a snapshot.
|
||||
func (r *Registry) List() []Agent {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
out := make([]Agent, 0, len(r.agents))
|
||||
for _, a := range r.agents {
|
||||
out = append(out, *a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Len returns the number of registered agents.
|
||||
func (r *Registry) Len() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.agents)
|
||||
}
|
||||
|
||||
// defaultRateLimit returns the default rate limit for a given tier.
|
||||
func defaultRateLimit(t Tier) int {
|
||||
switch t {
|
||||
case TierUntrusted:
|
||||
return 10
|
||||
case TierVerified:
|
||||
return 60
|
||||
case TierFull:
|
||||
return 0 // unlimited
|
||||
default:
|
||||
return 10
|
||||
}
|
||||
}
|
||||
164
trust/trust_test.go
Normal file
164
trust/trust_test.go
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
package trust
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Tier ---
|
||||
|
||||
func TestTierString_Good(t *testing.T) {
|
||||
assert.Equal(t, "untrusted", TierUntrusted.String())
|
||||
assert.Equal(t, "verified", TierVerified.String())
|
||||
assert.Equal(t, "full", TierFull.String())
|
||||
}
|
||||
|
||||
func TestTierString_Bad_Unknown(t *testing.T) {
|
||||
assert.Contains(t, Tier(99).String(), "unknown")
|
||||
}
|
||||
|
||||
func TestTierValid_Good(t *testing.T) {
|
||||
assert.True(t, TierUntrusted.Valid())
|
||||
assert.True(t, TierVerified.Valid())
|
||||
assert.True(t, TierFull.Valid())
|
||||
}
|
||||
|
||||
func TestTierValid_Bad(t *testing.T) {
|
||||
assert.False(t, Tier(0).Valid())
|
||||
assert.False(t, Tier(4).Valid())
|
||||
assert.False(t, Tier(-1).Valid())
|
||||
}
|
||||
|
||||
// --- Registry ---
|
||||
|
||||
func TestRegistryRegister_Good(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
err := r.Register(Agent{Name: "Athena", Tier: TierFull})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, r.Len())
|
||||
}
|
||||
|
||||
func TestRegistryRegister_Good_SetsDefaults(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
err := r.Register(Agent{Name: "Athena", Tier: TierFull})
|
||||
require.NoError(t, err)
|
||||
|
||||
a := r.Get("Athena")
|
||||
require.NotNil(t, a)
|
||||
assert.Equal(t, 0, a.RateLimit) // full trust = unlimited
|
||||
assert.False(t, a.CreatedAt.IsZero())
|
||||
}
|
||||
|
||||
func TestRegistryRegister_Good_TierDefaults(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
require.NoError(t, r.Register(Agent{Name: "A", Tier: TierUntrusted}))
|
||||
require.NoError(t, r.Register(Agent{Name: "B", Tier: TierVerified}))
|
||||
require.NoError(t, r.Register(Agent{Name: "C", Tier: TierFull}))
|
||||
|
||||
assert.Equal(t, 10, r.Get("A").RateLimit)
|
||||
assert.Equal(t, 60, r.Get("B").RateLimit)
|
||||
assert.Equal(t, 0, r.Get("C").RateLimit)
|
||||
}
|
||||
|
||||
func TestRegistryRegister_Good_PreservesExplicitRateLimit(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
err := r.Register(Agent{Name: "Custom", Tier: TierVerified, RateLimit: 30})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 30, r.Get("Custom").RateLimit)
|
||||
}
|
||||
|
||||
func TestRegistryRegister_Good_Update(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierVerified}))
|
||||
require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierFull}))
|
||||
|
||||
assert.Equal(t, 1, r.Len())
|
||||
assert.Equal(t, TierFull, r.Get("Athena").Tier)
|
||||
}
|
||||
|
||||
func TestRegistryRegister_Bad_EmptyName(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
err := r.Register(Agent{Tier: TierFull})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "name is required")
|
||||
}
|
||||
|
||||
func TestRegistryRegister_Bad_InvalidTier(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
err := r.Register(Agent{Name: "Bad", Tier: Tier(0)})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid tier")
|
||||
}
|
||||
|
||||
func TestRegistryGet_Good(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierFull}))
|
||||
a := r.Get("Athena")
|
||||
require.NotNil(t, a)
|
||||
assert.Equal(t, "Athena", a.Name)
|
||||
}
|
||||
|
||||
func TestRegistryGet_Bad_NotFound(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
assert.Nil(t, r.Get("nonexistent"))
|
||||
}
|
||||
|
||||
func TestRegistryRemove_Good(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierFull}))
|
||||
assert.True(t, r.Remove("Athena"))
|
||||
assert.Equal(t, 0, r.Len())
|
||||
}
|
||||
|
||||
func TestRegistryRemove_Bad_NotFound(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
assert.False(t, r.Remove("nonexistent"))
|
||||
}
|
||||
|
||||
func TestRegistryList_Good(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierFull}))
|
||||
require.NoError(t, r.Register(Agent{Name: "Clotho", Tier: TierVerified}))
|
||||
|
||||
agents := r.List()
|
||||
assert.Len(t, agents, 2)
|
||||
|
||||
names := make(map[string]bool)
|
||||
for _, a := range agents {
|
||||
names[a.Name] = true
|
||||
}
|
||||
assert.True(t, names["Athena"])
|
||||
assert.True(t, names["Clotho"])
|
||||
}
|
||||
|
||||
func TestRegistryList_Good_Empty(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
assert.Empty(t, r.List())
|
||||
}
|
||||
|
||||
func TestRegistryList_Good_Snapshot(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierFull}))
|
||||
agents := r.List()
|
||||
|
||||
// Modifying the returned slice should not affect the registry.
|
||||
agents[0].Tier = TierUntrusted
|
||||
assert.Equal(t, TierFull, r.Get("Athena").Tier)
|
||||
}
|
||||
|
||||
// --- Agent ---
|
||||
|
||||
func TestAgentTokenExpiry(t *testing.T) {
|
||||
agent := Agent{
|
||||
Name: "Test",
|
||||
Tier: TierVerified,
|
||||
TokenExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||
}
|
||||
assert.True(t, time.Now().After(agent.TokenExpiresAt))
|
||||
|
||||
agent.TokenExpiresAt = time.Now().Add(1 * time.Hour)
|
||||
assert.True(t, time.Now().Before(agent.TokenExpiresAt))
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue