LEM/experiments/worf/v1_positional.py
Snider f79eaabdce feat: WoRF — Word Radiance Field experiments
NeRF-inspired technique for learning relational dynamics of language.
Not what words mean, but how they behave together — rhythm, pacing,
punctuation patterns, style transitions.

v1: positional field over text (baseline, memorises)
v2: masked feature prediction (relational, actually works)

Trained on Wodehouse "My Man Jeeves" (public domain, Gutenberg).
All 11 style features are highly relational — the field learns that
Wodehouse's style is a tightly coupled system.

Key finding: style interpolation between narrative and dialogue
produces sensible predictions for unmeasured features, suggesting
the continuous field captures real structural patterns.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-03-04 09:43:38 +00:00

390 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
WoRF Experiment — Word Radiance Field
======================================
Feed Wodehouse's "My Man Jeeves" into a NeRF-like MLP and see what
the continuous field learns about writing style.
NeRF: (x, y, z, θ, φ) → (r, g, b, σ)
WoRF: (position_in_text, chunk_context) → (style_features)
First pass: 1D position → style feature vector. No viewing angle yet.
Just see if a continuous field over text position learns anything.
"""
import re
import math
import json
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from collections import Counter
# ---------------------------------------------------------------------------
# 1. Text Splitting — each "page" is one observation (like one photo for NeRF)
# ---------------------------------------------------------------------------
def load_and_clean(path: str) -> str:
"""Strip Gutenberg header/footer, return clean text."""
text = Path(path).read_text(encoding="utf-8")
# Strip PG header
start = text.find("LEAVE IT TO JEEVES")
if start == -1:
start = text.find("*** START OF")
start = text.find("\n", start) + 1
# Strip PG footer
end = text.find("*** END OF THE PROJECT GUTENBERG")
if end == -1:
end = len(text)
return text[start:end].strip()
def split_into_chunks(text: str, chunk_size: int = 500) -> list[str]:
"""Split text into roughly equal word-count chunks (pages)."""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size):
chunk = " ".join(words[i:i + chunk_size])
if len(chunk.split()) > 50: # skip tiny trailing chunks
chunks.append(chunk)
return chunks
# ---------------------------------------------------------------------------
# 2. Feature Extraction — what we measure about each "page"
# ---------------------------------------------------------------------------
def extract_features(chunk: str) -> dict:
"""Extract stylistic features from a chunk of text.
These are the 'RGB + density' equivalent — what the field predicts.
"""
words = chunk.split()
sentences = re.split(r'[.!?]+', chunk)
sentences = [s.strip() for s in sentences if s.strip()]
word_lengths = [len(w.strip(".,;:!?\"'()—-")) for w in words]
word_lengths = [l for l in word_lengths if l > 0]
# Dialogue detection
dialogue_chars = sum(1 for c in chunk if c == '"')
total_chars = len(chunk) or 1
# Punctuation patterns (Wodehouse loves dashes and exclamations)
dashes = chunk.count("") + chunk.count("--")
exclamations = chunk.count("!")
questions = chunk.count("?")
commas = chunk.count(",")
# Vocabulary richness (unique words / total words)
unique_words = len(set(w.lower().strip(".,;:!?\"'()—-") for w in words))
total_words = len(words) or 1
# Sentence length variation (std dev) — captures rhythm
sent_lengths = [len(s.split()) for s in sentences]
sent_mean = np.mean(sent_lengths) if sent_lengths else 0
sent_std = np.std(sent_lengths) if sent_lengths else 0
# Short sentence ratio (punchy lines like "Injudicious, sir.")
short_sentences = sum(1 for l in sent_lengths if l <= 5)
short_ratio = short_sentences / (len(sent_lengths) or 1)
# Aside/parenthetical density (commas, dashes per word)
aside_density = (commas + dashes) / total_words
return {
"avg_word_length": np.mean(word_lengths) if word_lengths else 0,
"avg_sentence_length": sent_mean,
"sentence_length_variance": sent_std,
"dialogue_ratio": dialogue_chars / total_chars,
"vocabulary_richness": unique_words / total_words,
"dash_density": dashes / total_words,
"exclamation_density": exclamations / total_words,
"question_density": questions / total_words,
"short_sentence_ratio": short_ratio,
"aside_density": aside_density,
"avg_punct_per_sentence": (commas + dashes + exclamations + questions) / (len(sent_lengths) or 1),
}
FEATURE_NAMES = list(extract_features("dummy text here for keys").keys())
NUM_FEATURES = len(FEATURE_NAMES)
# ---------------------------------------------------------------------------
# 3. Positional Encoding — NeRF's trick for capturing high-frequency detail
# ---------------------------------------------------------------------------
def positional_encoding(x: torch.Tensor, num_frequencies: int = 10) -> torch.Tensor:
"""NeRF-style sinusoidal positional encoding.
Maps a scalar position into a higher-dimensional space so the MLP
can learn sharp transitions (same reason NeRF needs it for edges).
"""
encodings = [x]
for freq in range(num_frequencies):
encodings.append(torch.sin(2.0 ** freq * math.pi * x))
encodings.append(torch.cos(2.0 ** freq * math.pi * x))
return torch.cat(encodings, dim=-1)
# ---------------------------------------------------------------------------
# 4. The WoRF Network — tiny MLP, same architecture as vanilla NeRF
# ---------------------------------------------------------------------------
class WoRF(nn.Module):
"""Word Radiance Field — learns a continuous style field over text position."""
def __init__(self, input_dim: int, hidden_dim: int = 128, num_layers: int = 4,
output_dim: int = NUM_FEATURES):
super().__init__()
layers = []
layers.append(nn.Linear(input_dim, hidden_dim))
layers.append(nn.ReLU())
for i in range(num_layers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
layers.append(nn.ReLU())
# Skip connection at midpoint (like NeRF)
if i == (num_layers - 2) // 2 - 1:
self.skip_layer_idx = len(layers)
layers.append(nn.Linear(hidden_dim, output_dim))
self.network = nn.Sequential(*layers)
self.input_dim = input_dim
self.hidden_dim = hidden_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x)
# ---------------------------------------------------------------------------
# 5. Training
# ---------------------------------------------------------------------------
def prepare_data(chunks: list[str]) -> tuple[torch.Tensor, torch.Tensor, dict]:
"""Convert chunks to training data: positions → features."""
n = len(chunks)
positions = []
features = []
for i, chunk in enumerate(chunks):
pos = i / (n - 1) # normalise to [0, 1]
feat = extract_features(chunk)
positions.append(pos)
features.append([feat[k] for k in FEATURE_NAMES])
positions = torch.tensor(positions, dtype=torch.float32).unsqueeze(-1)
features = torch.tensor(features, dtype=torch.float32)
# Normalise features to [0, 1] range for training stability
feat_min = features.min(dim=0).values
feat_max = features.max(dim=0).values
feat_range = feat_max - feat_min
feat_range[feat_range == 0] = 1.0 # avoid division by zero
features_norm = (features - feat_min) / feat_range
norm_stats = {"min": feat_min, "max": feat_max, "range": feat_range}
return positions, features_norm, norm_stats
def train_worf(positions: torch.Tensor, features: torch.Tensor,
num_frequencies: int = 10, epochs: int = 2000, lr: float = 1e-3):
"""Train the WoRF field."""
# Encode positions
encoded = positional_encoding(positions, num_frequencies)
input_dim = encoded.shape[-1]
model = WoRF(input_dim=input_dim, output_dim=features.shape[-1])
optimiser = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.MSELoss()
print(f"\nTraining WoRF: {len(positions)} chunks, {input_dim}D input, {features.shape[-1]} features")
print(f"Positional encoding frequencies: {num_frequencies}")
print("-" * 60)
for epoch in range(epochs):
pred = model(encoded)
loss = loss_fn(pred, features)
optimiser.zero_grad()
loss.backward()
optimiser.step()
if epoch % 200 == 0 or epoch == epochs - 1:
print(f" Epoch {epoch:4d}/{epochs} Loss: {loss.item():.6f}")
return model, num_frequencies
# ---------------------------------------------------------------------------
# 6. Query the Field — the interesting bit
# ---------------------------------------------------------------------------
def query_field(model: WoRF, num_frequencies: int, num_points: int = 500) -> np.ndarray:
"""Query the learned field at many points, including between training samples."""
positions = torch.linspace(0, 1, num_points).unsqueeze(-1)
encoded = positional_encoding(positions, num_frequencies)
with torch.no_grad():
predictions = model(encoded).numpy()
return positions.squeeze().numpy(), predictions
def analyse_field(positions: np.ndarray, predictions: np.ndarray,
norm_stats: dict, chunks: list[str]):
"""Analyse what the field learned."""
print("\n" + "=" * 60)
print("FIELD ANALYSIS")
print("=" * 60)
# Denormalise for interpretability
feat_min = norm_stats["min"].numpy()
feat_range = norm_stats["range"].numpy()
predictions_real = predictions * feat_range + feat_min
# Find peaks and valleys for each feature
print("\nFeature dynamics across the book:")
print("-" * 60)
for i, name in enumerate(FEATURE_NAMES):
values = predictions_real[:, i]
peak_pos = positions[np.argmax(values)]
valley_pos = positions[np.argmin(values)]
mean_val = np.mean(values)
std_val = np.std(values)
dynamic_range = np.max(values) - np.min(values)
print(f" {name:30s} mean={mean_val:.4f} std={std_val:.4f} "
f"range={dynamic_range:.4f} peak@{peak_pos:.2f} valley@{valley_pos:.2f}")
# Find story boundaries by looking for sharp transitions
print("\n\nSharp transitions (potential story/scene boundaries):")
print("-" * 60)
# Use total gradient magnitude across all features
gradients = np.diff(predictions, axis=0)
gradient_magnitude = np.sqrt(np.sum(gradients ** 2, axis=1))
# Find top transition points
top_transitions = np.argsort(gradient_magnitude)[-8:] # top 8 (roughly one per story)
top_transitions = np.sort(top_transitions)
for idx in top_transitions:
pos = positions[idx]
# Estimate which chunk this corresponds to
chunk_idx = int(pos * (len(chunks) - 1))
chunk_preview = chunks[min(chunk_idx, len(chunks) - 1)][:80]
print(f" Position {pos:.3f} (magnitude {gradient_magnitude[idx]:.4f})")
print(f" Text: \"{chunk_preview}...\"")
print()
# Compare dialogue-heavy vs narrative-heavy regions
print("\nDialogue vs Narrative rhythm:")
print("-" * 60)
dialogue_idx = FEATURE_NAMES.index("dialogue_ratio")
sent_var_idx = FEATURE_NAMES.index("sentence_length_variance")
short_idx = FEATURE_NAMES.index("short_sentence_ratio")
# Split into quartiles
n = len(positions)
for q, label in [(0, "Opening"), (1, "Early-mid"), (2, "Late-mid"), (3, "Closing")]:
start = q * n // 4
end = (q + 1) * n // 4
avg_dialogue = np.mean(predictions_real[start:end, dialogue_idx])
avg_variance = np.mean(predictions_real[start:end, sent_var_idx])
avg_short = np.mean(predictions_real[start:end, short_idx])
print(f" {label:12s} dialogue={avg_dialogue:.4f} "
f"sent_variance={avg_variance:.4f} short_ratio={avg_short:.4f}")
# Interpolation test — what does the field predict BETWEEN chunks?
print("\n\nInterpolation test (querying between training points):")
print("-" * 60)
print("The field predicts style features at positions where no text exists.")
print("If interpolation is smooth and sensible, the field learned structure.")
print("If it's noisy/random, it just memorised individual chunks.")
# Check smoothness: average absolute second derivative
second_deriv = np.diff(predictions, n=2, axis=0)
smoothness = np.mean(np.abs(second_deriv))
print(f"\n Smoothness score (lower = smoother): {smoothness:.6f}")
if smoothness < 0.01:
print(" → Very smooth field — learned continuous style patterns")
elif smoothness < 0.05:
print(" → Moderately smooth — some structure learned")
else:
print(" → Rough field — mostly memorised chunks")
return predictions_real
# ---------------------------------------------------------------------------
# 7. Save results for later
# ---------------------------------------------------------------------------
def save_results(positions, predictions_real, output_path):
"""Save the field data as JSON for potential visualisation later."""
results = {
"positions": positions.tolist(),
"features": {
name: predictions_real[:, i].tolist()
for i, name in enumerate(FEATURE_NAMES)
},
"feature_names": FEATURE_NAMES,
"description": "WoRF continuous field over Wodehouse's 'My Man Jeeves'",
}
Path(output_path).write_text(json.dumps(results, indent=2))
print(f"\nField data saved to {output_path}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
book_path = Path(__file__).parent / "pg-wood.txt"
print("WoRF — Word Radiance Field Experiment")
print("=" * 60)
print(f"Source: {book_path.name}")
# Load and split
text = load_and_clean(str(book_path))
print(f"Clean text: {len(text):,} characters, {len(text.split()):,} words")
chunks = split_into_chunks(text, chunk_size=300)
print(f"Chunks: {len(chunks)} (≈300 words each)")
# Prepare training data
positions, features, norm_stats = prepare_data(chunks)
print(f"Feature dimensions: {NUM_FEATURES}")
print(f"Features: {', '.join(FEATURE_NAMES)}")
# Train
model, num_freq = train_worf(positions, features, epochs=3000)
# Query the continuous field
query_positions, predictions = query_field(model, num_freq, num_points=1000)
# Analyse
predictions_real = analyse_field(query_positions, predictions, norm_stats, chunks)
# Save
save_results(query_positions, predictions_real,
str(book_path.parent / "worf-field-jeeves.json"))
print("\n" + "=" * 60)
print("Done. The field exists. Poke it and see what it tells you.")
print("=" * 60)
if __name__ == "__main__":
main()