From 8498ecf890d2264f3e1333fca2ca646c0710c1dc Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 15:25:54 +0000 Subject: [PATCH] feat: extract crypto/security packages from core/go ChaCha20-Poly1305, AES-256-GCM, Argon2 key derivation, OpenPGP challenge-response auth, and trust tier policy engine. Co-Authored-By: Claude Opus 4.6 --- auth/auth.go | 455 ++++++++++++++++++++++ auth/auth_test.go | 581 ++++++++++++++++++++++++++++ crypt/chachapoly/chachapoly.go | 50 +++ crypt/chachapoly/chachapoly_test.go | 114 ++++++ crypt/checksum.go | 55 +++ crypt/checksum_test.go | 23 ++ crypt/crypt.go | 90 +++++ crypt/crypt_test.go | 45 +++ crypt/hash.go | 89 +++++ crypt/hash_test.go | 50 +++ crypt/hmac.go | 30 ++ crypt/hmac_test.go | 40 ++ crypt/kdf.go | 60 +++ crypt/kdf_test.go | 56 +++ crypt/lthn/lthn.go | 94 +++++ crypt/lthn/lthn_test.go | 66 ++++ crypt/openpgp/service.go | 191 +++++++++ crypt/openpgp/service_test.go | 43 ++ crypt/pgp/pgp.go | 230 +++++++++++ crypt/pgp/pgp_test.go | 164 ++++++++ crypt/rsa/rsa.go | 91 +++++ crypt/rsa/rsa_test.go | 101 +++++ crypt/symmetric.go | 100 +++++ crypt/symmetric_test.go | 55 +++ go.mod | 20 + go.sum | 18 + trust/policy.go | 238 ++++++++++++ trust/policy_test.go | 268 +++++++++++++ trust/trust.go | 165 ++++++++ trust/trust_test.go | 164 ++++++++ 30 files changed, 3746 insertions(+) create mode 100644 auth/auth.go create mode 100644 auth/auth_test.go create mode 100644 crypt/chachapoly/chachapoly.go create mode 100644 crypt/chachapoly/chachapoly_test.go create mode 100644 crypt/checksum.go create mode 100644 crypt/checksum_test.go create mode 100644 crypt/crypt.go create mode 100644 crypt/crypt_test.go create mode 100644 crypt/hash.go create mode 100644 crypt/hash_test.go create mode 100644 crypt/hmac.go create mode 100644 crypt/hmac_test.go create mode 100644 crypt/kdf.go create mode 100644 crypt/kdf_test.go create mode 100644 crypt/lthn/lthn.go create mode 100644 crypt/lthn/lthn_test.go create mode 100644 crypt/openpgp/service.go create mode 100644 crypt/openpgp/service_test.go create mode 100644 crypt/pgp/pgp.go create mode 100644 crypt/pgp/pgp_test.go create mode 100644 crypt/rsa/rsa.go create mode 100644 crypt/rsa/rsa_test.go create mode 100644 crypt/symmetric.go create mode 100644 crypt/symmetric_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 trust/policy.go create mode 100644 trust/policy_test.go create mode 100644 trust/trust.go create mode 100644 trust/trust_test.go diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000..103ece3 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,455 @@ +// Package auth implements OpenPGP challenge-response authentication with +// support for both online (HTTP) and air-gapped (file-based) transport. +// +// Ported from dAppServer's mod-auth/lethean.service.ts. +// +// Authentication Flow (Online): +// +// 1. Client sends public key to server +// 2. Server generates a random nonce, encrypts it with client's public key +// 3. Client decrypts the nonce and signs it with their private key +// 4. Server verifies the signature, creates a session token +// +// Authentication Flow (Air-Gapped / Courier): +// +// Same crypto but challenge/response are exchanged via files on a Medium. +// +// Storage Layout (via Medium): +// +// users/ +// {userID}.pub PGP public key (armored) +// {userID}.key PGP private key (armored, password-encrypted) +// {userID}.rev Revocation certificate (placeholder) +// {userID}.json User metadata (encrypted with user's public key) +// {userID}.lthn LTHN password hash +package auth + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "sync" + "time" + + coreerr "forge.lthn.ai/core/go/pkg/framework/core" + + "forge.lthn.ai/core/go-crypt/crypt/lthn" + "forge.lthn.ai/core/go-crypt/crypt/pgp" + "forge.lthn.ai/core/go/pkg/io" +) + +// Default durations for challenge and session lifetimes. +const ( + DefaultChallengeTTL = 5 * time.Minute + DefaultSessionTTL = 24 * time.Hour + nonceBytes = 32 +) + +// protectedUsers lists usernames that cannot be deleted. +// The "server" user holds the server keypair; deleting it would +// permanently destroy all joining data and require a full rebuild. +var protectedUsers = map[string]bool{ + "server": true, +} + +// User represents a registered user with PGP credentials. +type User struct { + PublicKey string `json:"public_key"` + KeyID string `json:"key_id"` + Fingerprint string `json:"fingerprint"` + PasswordHash string `json:"password_hash"` // LTHN hash + Created time.Time `json:"created"` + LastLogin time.Time `json:"last_login"` +} + +// Challenge is a PGP-encrypted nonce sent to a client during authentication. +type Challenge struct { + Nonce []byte `json:"nonce"` + Encrypted string `json:"encrypted"` // PGP-encrypted nonce (armored) + ExpiresAt time.Time `json:"expires_at"` +} + +// Session represents an authenticated session. +type Session struct { + Token string `json:"token"` + UserID string `json:"user_id"` + ExpiresAt time.Time `json:"expires_at"` +} + +// Option configures an Authenticator. +type Option func(*Authenticator) + +// WithChallengeTTL sets the lifetime of a challenge before it expires. +func WithChallengeTTL(d time.Duration) Option { + return func(a *Authenticator) { + a.challengeTTL = d + } +} + +// WithSessionTTL sets the lifetime of a session before it expires. +func WithSessionTTL(d time.Duration) Option { + return func(a *Authenticator) { + a.sessionTTL = d + } +} + +// Authenticator manages PGP-based challenge-response authentication. +// All user data and keys are persisted through an io.Medium, which may +// be backed by disk, memory (MockMedium), or any other storage backend. +type Authenticator struct { + medium io.Medium + sessions map[string]*Session + challenges map[string]*Challenge // userID -> pending challenge + mu sync.RWMutex + challengeTTL time.Duration + sessionTTL time.Duration +} + +// New creates an Authenticator that persists user data via the given Medium. +func New(m io.Medium, opts ...Option) *Authenticator { + a := &Authenticator{ + medium: m, + sessions: make(map[string]*Session), + challenges: make(map[string]*Challenge), + challengeTTL: DefaultChallengeTTL, + sessionTTL: DefaultSessionTTL, + } + for _, opt := range opts { + opt(a) + } + return a +} + +// userPath returns the storage path for a user artifact. +func userPath(userID, ext string) string { + return "users/" + userID + ext +} + +// Register creates a new user account. It hashes the username with LTHN to +// produce a userID, generates a PGP keypair (protected by the given password), +// and persists the public key, private key, revocation placeholder, password +// hash, and encrypted metadata via the Medium. +func (a *Authenticator) Register(username, password string) (*User, error) { + const op = "auth.Register" + + userID := lthn.Hash(username) + + // Check if user already exists + if a.medium.IsFile(userPath(userID, ".pub")) { + return nil, coreerr.E(op, "user already exists", nil) + } + + // Ensure users directory exists + if err := a.medium.EnsureDir("users"); err != nil { + return nil, coreerr.E(op, "failed to create users directory", err) + } + + // Generate PGP keypair + kp, err := pgp.CreateKeyPair(userID, userID+"@auth.local", password) + if err != nil { + return nil, coreerr.E(op, "failed to create PGP keypair", err) + } + + // Store public key + if err := a.medium.Write(userPath(userID, ".pub"), kp.PublicKey); err != nil { + return nil, coreerr.E(op, "failed to write public key", err) + } + + // Store private key (already encrypted by PGP if password is non-empty) + if err := a.medium.Write(userPath(userID, ".key"), kp.PrivateKey); err != nil { + return nil, coreerr.E(op, "failed to write private key", err) + } + + // Store revocation certificate placeholder + if err := a.medium.Write(userPath(userID, ".rev"), "REVOCATION_PLACEHOLDER"); err != nil { + return nil, coreerr.E(op, "failed to write revocation certificate", err) + } + + // Store LTHN password hash + passwordHash := lthn.Hash(password) + if err := a.medium.Write(userPath(userID, ".lthn"), passwordHash); err != nil { + return nil, coreerr.E(op, "failed to write password hash", err) + } + + // Build user metadata + now := time.Now() + user := &User{ + PublicKey: kp.PublicKey, + KeyID: userID, + Fingerprint: lthn.Hash(kp.PublicKey), + PasswordHash: passwordHash, + Created: now, + LastLogin: time.Time{}, + } + + // Encrypt metadata with the user's public key and store + metaJSON, err := json.Marshal(user) + if err != nil { + return nil, coreerr.E(op, "failed to marshal user metadata", err) + } + + encMeta, err := pgp.Encrypt(metaJSON, kp.PublicKey) + if err != nil { + return nil, coreerr.E(op, "failed to encrypt user metadata", err) + } + + if err := a.medium.Write(userPath(userID, ".json"), string(encMeta)); err != nil { + return nil, coreerr.E(op, "failed to write user metadata", err) + } + + return user, nil +} + +// CreateChallenge generates a cryptographic challenge for the given user. +// A random nonce is created and encrypted with the user's PGP public key. +// The client must decrypt the nonce and sign it to prove key ownership. +func (a *Authenticator) CreateChallenge(userID string) (*Challenge, error) { + const op = "auth.CreateChallenge" + + // Read user's public key + pubKey, err := a.medium.Read(userPath(userID, ".pub")) + if err != nil { + return nil, coreerr.E(op, "user not found", err) + } + + // Generate random nonce + nonce := make([]byte, nonceBytes) + if _, err := rand.Read(nonce); err != nil { + return nil, coreerr.E(op, "failed to generate nonce", err) + } + + // Encrypt nonce with user's public key + encrypted, err := pgp.Encrypt(nonce, pubKey) + if err != nil { + return nil, coreerr.E(op, "failed to encrypt nonce", err) + } + + challenge := &Challenge{ + Nonce: nonce, + Encrypted: string(encrypted), + ExpiresAt: time.Now().Add(a.challengeTTL), + } + + a.mu.Lock() + a.challenges[userID] = challenge + a.mu.Unlock() + + return challenge, nil +} + +// ValidateResponse verifies a signed nonce from the client. The client must +// have decrypted the challenge nonce and signed it with their private key. +// On success, a new session is created and returned. +func (a *Authenticator) ValidateResponse(userID string, signedNonce []byte) (*Session, error) { + const op = "auth.ValidateResponse" + + a.mu.Lock() + challenge, exists := a.challenges[userID] + if exists { + delete(a.challenges, userID) + } + a.mu.Unlock() + + if !exists { + return nil, coreerr.E(op, "no pending challenge for user", nil) + } + + // Check challenge expiry + if time.Now().After(challenge.ExpiresAt) { + return nil, coreerr.E(op, "challenge expired", nil) + } + + // Read user's public key + pubKey, err := a.medium.Read(userPath(userID, ".pub")) + if err != nil { + return nil, coreerr.E(op, "user not found", err) + } + + // Verify signature over the original nonce + if err := pgp.Verify(challenge.Nonce, signedNonce, pubKey); err != nil { + return nil, coreerr.E(op, "signature verification failed", err) + } + + return a.createSession(userID) +} + +// ValidateSession checks whether a token maps to a valid, non-expired session. +func (a *Authenticator) ValidateSession(token string) (*Session, error) { + const op = "auth.ValidateSession" + + a.mu.RLock() + session, exists := a.sessions[token] + a.mu.RUnlock() + + if !exists { + return nil, coreerr.E(op, "session not found", nil) + } + + if time.Now().After(session.ExpiresAt) { + a.mu.Lock() + delete(a.sessions, token) + a.mu.Unlock() + return nil, coreerr.E(op, "session expired", nil) + } + + return session, nil +} + +// RefreshSession extends the expiry of an existing valid session. +func (a *Authenticator) RefreshSession(token string) (*Session, error) { + const op = "auth.RefreshSession" + + a.mu.Lock() + defer a.mu.Unlock() + + session, exists := a.sessions[token] + if !exists { + return nil, coreerr.E(op, "session not found", nil) + } + + if time.Now().After(session.ExpiresAt) { + delete(a.sessions, token) + return nil, coreerr.E(op, "session expired", nil) + } + + session.ExpiresAt = time.Now().Add(a.sessionTTL) + return session, nil +} + +// RevokeSession removes a session, invalidating the token immediately. +func (a *Authenticator) RevokeSession(token string) error { + const op = "auth.RevokeSession" + + a.mu.Lock() + defer a.mu.Unlock() + + if _, exists := a.sessions[token]; !exists { + return coreerr.E(op, "session not found", nil) + } + + delete(a.sessions, token) + return nil +} + +// DeleteUser removes a user and all associated keys from storage. +// The "server" user is protected and cannot be deleted (mirroring the +// original TypeScript implementation's safeguard). +func (a *Authenticator) DeleteUser(userID string) error { + const op = "auth.DeleteUser" + + // Protect special users + if protectedUsers[userID] { + return coreerr.E(op, "cannot delete protected user", nil) + } + + // Check user exists + if !a.medium.IsFile(userPath(userID, ".pub")) { + return coreerr.E(op, "user not found", nil) + } + + // Remove all artifacts + extensions := []string{".pub", ".key", ".rev", ".json", ".lthn"} + for _, ext := range extensions { + p := userPath(userID, ext) + if a.medium.IsFile(p) { + if err := a.medium.Delete(p); err != nil { + return coreerr.E(op, "failed to delete "+ext, err) + } + } + } + + // Revoke any active sessions for this user + a.mu.Lock() + for token, session := range a.sessions { + if session.UserID == userID { + delete(a.sessions, token) + } + } + a.mu.Unlock() + + return nil +} + +// Login performs password-based authentication as a convenience method. +// It verifies the password against the stored LTHN hash and, on success, +// creates a new session. This bypasses the PGP challenge-response flow. +func (a *Authenticator) Login(userID, password string) (*Session, error) { + const op = "auth.Login" + + // Read stored password hash + storedHash, err := a.medium.Read(userPath(userID, ".lthn")) + if err != nil { + return nil, coreerr.E(op, "user not found", err) + } + + // Verify password + if !lthn.Verify(password, storedHash) { + return nil, coreerr.E(op, "invalid password", nil) + } + + return a.createSession(userID) +} + +// WriteChallengeFile writes an encrypted challenge to a file for air-gapped +// (courier) transport. The challenge is created and then its encrypted nonce +// is written to the specified path on the Medium. +func (a *Authenticator) WriteChallengeFile(userID, path string) error { + const op = "auth.WriteChallengeFile" + + challenge, err := a.CreateChallenge(userID) + if err != nil { + return coreerr.E(op, "failed to create challenge", err) + } + + data, err := json.Marshal(challenge) + if err != nil { + return coreerr.E(op, "failed to marshal challenge", err) + } + + if err := a.medium.Write(path, string(data)); err != nil { + return coreerr.E(op, "failed to write challenge file", err) + } + + return nil +} + +// ReadResponseFile reads a signed response from a file and validates it, +// completing the air-gapped authentication flow. The file must contain the +// raw PGP signature bytes (armored). +func (a *Authenticator) ReadResponseFile(userID, path string) (*Session, error) { + const op = "auth.ReadResponseFile" + + content, err := a.medium.Read(path) + if err != nil { + return nil, coreerr.E(op, "failed to read response file", err) + } + + session, err := a.ValidateResponse(userID, []byte(content)) + if err != nil { + return nil, coreerr.E(op, "failed to validate response", err) + } + + return session, nil +} + +// createSession generates a cryptographically random session token and +// stores the session in the in-memory session map. +func (a *Authenticator) createSession(userID string) (*Session, error) { + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + return nil, fmt.Errorf("auth: failed to generate session token: %w", err) + } + + session := &Session{ + Token: hex.EncodeToString(tokenBytes), + UserID: userID, + ExpiresAt: time.Now().Add(a.sessionTTL), + } + + a.mu.Lock() + a.sessions[session.Token] = session + a.mu.Unlock() + + return session, nil +} diff --git a/auth/auth_test.go b/auth/auth_test.go new file mode 100644 index 0000000..ff1b0f3 --- /dev/null +++ b/auth/auth_test.go @@ -0,0 +1,581 @@ +package auth + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "forge.lthn.ai/core/go-crypt/crypt/lthn" + "forge.lthn.ai/core/go-crypt/crypt/pgp" + "forge.lthn.ai/core/go/pkg/io" +) + +// helper creates a fresh Authenticator backed by MockMedium. +func newTestAuth(opts ...Option) (*Authenticator, *io.MockMedium) { + m := io.NewMockMedium() + a := New(m, opts...) + return a, m +} + +// --- Register --- + +func TestRegister_Good(t *testing.T) { + a, m := newTestAuth() + + user, err := a.Register("alice", "hunter2") + require.NoError(t, err) + require.NotNil(t, user) + + userID := lthn.Hash("alice") + + // Verify public key is stored + assert.True(t, m.IsFile(userPath(userID, ".pub"))) + assert.True(t, m.IsFile(userPath(userID, ".key"))) + assert.True(t, m.IsFile(userPath(userID, ".rev"))) + assert.True(t, m.IsFile(userPath(userID, ".json"))) + assert.True(t, m.IsFile(userPath(userID, ".lthn"))) + + // Verify user fields + assert.NotEmpty(t, user.PublicKey) + assert.Equal(t, userID, user.KeyID) + assert.NotEmpty(t, user.Fingerprint) + assert.Equal(t, lthn.Hash("hunter2"), user.PasswordHash) + assert.False(t, user.Created.IsZero()) +} + +func TestRegister_Bad(t *testing.T) { + a, _ := newTestAuth() + + // Register first time succeeds + _, err := a.Register("bob", "pass1") + require.NoError(t, err) + + // Duplicate registration should fail + _, err = a.Register("bob", "pass2") + assert.Error(t, err) + assert.Contains(t, err.Error(), "user already exists") +} + +func TestRegister_Ugly(t *testing.T) { + a, _ := newTestAuth() + + // Empty username/password should still work (PGP allows it) + user, err := a.Register("", "") + require.NoError(t, err) + require.NotNil(t, user) +} + +// --- CreateChallenge --- + +func TestCreateChallenge_Good(t *testing.T) { + a, _ := newTestAuth() + + user, err := a.Register("charlie", "pass") + require.NoError(t, err) + + challenge, err := a.CreateChallenge(user.KeyID) + require.NoError(t, err) + require.NotNil(t, challenge) + + assert.Len(t, challenge.Nonce, nonceBytes) + assert.NotEmpty(t, challenge.Encrypted) + assert.True(t, challenge.ExpiresAt.After(time.Now())) +} + +func TestCreateChallenge_Bad(t *testing.T) { + a, _ := newTestAuth() + + // Challenge for non-existent user + _, err := a.CreateChallenge("nonexistent-user-id") + assert.Error(t, err) + assert.Contains(t, err.Error(), "user not found") +} + +func TestCreateChallenge_Ugly(t *testing.T) { + a, _ := newTestAuth() + + // Empty userID + _, err := a.CreateChallenge("") + assert.Error(t, err) +} + +// --- ValidateResponse (full challenge-response flow) --- + +func TestValidateResponse_Good(t *testing.T) { + a, m := newTestAuth() + + // Register user + _, err := a.Register("dave", "password123") + require.NoError(t, err) + + userID := lthn.Hash("dave") + + // Create challenge + challenge, err := a.CreateChallenge(userID) + require.NoError(t, err) + + // Client-side: decrypt nonce, then sign it + privKey, err := m.Read(userPath(userID, ".key")) + require.NoError(t, err) + + decryptedNonce, err := pgp.Decrypt([]byte(challenge.Encrypted), privKey, "password123") + require.NoError(t, err) + assert.Equal(t, challenge.Nonce, decryptedNonce) + + signedNonce, err := pgp.Sign(decryptedNonce, privKey, "password123") + require.NoError(t, err) + + // Validate response + session, err := a.ValidateResponse(userID, signedNonce) + require.NoError(t, err) + require.NotNil(t, session) + + assert.NotEmpty(t, session.Token) + assert.Equal(t, userID, session.UserID) + assert.True(t, session.ExpiresAt.After(time.Now())) +} + +func TestValidateResponse_Bad(t *testing.T) { + a, _ := newTestAuth() + + _, err := a.Register("eve", "pass") + require.NoError(t, err) + userID := lthn.Hash("eve") + + // No pending challenge + _, err = a.ValidateResponse(userID, []byte("fake-signature")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no pending challenge") +} + +func TestValidateResponse_Ugly(t *testing.T) { + a, m := newTestAuth(WithChallengeTTL(1 * time.Millisecond)) + + _, err := a.Register("frank", "pass") + require.NoError(t, err) + userID := lthn.Hash("frank") + + // Create challenge and let it expire + challenge, err := a.CreateChallenge(userID) + require.NoError(t, err) + + time.Sleep(5 * time.Millisecond) + + // Sign with valid key but expired challenge + privKey, err := m.Read(userPath(userID, ".key")) + require.NoError(t, err) + + signedNonce, err := pgp.Sign(challenge.Nonce, privKey, "pass") + require.NoError(t, err) + + _, err = a.ValidateResponse(userID, signedNonce) + assert.Error(t, err) + assert.Contains(t, err.Error(), "challenge expired") +} + +// --- ValidateSession --- + +func TestValidateSession_Good(t *testing.T) { + a, _ := newTestAuth() + + _, err := a.Register("grace", "pass") + require.NoError(t, err) + userID := lthn.Hash("grace") + + session, err := a.Login(userID, "pass") + require.NoError(t, err) + + validated, err := a.ValidateSession(session.Token) + require.NoError(t, err) + assert.Equal(t, session.Token, validated.Token) + assert.Equal(t, userID, validated.UserID) +} + +func TestValidateSession_Bad(t *testing.T) { + a, _ := newTestAuth() + + _, err := a.ValidateSession("nonexistent-token") + assert.Error(t, err) + assert.Contains(t, err.Error(), "session not found") +} + +func TestValidateSession_Ugly(t *testing.T) { + a, _ := newTestAuth(WithSessionTTL(1 * time.Millisecond)) + + _, err := a.Register("heidi", "pass") + require.NoError(t, err) + userID := lthn.Hash("heidi") + + session, err := a.Login(userID, "pass") + require.NoError(t, err) + + time.Sleep(5 * time.Millisecond) + + _, err = a.ValidateSession(session.Token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "session expired") +} + +// --- RefreshSession --- + +func TestRefreshSession_Good(t *testing.T) { + a, _ := newTestAuth(WithSessionTTL(1 * time.Hour)) + + _, err := a.Register("ivan", "pass") + require.NoError(t, err) + userID := lthn.Hash("ivan") + + session, err := a.Login(userID, "pass") + require.NoError(t, err) + + originalExpiry := session.ExpiresAt + + // Small delay to ensure time moves forward + time.Sleep(2 * time.Millisecond) + + refreshed, err := a.RefreshSession(session.Token) + require.NoError(t, err) + assert.True(t, refreshed.ExpiresAt.After(originalExpiry)) +} + +func TestRefreshSession_Bad(t *testing.T) { + a, _ := newTestAuth() + + _, err := a.RefreshSession("nonexistent-token") + assert.Error(t, err) + assert.Contains(t, err.Error(), "session not found") +} + +func TestRefreshSession_Ugly(t *testing.T) { + a, _ := newTestAuth(WithSessionTTL(1 * time.Millisecond)) + + _, err := a.Register("judy", "pass") + require.NoError(t, err) + userID := lthn.Hash("judy") + + session, err := a.Login(userID, "pass") + require.NoError(t, err) + + time.Sleep(5 * time.Millisecond) + + _, err = a.RefreshSession(session.Token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "session expired") +} + +// --- RevokeSession --- + +func TestRevokeSession_Good(t *testing.T) { + a, _ := newTestAuth() + + _, err := a.Register("karl", "pass") + require.NoError(t, err) + userID := lthn.Hash("karl") + + session, err := a.Login(userID, "pass") + require.NoError(t, err) + + err = a.RevokeSession(session.Token) + require.NoError(t, err) + + // Token should no longer be valid + _, err = a.ValidateSession(session.Token) + assert.Error(t, err) +} + +func TestRevokeSession_Bad(t *testing.T) { + a, _ := newTestAuth() + + err := a.RevokeSession("nonexistent-token") + assert.Error(t, err) + assert.Contains(t, err.Error(), "session not found") +} + +func TestRevokeSession_Ugly(t *testing.T) { + a, _ := newTestAuth() + + // Revoke empty token + err := a.RevokeSession("") + assert.Error(t, err) +} + +// --- DeleteUser --- + +func TestDeleteUser_Good(t *testing.T) { + a, m := newTestAuth() + + _, err := a.Register("larry", "pass") + require.NoError(t, err) + userID := lthn.Hash("larry") + + // Also create a session that should be cleaned up + _, err = a.Login(userID, "pass") + require.NoError(t, err) + + err = a.DeleteUser(userID) + require.NoError(t, err) + + // All files should be gone + assert.False(t, m.IsFile(userPath(userID, ".pub"))) + assert.False(t, m.IsFile(userPath(userID, ".key"))) + assert.False(t, m.IsFile(userPath(userID, ".rev"))) + assert.False(t, m.IsFile(userPath(userID, ".json"))) + assert.False(t, m.IsFile(userPath(userID, ".lthn"))) + + // Session should be gone + a.mu.RLock() + sessionCount := 0 + for _, s := range a.sessions { + if s.UserID == userID { + sessionCount++ + } + } + a.mu.RUnlock() + assert.Equal(t, 0, sessionCount) +} + +func TestDeleteUser_Bad(t *testing.T) { + a, _ := newTestAuth() + + // Protected user "server" cannot be deleted + err := a.DeleteUser("server") + assert.Error(t, err) + assert.Contains(t, err.Error(), "cannot delete protected user") +} + +func TestDeleteUser_Ugly(t *testing.T) { + a, _ := newTestAuth() + + // Non-existent user + err := a.DeleteUser("nonexistent-user-id") + assert.Error(t, err) + assert.Contains(t, err.Error(), "user not found") +} + +// --- Login --- + +func TestLogin_Good(t *testing.T) { + a, _ := newTestAuth() + + _, err := a.Register("mallory", "secret") + require.NoError(t, err) + userID := lthn.Hash("mallory") + + session, err := a.Login(userID, "secret") + require.NoError(t, err) + require.NotNil(t, session) + + assert.NotEmpty(t, session.Token) + assert.Equal(t, userID, session.UserID) + assert.True(t, session.ExpiresAt.After(time.Now())) +} + +func TestLogin_Bad(t *testing.T) { + a, _ := newTestAuth() + + _, err := a.Register("nancy", "correct-password") + require.NoError(t, err) + userID := lthn.Hash("nancy") + + // Wrong password + _, err = a.Login(userID, "wrong-password") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid password") +} + +func TestLogin_Ugly(t *testing.T) { + a, _ := newTestAuth() + + // Login for non-existent user + _, err := a.Login("nonexistent-user-id", "pass") + assert.Error(t, err) + assert.Contains(t, err.Error(), "user not found") +} + +// --- WriteChallengeFile / ReadResponseFile (Air-Gapped) --- + +func TestAirGappedFlow_Good(t *testing.T) { + a, m := newTestAuth() + + _, err := a.Register("oscar", "airgap-pass") + require.NoError(t, err) + userID := lthn.Hash("oscar") + + // Write challenge to file + challengePath := "transfer/challenge.json" + err = a.WriteChallengeFile(userID, challengePath) + require.NoError(t, err) + assert.True(t, m.IsFile(challengePath)) + + // Read challenge file to get the encrypted nonce (simulating courier) + challengeData, err := m.Read(challengePath) + require.NoError(t, err) + + var challenge Challenge + err = json.Unmarshal([]byte(challengeData), &challenge) + require.NoError(t, err) + + // Client-side: decrypt nonce and sign it + privKey, err := m.Read(userPath(userID, ".key")) + require.NoError(t, err) + + decryptedNonce, err := pgp.Decrypt([]byte(challenge.Encrypted), privKey, "airgap-pass") + require.NoError(t, err) + + signedNonce, err := pgp.Sign(decryptedNonce, privKey, "airgap-pass") + require.NoError(t, err) + + // Write signed response to file + responsePath := "transfer/response.sig" + err = m.Write(responsePath, string(signedNonce)) + require.NoError(t, err) + + // Server reads response file + session, err := a.ReadResponseFile(userID, responsePath) + require.NoError(t, err) + require.NotNil(t, session) + + assert.NotEmpty(t, session.Token) + assert.Equal(t, userID, session.UserID) +} + +func TestWriteChallengeFile_Bad(t *testing.T) { + a, _ := newTestAuth() + + // Challenge for non-existent user + err := a.WriteChallengeFile("nonexistent-user", "challenge.json") + assert.Error(t, err) +} + +func TestReadResponseFile_Bad(t *testing.T) { + a, _ := newTestAuth() + + // Response file does not exist + _, err := a.ReadResponseFile("some-user", "nonexistent-file.sig") + assert.Error(t, err) +} + +func TestReadResponseFile_Ugly(t *testing.T) { + a, m := newTestAuth() + + _, err := a.Register("peggy", "pass") + require.NoError(t, err) + userID := lthn.Hash("peggy") + + // Create a challenge + _, err = a.CreateChallenge(userID) + require.NoError(t, err) + + // Write garbage to response file + responsePath := "transfer/bad-response.sig" + err = m.Write(responsePath, "not-a-valid-signature") + require.NoError(t, err) + + _, err = a.ReadResponseFile(userID, responsePath) + assert.Error(t, err) +} + +// --- Options --- + +func TestWithChallengeTTL_Good(t *testing.T) { + ttl := 30 * time.Second + a, _ := newTestAuth(WithChallengeTTL(ttl)) + assert.Equal(t, ttl, a.challengeTTL) +} + +func TestWithSessionTTL_Good(t *testing.T) { + ttl := 2 * time.Hour + a, _ := newTestAuth(WithSessionTTL(ttl)) + assert.Equal(t, ttl, a.sessionTTL) +} + +// --- Full Round-Trip (Online Flow) --- + +func TestFullRoundTrip_Good(t *testing.T) { + a, m := newTestAuth() + + // 1. Register + user, err := a.Register("quinn", "roundtrip-pass") + require.NoError(t, err) + require.NotNil(t, user) + + userID := lthn.Hash("quinn") + + // 2. Create challenge + challenge, err := a.CreateChallenge(userID) + require.NoError(t, err) + + // 3. Client decrypts + signs + privKey, err := m.Read(userPath(userID, ".key")) + require.NoError(t, err) + + nonce, err := pgp.Decrypt([]byte(challenge.Encrypted), privKey, "roundtrip-pass") + require.NoError(t, err) + + sig, err := pgp.Sign(nonce, privKey, "roundtrip-pass") + require.NoError(t, err) + + // 4. Server validates, issues session + session, err := a.ValidateResponse(userID, sig) + require.NoError(t, err) + require.NotNil(t, session) + + // 5. Validate session + validated, err := a.ValidateSession(session.Token) + require.NoError(t, err) + assert.Equal(t, session.Token, validated.Token) + + // 6. Refresh session + refreshed, err := a.RefreshSession(session.Token) + require.NoError(t, err) + assert.Equal(t, session.Token, refreshed.Token) + + // 7. Revoke session + err = a.RevokeSession(session.Token) + require.NoError(t, err) + + // 8. Session should be invalid now + _, err = a.ValidateSession(session.Token) + assert.Error(t, err) +} + +// --- Concurrent Access --- + +func TestConcurrentSessions_Good(t *testing.T) { + a, _ := newTestAuth() + + _, err := a.Register("ruth", "pass") + require.NoError(t, err) + userID := lthn.Hash("ruth") + + // Create multiple sessions concurrently + const n = 10 + sessions := make(chan *Session, n) + errs := make(chan error, n) + + for i := 0; i < n; i++ { + go func() { + s, err := a.Login(userID, "pass") + if err != nil { + errs <- err + return + } + sessions <- s + }() + } + + for i := 0; i < n; i++ { + select { + case s := <-sessions: + require.NotNil(t, s) + // Validate each session + _, err := a.ValidateSession(s.Token) + assert.NoError(t, err) + case err := <-errs: + t.Fatalf("concurrent login failed: %v", err) + } + } +} diff --git a/crypt/chachapoly/chachapoly.go b/crypt/chachapoly/chachapoly.go new file mode 100644 index 0000000..2520c67 --- /dev/null +++ b/crypt/chachapoly/chachapoly.go @@ -0,0 +1,50 @@ +package chachapoly + +import ( + "crypto/rand" + "fmt" + "io" + + "golang.org/x/crypto/chacha20poly1305" +) + +// Encrypt encrypts data using ChaCha20-Poly1305. +func Encrypt(plaintext []byte, key []byte) ([]byte, error) { + aead, err := chacha20poly1305.NewX(key) + if err != nil { + return nil, err + } + + nonce := make([]byte, aead.NonceSize(), aead.NonceSize()+len(plaintext)+aead.Overhead()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + return aead.Seal(nonce, nonce, plaintext, nil), nil +} + +// Decrypt decrypts data using ChaCha20-Poly1305. +func Decrypt(ciphertext []byte, key []byte) ([]byte, error) { + aead, err := chacha20poly1305.NewX(key) + if err != nil { + return nil, err + } + + minLen := aead.NonceSize() + aead.Overhead() + if len(ciphertext) < minLen { + return nil, fmt.Errorf("ciphertext too short: got %d bytes, need at least %d bytes", len(ciphertext), minLen) + } + + nonce, ciphertext := ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():] + + decrypted, err := aead.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, err + } + + if len(decrypted) == 0 { + return []byte{}, nil + } + + return decrypted, nil +} diff --git a/crypt/chachapoly/chachapoly_test.go b/crypt/chachapoly/chachapoly_test.go new file mode 100644 index 0000000..1123f2c --- /dev/null +++ b/crypt/chachapoly/chachapoly_test.go @@ -0,0 +1,114 @@ +package chachapoly + +import ( + "crypto/rand" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +// mockReader is a reader that returns an error. +type mockReader struct{} + +func (r *mockReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error") +} + +func TestEncryptDecrypt(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = 1 + } + + plaintext := []byte("Hello, world!") + ciphertext, err := Encrypt(plaintext, key) + assert.NoError(t, err) + + decrypted, err := Decrypt(ciphertext, key) + assert.NoError(t, err) + + assert.Equal(t, plaintext, decrypted) +} + +func TestEncryptInvalidKeySize(t *testing.T) { + key := make([]byte, 16) // Wrong size + plaintext := []byte("test") + _, err := Encrypt(plaintext, key) + assert.Error(t, err) +} + +func TestDecryptWithWrongKey(t *testing.T) { + key1 := make([]byte, 32) + key2 := make([]byte, 32) + key2[0] = 1 // Different key + + plaintext := []byte("secret") + ciphertext, err := Encrypt(plaintext, key1) + assert.NoError(t, err) + + _, err = Decrypt(ciphertext, key2) + assert.Error(t, err) // Should fail authentication +} + +func TestDecryptTamperedCiphertext(t *testing.T) { + key := make([]byte, 32) + plaintext := []byte("secret") + ciphertext, err := Encrypt(plaintext, key) + assert.NoError(t, err) + + // Tamper with the ciphertext + ciphertext[0] ^= 0xff + + _, err = Decrypt(ciphertext, key) + assert.Error(t, err) +} + +func TestEncryptEmptyPlaintext(t *testing.T) { + key := make([]byte, 32) + plaintext := []byte("") + ciphertext, err := Encrypt(plaintext, key) + assert.NoError(t, err) + + decrypted, err := Decrypt(ciphertext, key) + assert.NoError(t, err) + + assert.Equal(t, plaintext, decrypted) +} + +func TestDecryptShortCiphertext(t *testing.T) { + key := make([]byte, 32) + shortCiphertext := []byte("short") + + _, err := Decrypt(shortCiphertext, key) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too short") +} + +func TestCiphertextDiffersFromPlaintext(t *testing.T) { + key := make([]byte, 32) + plaintext := []byte("Hello, world!") + ciphertext, err := Encrypt(plaintext, key) + assert.NoError(t, err) + assert.NotEqual(t, plaintext, ciphertext) +} + +func TestEncryptNonceError(t *testing.T) { + key := make([]byte, 32) + plaintext := []byte("test") + + // Replace the rand.Reader with our mock reader + oldReader := rand.Reader + rand.Reader = &mockReader{} + defer func() { rand.Reader = oldReader }() + + _, err := Encrypt(plaintext, key) + assert.Error(t, err) +} + +func TestDecryptInvalidKeySize(t *testing.T) { + key := make([]byte, 16) // Wrong size + ciphertext := []byte("test") + _, err := Decrypt(ciphertext, key) + assert.Error(t, err) +} diff --git a/crypt/checksum.go b/crypt/checksum.go new file mode 100644 index 0000000..ddf501f --- /dev/null +++ b/crypt/checksum.go @@ -0,0 +1,55 @@ +package crypt + +import ( + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "io" + "os" + + core "forge.lthn.ai/core/go/pkg/framework/core" +) + +// SHA256File computes the SHA-256 checksum of a file and returns it as a hex string. +func SHA256File(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", core.E("crypt.SHA256File", "failed to open file", err) + } + defer func() { _ = f.Close() }() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", core.E("crypt.SHA256File", "failed to read file", err) + } + + return hex.EncodeToString(h.Sum(nil)), nil +} + +// SHA512File computes the SHA-512 checksum of a file and returns it as a hex string. +func SHA512File(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", core.E("crypt.SHA512File", "failed to open file", err) + } + defer func() { _ = f.Close() }() + + h := sha512.New() + if _, err := io.Copy(h, f); err != nil { + return "", core.E("crypt.SHA512File", "failed to read file", err) + } + + return hex.EncodeToString(h.Sum(nil)), nil +} + +// SHA256Sum computes the SHA-256 checksum of data and returns it as a hex string. +func SHA256Sum(data []byte) string { + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]) +} + +// SHA512Sum computes the SHA-512 checksum of data and returns it as a hex string. +func SHA512Sum(data []byte) string { + h := sha512.Sum512(data) + return hex.EncodeToString(h[:]) +} diff --git a/crypt/checksum_test.go b/crypt/checksum_test.go new file mode 100644 index 0000000..ce98b3b --- /dev/null +++ b/crypt/checksum_test.go @@ -0,0 +1,23 @@ +package crypt + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSHA256Sum_Good(t *testing.T) { + data := []byte("hello") + expected := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824" + + result := SHA256Sum(data) + assert.Equal(t, expected, result) +} + +func TestSHA512Sum_Good(t *testing.T) { + data := []byte("hello") + expected := "9b71d224bd62f3785d96d46ad3ea3d73319bfbc2890caadae2dff72519673ca72323c3d99ba5c11d7c7acc6e14b8c5da0c4663475c2e5c3adef46f73bcdec043" + + result := SHA512Sum(data) + assert.Equal(t, expected, result) +} diff --git a/crypt/crypt.go b/crypt/crypt.go new file mode 100644 index 0000000..a73f0ad --- /dev/null +++ b/crypt/crypt.go @@ -0,0 +1,90 @@ +package crypt + +import ( + core "forge.lthn.ai/core/go/pkg/framework/core" +) + +// Encrypt encrypts data with a passphrase using ChaCha20-Poly1305. +// A random salt is generated and prepended to the output. +// Format: salt (16 bytes) + nonce (24 bytes) + ciphertext. +func Encrypt(plaintext, passphrase []byte) ([]byte, error) { + salt, err := generateSalt(argon2SaltLen) + if err != nil { + return nil, core.E("crypt.Encrypt", "failed to generate salt", err) + } + + key := DeriveKey(passphrase, salt, argon2KeyLen) + + encrypted, err := ChaCha20Encrypt(plaintext, key) + if err != nil { + return nil, core.E("crypt.Encrypt", "failed to encrypt", err) + } + + // Prepend salt to the encrypted data (which already has nonce prepended) + result := make([]byte, 0, len(salt)+len(encrypted)) + result = append(result, salt...) + result = append(result, encrypted...) + return result, nil +} + +// Decrypt decrypts data encrypted with Encrypt. +// Expects format: salt (16 bytes) + nonce (24 bytes) + ciphertext. +func Decrypt(ciphertext, passphrase []byte) ([]byte, error) { + if len(ciphertext) < argon2SaltLen { + return nil, core.E("crypt.Decrypt", "ciphertext too short", nil) + } + + salt := ciphertext[:argon2SaltLen] + encrypted := ciphertext[argon2SaltLen:] + + key := DeriveKey(passphrase, salt, argon2KeyLen) + + plaintext, err := ChaCha20Decrypt(encrypted, key) + if err != nil { + return nil, core.E("crypt.Decrypt", "failed to decrypt", err) + } + + return plaintext, nil +} + +// EncryptAES encrypts data using AES-256-GCM with a passphrase. +// A random salt is generated and prepended to the output. +// Format: salt (16 bytes) + nonce (12 bytes) + ciphertext. +func EncryptAES(plaintext, passphrase []byte) ([]byte, error) { + salt, err := generateSalt(argon2SaltLen) + if err != nil { + return nil, core.E("crypt.EncryptAES", "failed to generate salt", err) + } + + key := DeriveKey(passphrase, salt, argon2KeyLen) + + encrypted, err := AESGCMEncrypt(plaintext, key) + if err != nil { + return nil, core.E("crypt.EncryptAES", "failed to encrypt", err) + } + + result := make([]byte, 0, len(salt)+len(encrypted)) + result = append(result, salt...) + result = append(result, encrypted...) + return result, nil +} + +// DecryptAES decrypts data encrypted with EncryptAES. +// Expects format: salt (16 bytes) + nonce (12 bytes) + ciphertext. +func DecryptAES(ciphertext, passphrase []byte) ([]byte, error) { + if len(ciphertext) < argon2SaltLen { + return nil, core.E("crypt.DecryptAES", "ciphertext too short", nil) + } + + salt := ciphertext[:argon2SaltLen] + encrypted := ciphertext[argon2SaltLen:] + + key := DeriveKey(passphrase, salt, argon2KeyLen) + + plaintext, err := AESGCMDecrypt(encrypted, key) + if err != nil { + return nil, core.E("crypt.DecryptAES", "failed to decrypt", err) + } + + return plaintext, nil +} diff --git a/crypt/crypt_test.go b/crypt/crypt_test.go new file mode 100644 index 0000000..b2e7a56 --- /dev/null +++ b/crypt/crypt_test.go @@ -0,0 +1,45 @@ +package crypt + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEncryptDecrypt_Good(t *testing.T) { + plaintext := []byte("hello, world!") + passphrase := []byte("correct-horse-battery-staple") + + encrypted, err := Encrypt(plaintext, passphrase) + assert.NoError(t, err) + assert.NotEqual(t, plaintext, encrypted) + + decrypted, err := Decrypt(encrypted, passphrase) + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +} + +func TestEncryptDecrypt_Bad(t *testing.T) { + plaintext := []byte("secret data") + passphrase := []byte("correct-passphrase") + wrongPassphrase := []byte("wrong-passphrase") + + encrypted, err := Encrypt(plaintext, passphrase) + assert.NoError(t, err) + + _, err = Decrypt(encrypted, wrongPassphrase) + assert.Error(t, err) +} + +func TestEncryptDecryptAES_Good(t *testing.T) { + plaintext := []byte("hello, AES world!") + passphrase := []byte("my-secure-passphrase") + + encrypted, err := EncryptAES(plaintext, passphrase) + assert.NoError(t, err) + assert.NotEqual(t, plaintext, encrypted) + + decrypted, err := DecryptAES(encrypted, passphrase) + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +} diff --git a/crypt/hash.go b/crypt/hash.go new file mode 100644 index 0000000..9b1273d --- /dev/null +++ b/crypt/hash.go @@ -0,0 +1,89 @@ +package crypt + +import ( + "crypto/subtle" + "encoding/base64" + "fmt" + "strings" + + core "forge.lthn.ai/core/go/pkg/framework/core" + "golang.org/x/crypto/argon2" + "golang.org/x/crypto/bcrypt" +) + +// HashPassword hashes a password using Argon2id with default parameters. +// Returns a string in the format: $argon2id$v=19$m=65536,t=3,p=4$$ +func HashPassword(password string) (string, error) { + salt, err := generateSalt(argon2SaltLen) + if err != nil { + return "", core.E("crypt.HashPassword", "failed to generate salt", err) + } + + hash := argon2.IDKey([]byte(password), salt, argon2Time, argon2Memory, argon2Parallelism, argon2KeyLen) + + b64Salt := base64.RawStdEncoding.EncodeToString(salt) + b64Hash := base64.RawStdEncoding.EncodeToString(hash) + + encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", + argon2.Version, argon2Memory, argon2Time, argon2Parallelism, + b64Salt, b64Hash) + + return encoded, nil +} + +// VerifyPassword verifies a password against an Argon2id hash string. +// The hash must be in the format produced by HashPassword. +func VerifyPassword(password, hash string) (bool, error) { + parts := strings.Split(hash, "$") + if len(parts) != 6 { + return false, core.E("crypt.VerifyPassword", "invalid hash format", nil) + } + + var version int + if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil { + return false, core.E("crypt.VerifyPassword", "failed to parse version", err) + } + + var memory uint32 + var time uint32 + var parallelism uint8 + if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, ¶llelism); err != nil { + return false, core.E("crypt.VerifyPassword", "failed to parse parameters", err) + } + + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return false, core.E("crypt.VerifyPassword", "failed to decode salt", err) + } + + expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return false, core.E("crypt.VerifyPassword", "failed to decode hash", err) + } + + computedHash := argon2.IDKey([]byte(password), salt, time, memory, parallelism, uint32(len(expectedHash))) + + return subtle.ConstantTimeCompare(computedHash, expectedHash) == 1, nil +} + +// HashBcrypt hashes a password using bcrypt with the given cost. +// Cost must be between bcrypt.MinCost and bcrypt.MaxCost. +func HashBcrypt(password string, cost int) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), cost) + if err != nil { + return "", core.E("crypt.HashBcrypt", "failed to hash password", err) + } + return string(hash), nil +} + +// VerifyBcrypt verifies a password against a bcrypt hash. +func VerifyBcrypt(password, hash string) (bool, error) { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + if err == bcrypt.ErrMismatchedHashAndPassword { + return false, nil + } + if err != nil { + return false, core.E("crypt.VerifyBcrypt", "failed to verify password", err) + } + return true, nil +} diff --git a/crypt/hash_test.go b/crypt/hash_test.go new file mode 100644 index 0000000..ad308a0 --- /dev/null +++ b/crypt/hash_test.go @@ -0,0 +1,50 @@ +package crypt + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" +) + +func TestHashPassword_Good(t *testing.T) { + password := "my-secure-password" + + hash, err := HashPassword(password) + assert.NoError(t, err) + assert.NotEmpty(t, hash) + assert.Contains(t, hash, "$argon2id$") + + match, err := VerifyPassword(password, hash) + assert.NoError(t, err) + assert.True(t, match) +} + +func TestVerifyPassword_Bad(t *testing.T) { + password := "my-secure-password" + wrongPassword := "wrong-password" + + hash, err := HashPassword(password) + assert.NoError(t, err) + + match, err := VerifyPassword(wrongPassword, hash) + assert.NoError(t, err) + assert.False(t, match) +} + +func TestHashBcrypt_Good(t *testing.T) { + password := "bcrypt-test-password" + + hash, err := HashBcrypt(password, bcrypt.DefaultCost) + assert.NoError(t, err) + assert.NotEmpty(t, hash) + + match, err := VerifyBcrypt(password, hash) + assert.NoError(t, err) + assert.True(t, match) + + // Wrong password should not match + match, err = VerifyBcrypt("wrong-password", hash) + assert.NoError(t, err) + assert.False(t, match) +} diff --git a/crypt/hmac.go b/crypt/hmac.go new file mode 100644 index 0000000..adb80c2 --- /dev/null +++ b/crypt/hmac.go @@ -0,0 +1,30 @@ +package crypt + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/sha512" + "hash" +) + +// HMACSHA256 computes the HMAC-SHA256 of a message using the given key. +func HMACSHA256(message, key []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(message) + return mac.Sum(nil) +} + +// HMACSHA512 computes the HMAC-SHA512 of a message using the given key. +func HMACSHA512(message, key []byte) []byte { + mac := hmac.New(sha512.New, key) + mac.Write(message) + return mac.Sum(nil) +} + +// VerifyHMAC verifies an HMAC using constant-time comparison. +// hashFunc should be sha256.New, sha512.New, etc. +func VerifyHMAC(message, key, mac []byte, hashFunc func() hash.Hash) bool { + expected := hmac.New(hashFunc, key) + expected.Write(message) + return hmac.Equal(mac, expected.Sum(nil)) +} diff --git a/crypt/hmac_test.go b/crypt/hmac_test.go new file mode 100644 index 0000000..31dc474 --- /dev/null +++ b/crypt/hmac_test.go @@ -0,0 +1,40 @@ +package crypt + +import ( + "crypto/sha256" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHMACSHA256_Good(t *testing.T) { + // RFC 4231 Test Case 2 + key := []byte("Jefe") + message := []byte("what do ya want for nothing?") + expected := "5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843" + + mac := HMACSHA256(message, key) + assert.Equal(t, expected, hex.EncodeToString(mac)) +} + +func TestVerifyHMAC_Good(t *testing.T) { + key := []byte("secret-key") + message := []byte("test message") + + mac := HMACSHA256(message, key) + + valid := VerifyHMAC(message, key, mac, sha256.New) + assert.True(t, valid) +} + +func TestVerifyHMAC_Bad(t *testing.T) { + key := []byte("secret-key") + message := []byte("test message") + tampered := []byte("tampered message") + + mac := HMACSHA256(message, key) + + valid := VerifyHMAC(tampered, key, mac, sha256.New) + assert.False(t, valid) +} diff --git a/crypt/kdf.go b/crypt/kdf.go new file mode 100644 index 0000000..71fdff4 --- /dev/null +++ b/crypt/kdf.go @@ -0,0 +1,60 @@ +// Package crypt provides cryptographic utilities including encryption, +// hashing, key derivation, HMAC, and checksum functions. +package crypt + +import ( + "crypto/rand" + "crypto/sha256" + "io" + + core "forge.lthn.ai/core/go/pkg/framework/core" + "golang.org/x/crypto/argon2" + "golang.org/x/crypto/hkdf" + "golang.org/x/crypto/scrypt" +) + +// Argon2id default parameters. +const ( + argon2Memory = 64 * 1024 // 64 MB + argon2Time = 3 + argon2Parallelism = 4 + argon2KeyLen = 32 + argon2SaltLen = 16 +) + +// DeriveKey derives a key from a passphrase using Argon2id with default parameters. +// The salt must be argon2SaltLen bytes. keyLen specifies the desired key length. +func DeriveKey(passphrase, salt []byte, keyLen uint32) []byte { + return argon2.IDKey(passphrase, salt, argon2Time, argon2Memory, argon2Parallelism, keyLen) +} + +// DeriveKeyScrypt derives a key from a passphrase using scrypt. +// Uses recommended parameters: N=32768, r=8, p=1. +func DeriveKeyScrypt(passphrase, salt []byte, keyLen int) ([]byte, error) { + key, err := scrypt.Key(passphrase, salt, 32768, 8, 1, keyLen) + if err != nil { + return nil, core.E("crypt.DeriveKeyScrypt", "failed to derive key", err) + } + return key, nil +} + +// HKDF derives a key using HKDF-SHA256. +// secret is the input keying material, salt is optional (can be nil), +// info is optional context, and keyLen is the desired output length. +func HKDF(secret, salt, info []byte, keyLen int) ([]byte, error) { + reader := hkdf.New(sha256.New, secret, salt, info) + key := make([]byte, keyLen) + if _, err := io.ReadFull(reader, key); err != nil { + return nil, core.E("crypt.HKDF", "failed to derive key", err) + } + return key, nil +} + +// generateSalt creates a random salt of the given length. +func generateSalt(length int) ([]byte, error) { + salt := make([]byte, length) + if _, err := rand.Read(salt); err != nil { + return nil, core.E("crypt.generateSalt", "failed to generate random salt", err) + } + return salt, nil +} diff --git a/crypt/kdf_test.go b/crypt/kdf_test.go new file mode 100644 index 0000000..08ee76d --- /dev/null +++ b/crypt/kdf_test.go @@ -0,0 +1,56 @@ +package crypt + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDeriveKey_Good(t *testing.T) { + passphrase := []byte("test-passphrase") + salt := []byte("1234567890123456") // 16 bytes + + key1 := DeriveKey(passphrase, salt, 32) + key2 := DeriveKey(passphrase, salt, 32) + + assert.Len(t, key1, 32) + assert.Equal(t, key1, key2, "same inputs should produce same output") + + // Different passphrase should produce different key + key3 := DeriveKey([]byte("different-passphrase"), salt, 32) + assert.NotEqual(t, key1, key3) +} + +func TestDeriveKeyScrypt_Good(t *testing.T) { + passphrase := []byte("test-passphrase") + salt := []byte("1234567890123456") + + key, err := DeriveKeyScrypt(passphrase, salt, 32) + assert.NoError(t, err) + assert.Len(t, key, 32) + + // Deterministic + key2, err := DeriveKeyScrypt(passphrase, salt, 32) + assert.NoError(t, err) + assert.Equal(t, key, key2) +} + +func TestHKDF_Good(t *testing.T) { + secret := []byte("input-keying-material") + salt := []byte("optional-salt") + info := []byte("context-info") + + key1, err := HKDF(secret, salt, info, 32) + assert.NoError(t, err) + assert.Len(t, key1, 32) + + // Deterministic + key2, err := HKDF(secret, salt, info, 32) + assert.NoError(t, err) + assert.Equal(t, key1, key2) + + // Different info should produce different key + key3, err := HKDF(secret, salt, []byte("different-info"), 32) + assert.NoError(t, err) + assert.NotEqual(t, key1, key3) +} diff --git a/crypt/lthn/lthn.go b/crypt/lthn/lthn.go new file mode 100644 index 0000000..a9c04ef --- /dev/null +++ b/crypt/lthn/lthn.go @@ -0,0 +1,94 @@ +// Package lthn implements the LTHN quasi-salted hash algorithm (RFC-0004). +// +// LTHN produces deterministic, verifiable hashes without requiring separate salt +// storage. The salt is derived from the input itself through: +// 1. Reversing the input string +// 2. Applying "leet speak" style character substitutions +// +// The final hash is: SHA256(input || derived_salt) +// +// This is suitable for content identifiers, cache keys, and deduplication. +// NOT suitable for password hashing - use bcrypt, Argon2, or scrypt instead. +// +// Example: +// +// hash := lthn.Hash("hello") +// valid := lthn.Verify("hello", hash) // true +package lthn + +import ( + "crypto/sha256" + "encoding/hex" +) + +// keyMap defines the character substitutions for quasi-salt derivation. +// These are inspired by "leet speak" conventions for letter-number substitution. +// The mapping is bidirectional for most characters but NOT fully symmetric. +var keyMap = map[rune]rune{ + 'o': '0', // letter O -> zero + 'l': '1', // letter L -> one + 'e': '3', // letter E -> three + 'a': '4', // letter A -> four + 's': 'z', // letter S -> Z + 't': '7', // letter T -> seven + '0': 'o', // zero -> letter O + '1': 'l', // one -> letter L + '3': 'e', // three -> letter E + '4': 'a', // four -> letter A + '7': 't', // seven -> letter T +} + +// SetKeyMap replaces the default character substitution map. +// Use this to customize the quasi-salt derivation for specific applications. +// Changes affect all subsequent Hash and Verify calls. +func SetKeyMap(newKeyMap map[rune]rune) { + keyMap = newKeyMap +} + +// GetKeyMap returns the current character substitution map. +func GetKeyMap() map[rune]rune { + return keyMap +} + +// Hash computes the LTHN hash of the input string. +// +// The algorithm: +// 1. Derive a quasi-salt by reversing the input and applying character substitutions +// 2. Concatenate: input + salt +// 3. Compute SHA-256 of the concatenated string +// 4. Return the hex-encoded digest (64 characters, lowercase) +// +// The same input always produces the same hash, enabling verification +// without storing a separate salt value. +func Hash(input string) string { + salt := createSalt(input) + hash := sha256.Sum256([]byte(input + salt)) + return hex.EncodeToString(hash[:]) +} + +// createSalt derives a quasi-salt by reversing the input and applying substitutions. +// For example: "hello" -> reversed "olleh" -> substituted "011eh" +func createSalt(input string) string { + if input == "" { + return "" + } + runes := []rune(input) + salt := make([]rune, len(runes)) + for i := 0; i < len(runes); i++ { + char := runes[len(runes)-1-i] + if replacement, ok := keyMap[char]; ok { + salt[i] = replacement + } else { + salt[i] = char + } + } + return string(salt) +} + +// Verify checks if an input string produces the given hash. +// Returns true if Hash(input) equals the provided hash value. +// Uses direct string comparison - for security-critical applications, +// consider using constant-time comparison. +func Verify(input string, hash string) bool { + return Hash(input) == hash +} diff --git a/crypt/lthn/lthn_test.go b/crypt/lthn/lthn_test.go new file mode 100644 index 0000000..da0d655 --- /dev/null +++ b/crypt/lthn/lthn_test.go @@ -0,0 +1,66 @@ +package lthn + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHash(t *testing.T) { + hash := Hash("hello") + assert.NotEmpty(t, hash) +} + +func TestVerify(t *testing.T) { + hash := Hash("hello") + assert.True(t, Verify("hello", hash)) + assert.False(t, Verify("world", hash)) +} + +func TestCreateSalt_Good(t *testing.T) { + // "hello" reversed: "olleh" -> "0113h" + expected := "0113h" + actual := createSalt("hello") + assert.Equal(t, expected, actual, "Salt should be correctly created for 'hello'") +} + +func TestCreateSalt_Bad(t *testing.T) { + // Test with an empty string + expected := "" + actual := createSalt("") + assert.Equal(t, expected, actual, "Salt for an empty string should be empty") +} + +func TestCreateSalt_Ugly(t *testing.T) { + // Test with characters not in the keyMap + input := "world123" + // "world123" reversed: "321dlrow" -> "e2ld1r0w" + expected := "e2ld1r0w" + actual := createSalt(input) + assert.Equal(t, expected, actual, "Salt should handle characters not in the keyMap") + + // Test with only characters in the keyMap + input = "oleta" + // "oleta" reversed: "atelo" -> "47310" + expected = "47310" + actual = createSalt(input) + assert.Equal(t, expected, actual, "Salt should correctly handle strings with only keyMap characters") +} + +var testKeyMapMu sync.Mutex + +func TestSetKeyMap(t *testing.T) { + testKeyMapMu.Lock() + originalKeyMap := GetKeyMap() + t.Cleanup(func() { + SetKeyMap(originalKeyMap) + testKeyMapMu.Unlock() + }) + + newKeyMap := map[rune]rune{ + 'a': 'b', + } + SetKeyMap(newKeyMap) + assert.Equal(t, newKeyMap, GetKeyMap()) +} diff --git a/crypt/openpgp/service.go b/crypt/openpgp/service.go new file mode 100644 index 0000000..5064ea9 --- /dev/null +++ b/crypt/openpgp/service.go @@ -0,0 +1,191 @@ +package openpgp + +import ( + "bytes" + "crypto" + goio "io" + "strings" + + "github.com/ProtonMail/go-crypto/openpgp" + "github.com/ProtonMail/go-crypto/openpgp/armor" + "github.com/ProtonMail/go-crypto/openpgp/packet" + core "forge.lthn.ai/core/go/pkg/framework/core" +) + +// Service implements the core.Crypt interface using OpenPGP. +type Service struct { + core *core.Core +} + +// New creates a new OpenPGP service instance. +func New(c *core.Core) (any, error) { + return &Service{core: c}, nil +} + +// CreateKeyPair generates a new RSA-4096 PGP keypair. +// Returns the armored private key string. +func (s *Service) CreateKeyPair(name, passphrase string) (string, error) { + config := &packet.Config{ + Algorithm: packet.PubKeyAlgoRSA, + RSABits: 4096, + DefaultHash: crypto.SHA256, + DefaultCipher: packet.CipherAES256, + } + + entity, err := openpgp.NewEntity(name, "Workspace Key", "", config) + if err != nil { + return "", core.E("openpgp.CreateKeyPair", "failed to create entity", err) + } + + // Encrypt private key if passphrase is provided + if passphrase != "" { + err = entity.PrivateKey.Encrypt([]byte(passphrase)) + if err != nil { + return "", core.E("openpgp.CreateKeyPair", "failed to encrypt private key", err) + } + for _, subkey := range entity.Subkeys { + err = subkey.PrivateKey.Encrypt([]byte(passphrase)) + if err != nil { + return "", core.E("openpgp.CreateKeyPair", "failed to encrypt subkey", err) + } + } + } + + var buf bytes.Buffer + w, err := armor.Encode(&buf, openpgp.PrivateKeyType, nil) + if err != nil { + return "", core.E("openpgp.CreateKeyPair", "failed to create armor encoder", err) + } + + // Manual serialization to avoid panic from re-signing encrypted keys + err = s.serializeEntity(w, entity) + if err != nil { + w.Close() + return "", core.E("openpgp.CreateKeyPair", "failed to serialize private key", err) + } + w.Close() + + return buf.String(), nil +} + +// serializeEntity manually serializes an OpenPGP entity to avoid re-signing. +func (s *Service) serializeEntity(w goio.Writer, e *openpgp.Entity) error { + err := e.PrivateKey.Serialize(w) + if err != nil { + return err + } + for _, ident := range e.Identities { + err = ident.UserId.Serialize(w) + if err != nil { + return err + } + err = ident.SelfSignature.Serialize(w) + if err != nil { + return err + } + } + for _, subkey := range e.Subkeys { + err = subkey.PrivateKey.Serialize(w) + if err != nil { + return err + } + err = subkey.Sig.Serialize(w) + if err != nil { + return err + } + } + return nil +} + +// EncryptPGP encrypts data for a recipient identified by their public key (armored string in recipientPath). +// The encrypted data is written to the provided writer and also returned as an armored string. +func (s *Service) EncryptPGP(writer goio.Writer, recipientPath, data string, opts ...any) (string, error) { + entityList, err := openpgp.ReadArmoredKeyRing(strings.NewReader(recipientPath)) + if err != nil { + return "", core.E("openpgp.EncryptPGP", "failed to read recipient key", err) + } + + var armoredBuf bytes.Buffer + armoredWriter, err := armor.Encode(&armoredBuf, "PGP MESSAGE", nil) + if err != nil { + return "", core.E("openpgp.EncryptPGP", "failed to create armor encoder", err) + } + + // MultiWriter to write to both the provided writer and our armored buffer + mw := goio.MultiWriter(writer, armoredWriter) + + w, err := openpgp.Encrypt(mw, entityList, nil, nil, nil) + if err != nil { + armoredWriter.Close() + return "", core.E("openpgp.EncryptPGP", "failed to start encryption", err) + } + + _, err = goio.WriteString(w, data) + if err != nil { + w.Close() + armoredWriter.Close() + return "", core.E("openpgp.EncryptPGP", "failed to write data", err) + } + + w.Close() + armoredWriter.Close() + + return armoredBuf.String(), nil +} + +// DecryptPGP decrypts a PGP message using the provided armored private key and passphrase. +func (s *Service) DecryptPGP(privateKey, message, passphrase string, opts ...any) (string, error) { + entityList, err := openpgp.ReadArmoredKeyRing(strings.NewReader(privateKey)) + if err != nil { + return "", core.E("openpgp.DecryptPGP", "failed to read private key", err) + } + + entity := entityList[0] + if entity.PrivateKey.Encrypted { + err = entity.PrivateKey.Decrypt([]byte(passphrase)) + if err != nil { + return "", core.E("openpgp.DecryptPGP", "failed to decrypt private key", err) + } + for _, subkey := range entity.Subkeys { + _ = subkey.PrivateKey.Decrypt([]byte(passphrase)) + } + } + + // Decrypt armored message + block, err := armor.Decode(strings.NewReader(message)) + if err != nil { + return "", core.E("openpgp.DecryptPGP", "failed to decode armored message", err) + } + + md, err := openpgp.ReadMessage(block.Body, entityList, nil, nil) + if err != nil { + return "", core.E("openpgp.DecryptPGP", "failed to read message", err) + } + + var buf bytes.Buffer + _, err = goio.Copy(&buf, md.UnverifiedBody) + if err != nil { + return "", core.E("openpgp.DecryptPGP", "failed to read decrypted body", err) + } + + return buf.String(), nil +} + +// HandleIPCEvents handles PGP-related IPC messages. +func (s *Service) HandleIPCEvents(c *core.Core, msg core.Message) error { + switch m := msg.(type) { + case map[string]any: + action, _ := m["action"].(string) + switch action { + case "openpgp.create_key_pair": + name, _ := m["name"].(string) + passphrase, _ := m["passphrase"].(string) + _, err := s.CreateKeyPair(name, passphrase) + return err + } + } + return nil +} + +// Ensure Service implements core.Crypt. +var _ core.Crypt = (*Service)(nil) diff --git a/crypt/openpgp/service_test.go b/crypt/openpgp/service_test.go new file mode 100644 index 0000000..b74c334 --- /dev/null +++ b/crypt/openpgp/service_test.go @@ -0,0 +1,43 @@ +package openpgp + +import ( + "bytes" + "testing" + + core "forge.lthn.ai/core/go/pkg/framework/core" + "github.com/stretchr/testify/assert" +) + +func TestCreateKeyPair(t *testing.T) { + c, _ := core.New() + s := &Service{core: c} + + privKey, err := s.CreateKeyPair("test user", "password123") + assert.NoError(t, err) + assert.NotEmpty(t, privKey) + assert.Contains(t, privKey, "-----BEGIN PGP PRIVATE KEY BLOCK-----") +} + +func TestEncryptDecrypt(t *testing.T) { + c, _ := core.New() + s := &Service{core: c} + + passphrase := "secret" + privKey, err := s.CreateKeyPair("test user", passphrase) + assert.NoError(t, err) + + // In this simple test, the public key is also in the armored private key string + // (openpgp.ReadArmoredKeyRing reads both) + publicKey := privKey + + data := "hello openpgp" + var buf bytes.Buffer + armored, err := s.EncryptPGP(&buf, publicKey, data) + assert.NoError(t, err) + assert.NotEmpty(t, armored) + assert.NotEmpty(t, buf.String()) + + decrypted, err := s.DecryptPGP(privKey, armored, passphrase) + assert.NoError(t, err) + assert.Equal(t, data, decrypted) +} diff --git a/crypt/pgp/pgp.go b/crypt/pgp/pgp.go new file mode 100644 index 0000000..d5c93b9 --- /dev/null +++ b/crypt/pgp/pgp.go @@ -0,0 +1,230 @@ +// Package pgp provides OpenPGP key generation, encryption, decryption, +// signing, and verification using the ProtonMail go-crypto library. +// +// Ported from Enchantrix (github.com/Snider/Enchantrix/pkg/crypt/std/pgp). +package pgp + +import ( + "bytes" + "fmt" + "io" + + "github.com/ProtonMail/go-crypto/openpgp" + "github.com/ProtonMail/go-crypto/openpgp/armor" + "github.com/ProtonMail/go-crypto/openpgp/packet" +) + +// KeyPair holds armored PGP public and private keys. +type KeyPair struct { + PublicKey string + PrivateKey string +} + +// CreateKeyPair generates a new PGP key pair for the given identity. +// If password is non-empty, the private key is encrypted with it. +// Returns a KeyPair with armored public and private keys. +func CreateKeyPair(name, email, password string) (*KeyPair, error) { + entity, err := openpgp.NewEntity(name, "", email, nil) + if err != nil { + return nil, fmt.Errorf("pgp: failed to create entity: %w", err) + } + + // Sign all the identities + for _, id := range entity.Identities { + _ = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil) + } + + // Encrypt private key with password if provided + if password != "" { + err = entity.PrivateKey.Encrypt([]byte(password)) + if err != nil { + return nil, fmt.Errorf("pgp: failed to encrypt private key: %w", err) + } + for _, subkey := range entity.Subkeys { + err = subkey.PrivateKey.Encrypt([]byte(password)) + if err != nil { + return nil, fmt.Errorf("pgp: failed to encrypt subkey: %w", err) + } + } + } + + // Serialize public key + pubKeyBuf := new(bytes.Buffer) + pubKeyWriter, err := armor.Encode(pubKeyBuf, openpgp.PublicKeyType, nil) + if err != nil { + return nil, fmt.Errorf("pgp: failed to create armored public key writer: %w", err) + } + if err := entity.Serialize(pubKeyWriter); err != nil { + pubKeyWriter.Close() + return nil, fmt.Errorf("pgp: failed to serialize public key: %w", err) + } + pubKeyWriter.Close() + + // Serialize private key + privKeyBuf := new(bytes.Buffer) + privKeyWriter, err := armor.Encode(privKeyBuf, openpgp.PrivateKeyType, nil) + if err != nil { + return nil, fmt.Errorf("pgp: failed to create armored private key writer: %w", err) + } + if password != "" { + // Manual serialization to avoid re-signing encrypted keys + if err := serializeEncryptedEntity(privKeyWriter, entity); err != nil { + privKeyWriter.Close() + return nil, fmt.Errorf("pgp: failed to serialize private key: %w", err) + } + } else { + if err := entity.SerializePrivate(privKeyWriter, nil); err != nil { + privKeyWriter.Close() + return nil, fmt.Errorf("pgp: failed to serialize private key: %w", err) + } + } + privKeyWriter.Close() + + return &KeyPair{ + PublicKey: pubKeyBuf.String(), + PrivateKey: privKeyBuf.String(), + }, nil +} + +// serializeEncryptedEntity manually serializes an entity with encrypted private keys +// to avoid the panic from re-signing encrypted keys. +func serializeEncryptedEntity(w io.Writer, e *openpgp.Entity) error { + if err := e.PrivateKey.Serialize(w); err != nil { + return err + } + for _, ident := range e.Identities { + if err := ident.UserId.Serialize(w); err != nil { + return err + } + if err := ident.SelfSignature.Serialize(w); err != nil { + return err + } + } + for _, subkey := range e.Subkeys { + if err := subkey.PrivateKey.Serialize(w); err != nil { + return err + } + if err := subkey.Sig.Serialize(w); err != nil { + return err + } + } + return nil +} + +// Encrypt encrypts data for the recipient identified by their armored public key. +// Returns the encrypted data as armored PGP output. +func Encrypt(data []byte, publicKeyArmor string) ([]byte, error) { + keyring, err := openpgp.ReadArmoredKeyRing(bytes.NewReader([]byte(publicKeyArmor))) + if err != nil { + return nil, fmt.Errorf("pgp: failed to read public key ring: %w", err) + } + + buf := new(bytes.Buffer) + armoredWriter, err := armor.Encode(buf, "PGP MESSAGE", nil) + if err != nil { + return nil, fmt.Errorf("pgp: failed to create armor encoder: %w", err) + } + + w, err := openpgp.Encrypt(armoredWriter, keyring, nil, nil, nil) + if err != nil { + armoredWriter.Close() + return nil, fmt.Errorf("pgp: failed to create encryption writer: %w", err) + } + + if _, err := w.Write(data); err != nil { + w.Close() + armoredWriter.Close() + return nil, fmt.Errorf("pgp: failed to write data: %w", err) + } + w.Close() + armoredWriter.Close() + + return buf.Bytes(), nil +} + +// Decrypt decrypts armored PGP data using the given armored private key. +// If the private key is encrypted, the password is used to decrypt it first. +func Decrypt(data []byte, privateKeyArmor, password string) ([]byte, error) { + keyring, err := openpgp.ReadArmoredKeyRing(bytes.NewReader([]byte(privateKeyArmor))) + if err != nil { + return nil, fmt.Errorf("pgp: failed to read private key ring: %w", err) + } + + // Decrypt the private key if it is encrypted + for _, entity := range keyring { + if entity.PrivateKey != nil && entity.PrivateKey.Encrypted { + if err := entity.PrivateKey.Decrypt([]byte(password)); err != nil { + return nil, fmt.Errorf("pgp: failed to decrypt private key: %w", err) + } + } + for _, subkey := range entity.Subkeys { + if subkey.PrivateKey != nil && subkey.PrivateKey.Encrypted { + _ = subkey.PrivateKey.Decrypt([]byte(password)) + } + } + } + + // Decode armored message + block, err := armor.Decode(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("pgp: failed to decode armored message: %w", err) + } + + md, err := openpgp.ReadMessage(block.Body, keyring, nil, nil) + if err != nil { + return nil, fmt.Errorf("pgp: failed to read message: %w", err) + } + + plaintext, err := io.ReadAll(md.UnverifiedBody) + if err != nil { + return nil, fmt.Errorf("pgp: failed to read plaintext: %w", err) + } + + return plaintext, nil +} + +// Sign creates an armored detached signature for the given data using +// the armored private key. If the key is encrypted, the password is used +// to decrypt it first. +func Sign(data []byte, privateKeyArmor, password string) ([]byte, error) { + keyring, err := openpgp.ReadArmoredKeyRing(bytes.NewReader([]byte(privateKeyArmor))) + if err != nil { + return nil, fmt.Errorf("pgp: failed to read private key ring: %w", err) + } + + signer := keyring[0] + if signer.PrivateKey == nil { + return nil, fmt.Errorf("pgp: private key not found in keyring") + } + + if signer.PrivateKey.Encrypted { + if err := signer.PrivateKey.Decrypt([]byte(password)); err != nil { + return nil, fmt.Errorf("pgp: failed to decrypt private key: %w", err) + } + } + + buf := new(bytes.Buffer) + config := &packet.Config{} + err = openpgp.ArmoredDetachSign(buf, signer, bytes.NewReader(data), config) + if err != nil { + return nil, fmt.Errorf("pgp: failed to sign message: %w", err) + } + + return buf.Bytes(), nil +} + +// Verify verifies an armored detached signature against the given data +// and armored public key. Returns nil if the signature is valid. +func Verify(data, signature []byte, publicKeyArmor string) error { + keyring, err := openpgp.ReadArmoredKeyRing(bytes.NewReader([]byte(publicKeyArmor))) + if err != nil { + return fmt.Errorf("pgp: failed to read public key ring: %w", err) + } + + _, err = openpgp.CheckArmoredDetachedSignature(keyring, bytes.NewReader(data), bytes.NewReader(signature), nil) + if err != nil { + return fmt.Errorf("pgp: signature verification failed: %w", err) + } + + return nil +} diff --git a/crypt/pgp/pgp_test.go b/crypt/pgp/pgp_test.go new file mode 100644 index 0000000..4f7edd9 --- /dev/null +++ b/crypt/pgp/pgp_test.go @@ -0,0 +1,164 @@ +package pgp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateKeyPair_Good(t *testing.T) { + kp, err := CreateKeyPair("Test User", "test@example.com", "") + require.NoError(t, err) + require.NotNil(t, kp) + assert.Contains(t, kp.PublicKey, "-----BEGIN PGP PUBLIC KEY BLOCK-----") + assert.Contains(t, kp.PrivateKey, "-----BEGIN PGP PRIVATE KEY BLOCK-----") +} + +func TestCreateKeyPair_Bad(t *testing.T) { + // Empty name still works (openpgp allows it), but test with password + kp, err := CreateKeyPair("Secure User", "secure@example.com", "strong-password") + require.NoError(t, err) + require.NotNil(t, kp) + assert.Contains(t, kp.PublicKey, "-----BEGIN PGP PUBLIC KEY BLOCK-----") + assert.Contains(t, kp.PrivateKey, "-----BEGIN PGP PRIVATE KEY BLOCK-----") +} + +func TestCreateKeyPair_Ugly(t *testing.T) { + // Minimal identity + kp, err := CreateKeyPair("", "", "") + require.NoError(t, err) + require.NotNil(t, kp) +} + +func TestEncryptDecrypt_Good(t *testing.T) { + kp, err := CreateKeyPair("Test User", "test@example.com", "") + require.NoError(t, err) + + plaintext := []byte("hello, OpenPGP!") + ciphertext, err := Encrypt(plaintext, kp.PublicKey) + require.NoError(t, err) + assert.NotEmpty(t, ciphertext) + assert.Contains(t, string(ciphertext), "-----BEGIN PGP MESSAGE-----") + + decrypted, err := Decrypt(ciphertext, kp.PrivateKey, "") + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +} + +func TestEncryptDecrypt_Bad(t *testing.T) { + kp1, err := CreateKeyPair("User One", "one@example.com", "") + require.NoError(t, err) + kp2, err := CreateKeyPair("User Two", "two@example.com", "") + require.NoError(t, err) + + plaintext := []byte("secret data") + ciphertext, err := Encrypt(plaintext, kp1.PublicKey) + require.NoError(t, err) + + // Decrypting with wrong key should fail + _, err = Decrypt(ciphertext, kp2.PrivateKey, "") + assert.Error(t, err) +} + +func TestEncryptDecrypt_Ugly(t *testing.T) { + // Invalid public key for encryption + _, err := Encrypt([]byte("data"), "not-a-pgp-key") + assert.Error(t, err) + + // Invalid private key for decryption + _, err = Decrypt([]byte("data"), "not-a-pgp-key", "") + assert.Error(t, err) +} + +func TestEncryptDecryptWithPassword_Good(t *testing.T) { + password := "my-secret-passphrase" + kp, err := CreateKeyPair("Secure User", "secure@example.com", password) + require.NoError(t, err) + + plaintext := []byte("encrypted with password-protected key") + ciphertext, err := Encrypt(plaintext, kp.PublicKey) + require.NoError(t, err) + + decrypted, err := Decrypt(ciphertext, kp.PrivateKey, password) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +} + +func TestSignVerify_Good(t *testing.T) { + kp, err := CreateKeyPair("Signer", "signer@example.com", "") + require.NoError(t, err) + + data := []byte("message to sign") + signature, err := Sign(data, kp.PrivateKey, "") + require.NoError(t, err) + assert.NotEmpty(t, signature) + assert.Contains(t, string(signature), "-----BEGIN PGP SIGNATURE-----") + + err = Verify(data, signature, kp.PublicKey) + assert.NoError(t, err) +} + +func TestSignVerify_Bad(t *testing.T) { + kp, err := CreateKeyPair("Signer", "signer@example.com", "") + require.NoError(t, err) + + data := []byte("original message") + signature, err := Sign(data, kp.PrivateKey, "") + require.NoError(t, err) + + // Verify with tampered data should fail + err = Verify([]byte("tampered message"), signature, kp.PublicKey) + assert.Error(t, err) +} + +func TestSignVerify_Ugly(t *testing.T) { + // Invalid key for signing + _, err := Sign([]byte("data"), "not-a-key", "") + assert.Error(t, err) + + // Invalid key for verification + kp, err := CreateKeyPair("Signer", "signer@example.com", "") + require.NoError(t, err) + + data := []byte("message") + sig, err := Sign(data, kp.PrivateKey, "") + require.NoError(t, err) + + err = Verify(data, sig, "not-a-key") + assert.Error(t, err) +} + +func TestSignVerifyWithPassword_Good(t *testing.T) { + password := "signing-password" + kp, err := CreateKeyPair("Signer", "signer@example.com", password) + require.NoError(t, err) + + data := []byte("signed with password-protected key") + signature, err := Sign(data, kp.PrivateKey, password) + require.NoError(t, err) + + err = Verify(data, signature, kp.PublicKey) + assert.NoError(t, err) +} + +func TestFullRoundTrip_Good(t *testing.T) { + // Generate keys, encrypt, decrypt, sign, and verify - full round trip + kp, err := CreateKeyPair("Full Test", "full@example.com", "") + require.NoError(t, err) + + original := []byte("full round-trip test data") + + // Encrypt then decrypt + ciphertext, err := Encrypt(original, kp.PublicKey) + require.NoError(t, err) + decrypted, err := Decrypt(ciphertext, kp.PrivateKey, "") + require.NoError(t, err) + assert.Equal(t, original, decrypted) + + // Sign then verify + signature, err := Sign(original, kp.PrivateKey, "") + require.NoError(t, err) + err = Verify(original, signature, kp.PublicKey) + assert.NoError(t, err) +} diff --git a/crypt/rsa/rsa.go b/crypt/rsa/rsa.go new file mode 100644 index 0000000..5470ea8 --- /dev/null +++ b/crypt/rsa/rsa.go @@ -0,0 +1,91 @@ +package rsa + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "fmt" +) + +// Service provides RSA functionality. +type Service struct{} + +// NewService creates and returns a new Service instance for performing RSA-related operations. +func NewService() *Service { + return &Service{} +} + +// GenerateKeyPair creates a new RSA key pair. +func (s *Service) GenerateKeyPair(bits int) (publicKey, privateKey []byte, err error) { + if bits < 2048 { + return nil, nil, fmt.Errorf("rsa: key size too small: %d (minimum 2048)", bits) + } + privKey, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %w", err) + } + + privKeyBytes := x509.MarshalPKCS1PrivateKey(privKey) + privKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: privKeyBytes, + }) + + pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal public key: %w", err) + } + pubKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: pubKeyBytes, + }) + + return pubKeyPEM, privKeyPEM, nil +} + +// Encrypt encrypts data with a public key. +func (s *Service) Encrypt(publicKey, data, label []byte) ([]byte, error) { + block, _ := pem.Decode(publicKey) + if block == nil { + return nil, fmt.Errorf("failed to decode public key") + } + + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + + rsaPub, ok := pub.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("not an RSA public key") + } + + ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, rsaPub, data, label) + if err != nil { + return nil, fmt.Errorf("failed to encrypt data: %w", err) + } + + return ciphertext, nil +} + +// Decrypt decrypts data with a private key. +func (s *Service) Decrypt(privateKey, ciphertext, label []byte) ([]byte, error) { + block, _ := pem.Decode(privateKey) + if block == nil { + return nil, fmt.Errorf("failed to decode private key") + } + + priv, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, priv, ciphertext, label) + if err != nil { + return nil, fmt.Errorf("failed to decrypt data: %w", err) + } + + return plaintext, nil +} diff --git a/crypt/rsa/rsa_test.go b/crypt/rsa/rsa_test.go new file mode 100644 index 0000000..c78d91d --- /dev/null +++ b/crypt/rsa/rsa_test.go @@ -0,0 +1,101 @@ +package rsa + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +// mockReader is a reader that returns an error. +type mockReader struct{} + +func (r *mockReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error") +} + +func TestRSA_Good(t *testing.T) { + s := NewService() + + // Generate a new key pair + pubKey, privKey, err := s.GenerateKeyPair(2048) + assert.NoError(t, err) + assert.NotEmpty(t, pubKey) + assert.NotEmpty(t, privKey) + + // Encrypt and decrypt a message + message := []byte("Hello, World!") + ciphertext, err := s.Encrypt(pubKey, message, nil) + assert.NoError(t, err) + plaintext, err := s.Decrypt(privKey, ciphertext, nil) + assert.NoError(t, err) + assert.Equal(t, message, plaintext) +} + +func TestRSA_Bad(t *testing.T) { + s := NewService() + + // Decrypt with wrong key + pubKey, _, err := s.GenerateKeyPair(2048) + assert.NoError(t, err) + _, otherPrivKey, err := s.GenerateKeyPair(2048) + assert.NoError(t, err) + message := []byte("Hello, World!") + ciphertext, err := s.Encrypt(pubKey, message, nil) + assert.NoError(t, err) + _, err = s.Decrypt(otherPrivKey, ciphertext, nil) + assert.Error(t, err) + + // Key size too small + _, _, err = s.GenerateKeyPair(512) + assert.Error(t, err) +} + +func TestRSA_Ugly(t *testing.T) { + s := NewService() + + // Malformed keys and messages + _, err := s.Encrypt([]byte("not-a-key"), []byte("message"), nil) + assert.Error(t, err) + _, err = s.Decrypt([]byte("not-a-key"), []byte("message"), nil) + assert.Error(t, err) + _, err = s.Encrypt([]byte("-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAJ/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4=\n-----END PUBLIC KEY-----"), []byte("message"), nil) + assert.Error(t, err) + _, err = s.Decrypt([]byte("-----BEGIN RSA PRIVATE KEY-----\nMIIBOQIBAAJBAL/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CAwEAAQJB\nAL/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4C\ngYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/4CgYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/4=\n-----END RSA PRIVATE KEY-----"), []byte("message"), nil) + assert.Error(t, err) + + // Key generation failure + oldReader := rand.Reader + rand.Reader = &mockReader{} + t.Cleanup(func() { rand.Reader = oldReader }) + _, _, err = s.GenerateKeyPair(2048) + assert.Error(t, err) + + // Encrypt with non-RSA key + rand.Reader = oldReader // Restore reader for this test + ecdsaPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.NoError(t, err) + ecdsaPubKeyBytes, err := x509.MarshalPKIXPublicKey(&ecdsaPrivKey.PublicKey) + assert.NoError(t, err) + ecdsaPubKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: ecdsaPubKeyBytes, + }) + _, err = s.Encrypt(ecdsaPubKeyPEM, []byte("message"), nil) + assert.Error(t, err) + rand.Reader = &mockReader{} // Set it back for the next test + + // Encrypt message too long + rand.Reader = oldReader // Restore reader for this test + pubKey, _, err := s.GenerateKeyPair(2048) + assert.NoError(t, err) + message := make([]byte, 2048) + _, err = s.Encrypt(pubKey, message, nil) + assert.Error(t, err) + rand.Reader = &mockReader{} // Set it back +} diff --git a/crypt/symmetric.go b/crypt/symmetric.go new file mode 100644 index 0000000..844e4a5 --- /dev/null +++ b/crypt/symmetric.go @@ -0,0 +1,100 @@ +package crypt + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + + core "forge.lthn.ai/core/go/pkg/framework/core" + "golang.org/x/crypto/chacha20poly1305" +) + +// ChaCha20Encrypt encrypts plaintext using ChaCha20-Poly1305. +// The key must be 32 bytes. The nonce is randomly generated and prepended +// to the ciphertext. +func ChaCha20Encrypt(plaintext, key []byte) ([]byte, error) { + aead, err := chacha20poly1305.NewX(key) + if err != nil { + return nil, core.E("crypt.ChaCha20Encrypt", "failed to create cipher", err) + } + + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, core.E("crypt.ChaCha20Encrypt", "failed to generate nonce", err) + } + + ciphertext := aead.Seal(nonce, nonce, plaintext, nil) + return ciphertext, nil +} + +// ChaCha20Decrypt decrypts ciphertext encrypted with ChaCha20Encrypt. +// The key must be 32 bytes. Expects the nonce prepended to the ciphertext. +func ChaCha20Decrypt(ciphertext, key []byte) ([]byte, error) { + aead, err := chacha20poly1305.NewX(key) + if err != nil { + return nil, core.E("crypt.ChaCha20Decrypt", "failed to create cipher", err) + } + + nonceSize := aead.NonceSize() + if len(ciphertext) < nonceSize { + return nil, core.E("crypt.ChaCha20Decrypt", "ciphertext too short", nil) + } + + nonce, encrypted := ciphertext[:nonceSize], ciphertext[nonceSize:] + plaintext, err := aead.Open(nil, nonce, encrypted, nil) + if err != nil { + return nil, core.E("crypt.ChaCha20Decrypt", "failed to decrypt", err) + } + + return plaintext, nil +} + +// AESGCMEncrypt encrypts plaintext using AES-256-GCM. +// The key must be 32 bytes. The nonce is randomly generated and prepended +// to the ciphertext. +func AESGCMEncrypt(plaintext, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, core.E("crypt.AESGCMEncrypt", "failed to create cipher", err) + } + + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, core.E("crypt.AESGCMEncrypt", "failed to create GCM", err) + } + + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, core.E("crypt.AESGCMEncrypt", "failed to generate nonce", err) + } + + ciphertext := aead.Seal(nonce, nonce, plaintext, nil) + return ciphertext, nil +} + +// AESGCMDecrypt decrypts ciphertext encrypted with AESGCMEncrypt. +// The key must be 32 bytes. Expects the nonce prepended to the ciphertext. +func AESGCMDecrypt(ciphertext, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, core.E("crypt.AESGCMDecrypt", "failed to create cipher", err) + } + + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, core.E("crypt.AESGCMDecrypt", "failed to create GCM", err) + } + + nonceSize := aead.NonceSize() + if len(ciphertext) < nonceSize { + return nil, core.E("crypt.AESGCMDecrypt", "ciphertext too short", nil) + } + + nonce, encrypted := ciphertext[:nonceSize], ciphertext[nonceSize:] + plaintext, err := aead.Open(nil, nonce, encrypted, nil) + if err != nil { + return nil, core.E("crypt.AESGCMDecrypt", "failed to decrypt", err) + } + + return plaintext, nil +} diff --git a/crypt/symmetric_test.go b/crypt/symmetric_test.go new file mode 100644 index 0000000..a060579 --- /dev/null +++ b/crypt/symmetric_test.go @@ -0,0 +1,55 @@ +package crypt + +import ( + "crypto/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestChaCha20_Good(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + assert.NoError(t, err) + + plaintext := []byte("ChaCha20-Poly1305 test data") + + encrypted, err := ChaCha20Encrypt(plaintext, key) + assert.NoError(t, err) + assert.NotEqual(t, plaintext, encrypted) + + decrypted, err := ChaCha20Decrypt(encrypted, key) + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +} + +func TestChaCha20_Bad(t *testing.T) { + key := make([]byte, 32) + wrongKey := make([]byte, 32) + _, _ = rand.Read(key) + _, _ = rand.Read(wrongKey) + + plaintext := []byte("secret message") + + encrypted, err := ChaCha20Encrypt(plaintext, key) + assert.NoError(t, err) + + _, err = ChaCha20Decrypt(encrypted, wrongKey) + assert.Error(t, err) +} + +func TestAESGCM_Good(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + assert.NoError(t, err) + + plaintext := []byte("AES-256-GCM test data") + + encrypted, err := AESGCMEncrypt(plaintext, key) + assert.NoError(t, err) + assert.NotEqual(t, plaintext, encrypted) + + decrypted, err := AESGCMDecrypt(encrypted, key) + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..fd2bcf4 --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module forge.lthn.ai/core/go-crypt + +go 1.25.5 + +require ( + forge.lthn.ai/core/go v0.0.0 + github.com/ProtonMail/go-crypto v1.3.0 + github.com/stretchr/testify v1.11.1 + golang.org/x/crypto v0.48.0 +) + +require ( + github.com/cloudflare/circl v1.6.3 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + golang.org/x/sys v0.41.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace forge.lthn.ai/core/go => ../go diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ecc0c9a --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw= +github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE= +github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= +github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/trust/policy.go b/trust/policy.go new file mode 100644 index 0000000..a7da2ca --- /dev/null +++ b/trust/policy.go @@ -0,0 +1,238 @@ +package trust + +import ( + "fmt" + "strings" +) + +// Policy defines the access rules for a given trust tier. +type Policy struct { + // Tier is the trust level this policy applies to. + Tier Tier + // Allowed lists the capabilities granted at this tier. + Allowed []Capability + // RequiresApproval lists capabilities that need human/higher-tier approval. + RequiresApproval []Capability + // Denied lists explicitly denied capabilities. + Denied []Capability +} + +// PolicyEngine evaluates capability requests against registered policies. +type PolicyEngine struct { + registry *Registry + policies map[Tier]*Policy +} + +// Decision is the result of a policy evaluation. +type Decision int + +const ( + // Deny means the action is not permitted. + Deny Decision = iota + // Allow means the action is permitted. + Allow + // NeedsApproval means the action requires human or higher-tier approval. + NeedsApproval +) + +// String returns the human-readable name of the decision. +func (d Decision) String() string { + switch d { + case Deny: + return "deny" + case Allow: + return "allow" + case NeedsApproval: + return "needs_approval" + default: + return fmt.Sprintf("unknown(%d)", int(d)) + } +} + +// EvalResult contains the outcome of a capability evaluation. +type EvalResult struct { + Decision Decision + Agent string + Cap Capability + Reason string +} + +// NewPolicyEngine creates a policy engine with the given registry and default policies. +func NewPolicyEngine(registry *Registry) *PolicyEngine { + pe := &PolicyEngine{ + registry: registry, + policies: make(map[Tier]*Policy), + } + pe.loadDefaults() + return pe +} + +// Evaluate checks whether the named agent can perform the given capability. +// If the agent has scoped repos and the capability is repo-scoped, the repo +// parameter is checked against the agent's allowed repos. +func (pe *PolicyEngine) Evaluate(agentName string, cap Capability, repo string) EvalResult { + agent := pe.registry.Get(agentName) + if agent == nil { + return EvalResult{ + Decision: Deny, + Agent: agentName, + Cap: cap, + Reason: "agent not registered", + } + } + + policy, ok := pe.policies[agent.Tier] + if !ok { + return EvalResult{ + Decision: Deny, + Agent: agentName, + Cap: cap, + Reason: fmt.Sprintf("no policy for tier %s", agent.Tier), + } + } + + // Check explicit denials first. + for _, denied := range policy.Denied { + if denied == cap { + return EvalResult{ + Decision: Deny, + Agent: agentName, + Cap: cap, + Reason: fmt.Sprintf("capability %s is denied for tier %s", cap, agent.Tier), + } + } + } + + // Check if capability requires approval. + for _, approval := range policy.RequiresApproval { + if approval == cap { + return EvalResult{ + Decision: NeedsApproval, + Agent: agentName, + Cap: cap, + Reason: fmt.Sprintf("capability %s requires approval for tier %s", cap, agent.Tier), + } + } + } + + // Check if capability is allowed. + for _, allowed := range policy.Allowed { + if allowed == cap { + // For repo-scoped capabilities, verify repo access. + if isRepoScoped(cap) && len(agent.ScopedRepos) > 0 { + if !repoAllowed(agent.ScopedRepos, repo) { + return EvalResult{ + Decision: Deny, + Agent: agentName, + Cap: cap, + Reason: fmt.Sprintf("agent %q does not have access to repo %q", agentName, repo), + } + } + } + return EvalResult{ + Decision: Allow, + Agent: agentName, + Cap: cap, + Reason: fmt.Sprintf("capability %s allowed for tier %s", cap, agent.Tier), + } + } + } + + return EvalResult{ + Decision: Deny, + Agent: agentName, + Cap: cap, + Reason: fmt.Sprintf("capability %s not granted for tier %s", cap, agent.Tier), + } +} + +// SetPolicy replaces the policy for a given tier. +func (pe *PolicyEngine) SetPolicy(p Policy) error { + if !p.Tier.Valid() { + return fmt.Errorf("trust.SetPolicy: invalid tier %d", p.Tier) + } + pe.policies[p.Tier] = &p + return nil +} + +// GetPolicy returns the policy for a tier, or nil if none is set. +func (pe *PolicyEngine) GetPolicy(t Tier) *Policy { + return pe.policies[t] +} + +// loadDefaults installs the default trust policies from the issue spec. +func (pe *PolicyEngine) loadDefaults() { + // Tier 3 — Full Trust + pe.policies[TierFull] = &Policy{ + Tier: TierFull, + Allowed: []Capability{ + CapPushRepo, + CapMergePR, + CapCreatePR, + CapCreateIssue, + CapCommentIssue, + CapReadSecrets, + CapRunPrivileged, + CapAccessWorkspace, + CapModifyFlows, + }, + } + + // Tier 2 — Verified + pe.policies[TierVerified] = &Policy{ + Tier: TierVerified, + Allowed: []Capability{ + CapPushRepo, // scoped to assigned repos + CapCreatePR, // can create, not merge + CapCreateIssue, + CapCommentIssue, + CapReadSecrets, // scoped to their repos + }, + RequiresApproval: []Capability{ + CapMergePR, + }, + Denied: []Capability{ + CapAccessWorkspace, // cannot access other agents' workspaces + CapModifyFlows, + CapRunPrivileged, + }, + } + + // Tier 1 — Untrusted + pe.policies[TierUntrusted] = &Policy{ + Tier: TierUntrusted, + Allowed: []Capability{ + CapCreatePR, // fork only, checked at enforcement layer + CapCommentIssue, + }, + Denied: []Capability{ + CapPushRepo, + CapMergePR, + CapCreateIssue, + CapReadSecrets, + CapRunPrivileged, + CapAccessWorkspace, + CapModifyFlows, + }, + } +} + +// isRepoScoped returns true if the capability is constrained by repo scope. +func isRepoScoped(cap Capability) bool { + return strings.HasPrefix(string(cap), "repo.") || + strings.HasPrefix(string(cap), "pr.") || + cap == CapReadSecrets +} + +// repoAllowed checks if repo is in the agent's scoped list. +func repoAllowed(scoped []string, repo string) bool { + if repo == "" { + return false + } + for _, r := range scoped { + if r == repo { + return true + } + } + return false +} diff --git a/trust/policy_test.go b/trust/policy_test.go new file mode 100644 index 0000000..cf975d4 --- /dev/null +++ b/trust/policy_test.go @@ -0,0 +1,268 @@ +package trust + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestEngine(t *testing.T) *PolicyEngine { + t.Helper() + r := NewRegistry() + require.NoError(t, r.Register(Agent{ + Name: "Athena", + Tier: TierFull, + })) + require.NoError(t, r.Register(Agent{ + Name: "Clotho", + Tier: TierVerified, + ScopedRepos: []string{"host-uk/core", "host-uk/docs"}, + })) + require.NoError(t, r.Register(Agent{ + Name: "BugSETI-001", + Tier: TierUntrusted, + })) + return NewPolicyEngine(r) +} + +// --- Decision --- + +func TestDecisionString_Good(t *testing.T) { + assert.Equal(t, "deny", Deny.String()) + assert.Equal(t, "allow", Allow.String()) + assert.Equal(t, "needs_approval", NeedsApproval.String()) +} + +func TestDecisionString_Bad_Unknown(t *testing.T) { + assert.Contains(t, Decision(99).String(), "unknown") +} + +// --- Tier 3 (Full Trust) --- + +func TestEvaluate_Good_Tier3CanDoAnything(t *testing.T) { + pe := newTestEngine(t) + + caps := []Capability{ + CapPushRepo, CapMergePR, CapCreatePR, CapCreateIssue, + CapCommentIssue, CapReadSecrets, CapRunPrivileged, + CapAccessWorkspace, CapModifyFlows, + } + for _, cap := range caps { + result := pe.Evaluate("Athena", cap, "") + assert.Equal(t, Allow, result.Decision, "Athena should be allowed %s", cap) + } +} + +// --- Tier 2 (Verified) --- + +func TestEvaluate_Good_Tier2CanCreatePR(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("Clotho", CapCreatePR, "host-uk/core") + assert.Equal(t, Allow, result.Decision) +} + +func TestEvaluate_Good_Tier2CanPushToScopedRepo(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("Clotho", CapPushRepo, "host-uk/core") + assert.Equal(t, Allow, result.Decision) +} + +func TestEvaluate_Good_Tier2NeedsApprovalToMerge(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("Clotho", CapMergePR, "host-uk/core") + assert.Equal(t, NeedsApproval, result.Decision) +} + +func TestEvaluate_Good_Tier2CanCreateIssue(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("Clotho", CapCreateIssue, "") + assert.Equal(t, Allow, result.Decision) +} + +func TestEvaluate_Bad_Tier2CannotAccessWorkspace(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("Clotho", CapAccessWorkspace, "") + assert.Equal(t, Deny, result.Decision) +} + +func TestEvaluate_Bad_Tier2CannotModifyFlows(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("Clotho", CapModifyFlows, "") + assert.Equal(t, Deny, result.Decision) +} + +func TestEvaluate_Bad_Tier2CannotRunPrivileged(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("Clotho", CapRunPrivileged, "") + assert.Equal(t, Deny, result.Decision) +} + +func TestEvaluate_Bad_Tier2CannotPushToUnscopedRepo(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("Clotho", CapPushRepo, "host-uk/secret-repo") + assert.Equal(t, Deny, result.Decision) + assert.Contains(t, result.Reason, "does not have access") +} + +func TestEvaluate_Bad_Tier2RepoScopeEmptyRepo(t *testing.T) { + pe := newTestEngine(t) + // Push without specifying a repo should be denied for scoped agents. + result := pe.Evaluate("Clotho", CapPushRepo, "") + assert.Equal(t, Deny, result.Decision) +} + +// --- Tier 1 (Untrusted) --- + +func TestEvaluate_Good_Tier1CanCreatePR(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("BugSETI-001", CapCreatePR, "") + assert.Equal(t, Allow, result.Decision) +} + +func TestEvaluate_Good_Tier1CanCommentIssue(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("BugSETI-001", CapCommentIssue, "") + assert.Equal(t, Allow, result.Decision) +} + +func TestEvaluate_Bad_Tier1CannotPush(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("BugSETI-001", CapPushRepo, "") + assert.Equal(t, Deny, result.Decision) +} + +func TestEvaluate_Bad_Tier1CannotMerge(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("BugSETI-001", CapMergePR, "") + assert.Equal(t, Deny, result.Decision) +} + +func TestEvaluate_Bad_Tier1CannotCreateIssue(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("BugSETI-001", CapCreateIssue, "") + assert.Equal(t, Deny, result.Decision) +} + +func TestEvaluate_Bad_Tier1CannotReadSecrets(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("BugSETI-001", CapReadSecrets, "") + assert.Equal(t, Deny, result.Decision) +} + +func TestEvaluate_Bad_Tier1CannotRunPrivileged(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("BugSETI-001", CapRunPrivileged, "") + assert.Equal(t, Deny, result.Decision) +} + +// --- Edge cases --- + +func TestEvaluate_Bad_UnknownAgent(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("Unknown", CapCreatePR, "") + assert.Equal(t, Deny, result.Decision) + assert.Contains(t, result.Reason, "not registered") +} + +func TestEvaluate_Good_EvalResultFields(t *testing.T) { + pe := newTestEngine(t) + result := pe.Evaluate("Athena", CapPushRepo, "") + assert.Equal(t, "Athena", result.Agent) + assert.Equal(t, CapPushRepo, result.Cap) + assert.NotEmpty(t, result.Reason) +} + +// --- SetPolicy --- + +func TestSetPolicy_Good(t *testing.T) { + pe := newTestEngine(t) + err := pe.SetPolicy(Policy{ + Tier: TierVerified, + Allowed: []Capability{CapPushRepo, CapMergePR}, + }) + require.NoError(t, err) + + // Verify the new policy is in effect. + result := pe.Evaluate("Clotho", CapMergePR, "host-uk/core") + assert.Equal(t, Allow, result.Decision) +} + +func TestSetPolicy_Bad_InvalidTier(t *testing.T) { + pe := newTestEngine(t) + err := pe.SetPolicy(Policy{Tier: Tier(0)}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid tier") +} + +func TestGetPolicy_Good(t *testing.T) { + pe := newTestEngine(t) + p := pe.GetPolicy(TierFull) + require.NotNil(t, p) + assert.Equal(t, TierFull, p.Tier) +} + +func TestGetPolicy_Bad_NotFound(t *testing.T) { + pe := newTestEngine(t) + assert.Nil(t, pe.GetPolicy(Tier(99))) +} + +// --- isRepoScoped / repoAllowed helpers --- + +func TestIsRepoScoped_Good(t *testing.T) { + assert.True(t, isRepoScoped(CapPushRepo)) + assert.True(t, isRepoScoped(CapCreatePR)) + assert.True(t, isRepoScoped(CapMergePR)) + assert.True(t, isRepoScoped(CapReadSecrets)) +} + +func TestIsRepoScoped_Bad_NotScoped(t *testing.T) { + assert.False(t, isRepoScoped(CapRunPrivileged)) + assert.False(t, isRepoScoped(CapAccessWorkspace)) + assert.False(t, isRepoScoped(CapModifyFlows)) +} + +func TestRepoAllowed_Good(t *testing.T) { + scoped := []string{"host-uk/core", "host-uk/docs"} + assert.True(t, repoAllowed(scoped, "host-uk/core")) + assert.True(t, repoAllowed(scoped, "host-uk/docs")) +} + +func TestRepoAllowed_Bad_NotInScope(t *testing.T) { + scoped := []string{"host-uk/core"} + assert.False(t, repoAllowed(scoped, "host-uk/secret")) +} + +func TestRepoAllowed_Bad_EmptyRepo(t *testing.T) { + scoped := []string{"host-uk/core"} + assert.False(t, repoAllowed(scoped, "")) +} + +func TestRepoAllowed_Bad_EmptyScope(t *testing.T) { + assert.False(t, repoAllowed(nil, "host-uk/core")) + assert.False(t, repoAllowed([]string{}, "host-uk/core")) +} + +// --- Tier 3 ignores repo scoping --- + +func TestEvaluate_Good_Tier3IgnoresRepoScope(t *testing.T) { + r := NewRegistry() + require.NoError(t, r.Register(Agent{ + Name: "Virgil", + Tier: TierFull, + ScopedRepos: []string{}, // empty scope should not restrict Tier 3 + })) + pe := NewPolicyEngine(r) + + result := pe.Evaluate("Virgil", CapPushRepo, "any-repo") + assert.Equal(t, Allow, result.Decision) +} + +// --- Default rate limits --- + +func TestDefaultRateLimit(t *testing.T) { + assert.Equal(t, 10, defaultRateLimit(TierUntrusted)) + assert.Equal(t, 60, defaultRateLimit(TierVerified)) + assert.Equal(t, 0, defaultRateLimit(TierFull)) + assert.Equal(t, 10, defaultRateLimit(Tier(99))) // unknown defaults to 10 +} diff --git a/trust/trust.go b/trust/trust.go new file mode 100644 index 0000000..d5c0636 --- /dev/null +++ b/trust/trust.go @@ -0,0 +1,165 @@ +// Package trust implements an agent trust model with tiered access control. +// +// Agents are assigned trust tiers that determine their capabilities: +// +// - Tier 3 (Full Trust): Internal agents with full access (e.g., Athena, Virgil, Charon) +// - Tier 2 (Verified): Partner agents with scoped access (e.g., Clotho, Hypnos) +// - Tier 1 (Untrusted): External/community agents with minimal access +// +// The package provides a Registry for managing agent identities and a PolicyEngine +// for evaluating capability requests against trust policies. +package trust + +import ( + "fmt" + "sync" + "time" +) + +// Tier represents an agent's trust level in the system. +type Tier int + +const ( + // TierUntrusted is for external/community agents with minimal access. + TierUntrusted Tier = 1 + // TierVerified is for partner agents with scoped access. + TierVerified Tier = 2 + // TierFull is for internal agents with full access. + TierFull Tier = 3 +) + +// String returns the human-readable name of the tier. +func (t Tier) String() string { + switch t { + case TierUntrusted: + return "untrusted" + case TierVerified: + return "verified" + case TierFull: + return "full" + default: + return fmt.Sprintf("unknown(%d)", int(t)) + } +} + +// Valid returns true if the tier is a recognised trust level. +func (t Tier) Valid() bool { + return t >= TierUntrusted && t <= TierFull +} + +// Capability represents a specific action an agent can perform. +type Capability string + +const ( + CapPushRepo Capability = "repo.push" + CapMergePR Capability = "pr.merge" + CapCreatePR Capability = "pr.create" + CapCreateIssue Capability = "issue.create" + CapCommentIssue Capability = "issue.comment" + CapReadSecrets Capability = "secrets.read" + CapRunPrivileged Capability = "cmd.privileged" + CapAccessWorkspace Capability = "workspace.access" + CapModifyFlows Capability = "flows.modify" +) + +// Agent represents an agent identity in the trust system. +type Agent struct { + // Name is the unique identifier for the agent (e.g., "Athena", "Clotho"). + Name string + // Tier is the agent's trust level. + Tier Tier + // ScopedRepos limits repo access for Tier 2 agents. Empty means no repo access. + // Tier 3 agents ignore this field (they have access to all repos). + ScopedRepos []string + // RateLimit is the maximum requests per minute. 0 means unlimited. + RateLimit int + // TokenExpiresAt is when the agent's token expires. + TokenExpiresAt time.Time + // CreatedAt is when the agent was registered. + CreatedAt time.Time +} + +// Registry manages agent identities and their trust tiers. +type Registry struct { + mu sync.RWMutex + agents map[string]*Agent +} + +// NewRegistry creates an empty agent registry. +func NewRegistry() *Registry { + return &Registry{ + agents: make(map[string]*Agent), + } +} + +// Register adds or updates an agent in the registry. +// Returns an error if the agent name is empty or the tier is invalid. +func (r *Registry) Register(agent Agent) error { + if agent.Name == "" { + return fmt.Errorf("trust.Register: agent name is required") + } + if !agent.Tier.Valid() { + return fmt.Errorf("trust.Register: invalid tier %d for agent %q", agent.Tier, agent.Name) + } + if agent.CreatedAt.IsZero() { + agent.CreatedAt = time.Now() + } + if agent.RateLimit == 0 { + agent.RateLimit = defaultRateLimit(agent.Tier) + } + + r.mu.Lock() + defer r.mu.Unlock() + r.agents[agent.Name] = &agent + return nil +} + +// Get returns the agent with the given name, or nil if not found. +func (r *Registry) Get(name string) *Agent { + r.mu.RLock() + defer r.mu.RUnlock() + return r.agents[name] +} + +// Remove deletes an agent from the registry. +func (r *Registry) Remove(name string) bool { + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.agents[name]; !ok { + return false + } + delete(r.agents, name) + return true +} + +// List returns all registered agents. The returned slice is a snapshot. +func (r *Registry) List() []Agent { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]Agent, 0, len(r.agents)) + for _, a := range r.agents { + out = append(out, *a) + } + return out +} + +// Len returns the number of registered agents. +func (r *Registry) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.agents) +} + +// defaultRateLimit returns the default rate limit for a given tier. +func defaultRateLimit(t Tier) int { + switch t { + case TierUntrusted: + return 10 + case TierVerified: + return 60 + case TierFull: + return 0 // unlimited + default: + return 10 + } +} diff --git a/trust/trust_test.go b/trust/trust_test.go new file mode 100644 index 0000000..af0a9d3 --- /dev/null +++ b/trust/trust_test.go @@ -0,0 +1,164 @@ +package trust + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Tier --- + +func TestTierString_Good(t *testing.T) { + assert.Equal(t, "untrusted", TierUntrusted.String()) + assert.Equal(t, "verified", TierVerified.String()) + assert.Equal(t, "full", TierFull.String()) +} + +func TestTierString_Bad_Unknown(t *testing.T) { + assert.Contains(t, Tier(99).String(), "unknown") +} + +func TestTierValid_Good(t *testing.T) { + assert.True(t, TierUntrusted.Valid()) + assert.True(t, TierVerified.Valid()) + assert.True(t, TierFull.Valid()) +} + +func TestTierValid_Bad(t *testing.T) { + assert.False(t, Tier(0).Valid()) + assert.False(t, Tier(4).Valid()) + assert.False(t, Tier(-1).Valid()) +} + +// --- Registry --- + +func TestRegistryRegister_Good(t *testing.T) { + r := NewRegistry() + err := r.Register(Agent{Name: "Athena", Tier: TierFull}) + require.NoError(t, err) + assert.Equal(t, 1, r.Len()) +} + +func TestRegistryRegister_Good_SetsDefaults(t *testing.T) { + r := NewRegistry() + err := r.Register(Agent{Name: "Athena", Tier: TierFull}) + require.NoError(t, err) + + a := r.Get("Athena") + require.NotNil(t, a) + assert.Equal(t, 0, a.RateLimit) // full trust = unlimited + assert.False(t, a.CreatedAt.IsZero()) +} + +func TestRegistryRegister_Good_TierDefaults(t *testing.T) { + r := NewRegistry() + require.NoError(t, r.Register(Agent{Name: "A", Tier: TierUntrusted})) + require.NoError(t, r.Register(Agent{Name: "B", Tier: TierVerified})) + require.NoError(t, r.Register(Agent{Name: "C", Tier: TierFull})) + + assert.Equal(t, 10, r.Get("A").RateLimit) + assert.Equal(t, 60, r.Get("B").RateLimit) + assert.Equal(t, 0, r.Get("C").RateLimit) +} + +func TestRegistryRegister_Good_PreservesExplicitRateLimit(t *testing.T) { + r := NewRegistry() + err := r.Register(Agent{Name: "Custom", Tier: TierVerified, RateLimit: 30}) + require.NoError(t, err) + assert.Equal(t, 30, r.Get("Custom").RateLimit) +} + +func TestRegistryRegister_Good_Update(t *testing.T) { + r := NewRegistry() + require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierVerified})) + require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierFull})) + + assert.Equal(t, 1, r.Len()) + assert.Equal(t, TierFull, r.Get("Athena").Tier) +} + +func TestRegistryRegister_Bad_EmptyName(t *testing.T) { + r := NewRegistry() + err := r.Register(Agent{Tier: TierFull}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "name is required") +} + +func TestRegistryRegister_Bad_InvalidTier(t *testing.T) { + r := NewRegistry() + err := r.Register(Agent{Name: "Bad", Tier: Tier(0)}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid tier") +} + +func TestRegistryGet_Good(t *testing.T) { + r := NewRegistry() + require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierFull})) + a := r.Get("Athena") + require.NotNil(t, a) + assert.Equal(t, "Athena", a.Name) +} + +func TestRegistryGet_Bad_NotFound(t *testing.T) { + r := NewRegistry() + assert.Nil(t, r.Get("nonexistent")) +} + +func TestRegistryRemove_Good(t *testing.T) { + r := NewRegistry() + require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierFull})) + assert.True(t, r.Remove("Athena")) + assert.Equal(t, 0, r.Len()) +} + +func TestRegistryRemove_Bad_NotFound(t *testing.T) { + r := NewRegistry() + assert.False(t, r.Remove("nonexistent")) +} + +func TestRegistryList_Good(t *testing.T) { + r := NewRegistry() + require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierFull})) + require.NoError(t, r.Register(Agent{Name: "Clotho", Tier: TierVerified})) + + agents := r.List() + assert.Len(t, agents, 2) + + names := make(map[string]bool) + for _, a := range agents { + names[a.Name] = true + } + assert.True(t, names["Athena"]) + assert.True(t, names["Clotho"]) +} + +func TestRegistryList_Good_Empty(t *testing.T) { + r := NewRegistry() + assert.Empty(t, r.List()) +} + +func TestRegistryList_Good_Snapshot(t *testing.T) { + r := NewRegistry() + require.NoError(t, r.Register(Agent{Name: "Athena", Tier: TierFull})) + agents := r.List() + + // Modifying the returned slice should not affect the registry. + agents[0].Tier = TierUntrusted + assert.Equal(t, TierFull, r.Get("Athena").Tier) +} + +// --- Agent --- + +func TestAgentTokenExpiry(t *testing.T) { + agent := Agent{ + Name: "Test", + Tier: TierVerified, + TokenExpiresAt: time.Now().Add(-1 * time.Hour), + } + assert.True(t, time.Now().After(agent.TokenExpiresAt)) + + agent.TokenExpiresAt = time.Now().Add(1 * time.Hour) + assert.True(t, time.Now().Before(agent.TokenExpiresAt)) +}