NeRF-inspired technique for learning relational dynamics of language. Not what words mean, but how they behave together — rhythm, pacing, punctuation patterns, style transitions. v1: positional field over text (baseline, memorises) v2: masked feature prediction (relational, actually works) Trained on Wodehouse "My Man Jeeves" (public domain, Gutenberg). All 11 style features are highly relational — the field learns that Wodehouse's style is a tightly coupled system. Key finding: style interpolation between narrative and dialogue produces sensible predictions for unmeasured features, suggesting the continuous field captures real structural patterns. Co-Authored-By: Virgil <virgil@lethean.io>
390 lines
14 KiB
Python
390 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
WoRF Experiment — Word Radiance Field
|
||
======================================
|
||
|
||
Feed Wodehouse's "My Man Jeeves" into a NeRF-like MLP and see what
|
||
the continuous field learns about writing style.
|
||
|
||
NeRF: (x, y, z, θ, φ) → (r, g, b, σ)
|
||
WoRF: (position_in_text, chunk_context) → (style_features)
|
||
|
||
First pass: 1D position → style feature vector. No viewing angle yet.
|
||
Just see if a continuous field over text position learns anything.
|
||
"""
|
||
|
||
import re
|
||
import math
|
||
import json
|
||
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from collections import Counter
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 1. Text Splitting — each "page" is one observation (like one photo for NeRF)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def load_and_clean(path: str) -> str:
|
||
"""Strip Gutenberg header/footer, return clean text."""
|
||
text = Path(path).read_text(encoding="utf-8")
|
||
# Strip PG header
|
||
start = text.find("LEAVE IT TO JEEVES")
|
||
if start == -1:
|
||
start = text.find("*** START OF")
|
||
start = text.find("\n", start) + 1
|
||
# Strip PG footer
|
||
end = text.find("*** END OF THE PROJECT GUTENBERG")
|
||
if end == -1:
|
||
end = len(text)
|
||
return text[start:end].strip()
|
||
|
||
|
||
def split_into_chunks(text: str, chunk_size: int = 500) -> list[str]:
|
||
"""Split text into roughly equal word-count chunks (pages)."""
|
||
words = text.split()
|
||
chunks = []
|
||
for i in range(0, len(words), chunk_size):
|
||
chunk = " ".join(words[i:i + chunk_size])
|
||
if len(chunk.split()) > 50: # skip tiny trailing chunks
|
||
chunks.append(chunk)
|
||
return chunks
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 2. Feature Extraction — what we measure about each "page"
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def extract_features(chunk: str) -> dict:
|
||
"""Extract stylistic features from a chunk of text.
|
||
|
||
These are the 'RGB + density' equivalent — what the field predicts.
|
||
"""
|
||
words = chunk.split()
|
||
sentences = re.split(r'[.!?]+', chunk)
|
||
sentences = [s.strip() for s in sentences if s.strip()]
|
||
|
||
word_lengths = [len(w.strip(".,;:!?\"'()—-")) for w in words]
|
||
word_lengths = [l for l in word_lengths if l > 0]
|
||
|
||
# Dialogue detection
|
||
dialogue_chars = sum(1 for c in chunk if c == '"')
|
||
total_chars = len(chunk) or 1
|
||
|
||
# Punctuation patterns (Wodehouse loves dashes and exclamations)
|
||
dashes = chunk.count("—") + chunk.count("--")
|
||
exclamations = chunk.count("!")
|
||
questions = chunk.count("?")
|
||
commas = chunk.count(",")
|
||
|
||
# Vocabulary richness (unique words / total words)
|
||
unique_words = len(set(w.lower().strip(".,;:!?\"'()—-") for w in words))
|
||
total_words = len(words) or 1
|
||
|
||
# Sentence length variation (std dev) — captures rhythm
|
||
sent_lengths = [len(s.split()) for s in sentences]
|
||
sent_mean = np.mean(sent_lengths) if sent_lengths else 0
|
||
sent_std = np.std(sent_lengths) if sent_lengths else 0
|
||
|
||
# Short sentence ratio (punchy lines like "Injudicious, sir.")
|
||
short_sentences = sum(1 for l in sent_lengths if l <= 5)
|
||
short_ratio = short_sentences / (len(sent_lengths) or 1)
|
||
|
||
# Aside/parenthetical density (commas, dashes per word)
|
||
aside_density = (commas + dashes) / total_words
|
||
|
||
return {
|
||
"avg_word_length": np.mean(word_lengths) if word_lengths else 0,
|
||
"avg_sentence_length": sent_mean,
|
||
"sentence_length_variance": sent_std,
|
||
"dialogue_ratio": dialogue_chars / total_chars,
|
||
"vocabulary_richness": unique_words / total_words,
|
||
"dash_density": dashes / total_words,
|
||
"exclamation_density": exclamations / total_words,
|
||
"question_density": questions / total_words,
|
||
"short_sentence_ratio": short_ratio,
|
||
"aside_density": aside_density,
|
||
"avg_punct_per_sentence": (commas + dashes + exclamations + questions) / (len(sent_lengths) or 1),
|
||
}
|
||
|
||
|
||
FEATURE_NAMES = list(extract_features("dummy text here for keys").keys())
|
||
NUM_FEATURES = len(FEATURE_NAMES)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 3. Positional Encoding — NeRF's trick for capturing high-frequency detail
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def positional_encoding(x: torch.Tensor, num_frequencies: int = 10) -> torch.Tensor:
|
||
"""NeRF-style sinusoidal positional encoding.
|
||
|
||
Maps a scalar position into a higher-dimensional space so the MLP
|
||
can learn sharp transitions (same reason NeRF needs it for edges).
|
||
"""
|
||
encodings = [x]
|
||
for freq in range(num_frequencies):
|
||
encodings.append(torch.sin(2.0 ** freq * math.pi * x))
|
||
encodings.append(torch.cos(2.0 ** freq * math.pi * x))
|
||
return torch.cat(encodings, dim=-1)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 4. The WoRF Network — tiny MLP, same architecture as vanilla NeRF
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class WoRF(nn.Module):
|
||
"""Word Radiance Field — learns a continuous style field over text position."""
|
||
|
||
def __init__(self, input_dim: int, hidden_dim: int = 128, num_layers: int = 4,
|
||
output_dim: int = NUM_FEATURES):
|
||
super().__init__()
|
||
|
||
layers = []
|
||
layers.append(nn.Linear(input_dim, hidden_dim))
|
||
layers.append(nn.ReLU())
|
||
|
||
for i in range(num_layers - 2):
|
||
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
||
layers.append(nn.ReLU())
|
||
# Skip connection at midpoint (like NeRF)
|
||
if i == (num_layers - 2) // 2 - 1:
|
||
self.skip_layer_idx = len(layers)
|
||
|
||
layers.append(nn.Linear(hidden_dim, output_dim))
|
||
|
||
self.network = nn.Sequential(*layers)
|
||
self.input_dim = input_dim
|
||
self.hidden_dim = hidden_dim
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self.network(x)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 5. Training
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def prepare_data(chunks: list[str]) -> tuple[torch.Tensor, torch.Tensor, dict]:
|
||
"""Convert chunks to training data: positions → features."""
|
||
n = len(chunks)
|
||
positions = []
|
||
features = []
|
||
|
||
for i, chunk in enumerate(chunks):
|
||
pos = i / (n - 1) # normalise to [0, 1]
|
||
feat = extract_features(chunk)
|
||
positions.append(pos)
|
||
features.append([feat[k] for k in FEATURE_NAMES])
|
||
|
||
positions = torch.tensor(positions, dtype=torch.float32).unsqueeze(-1)
|
||
features = torch.tensor(features, dtype=torch.float32)
|
||
|
||
# Normalise features to [0, 1] range for training stability
|
||
feat_min = features.min(dim=0).values
|
||
feat_max = features.max(dim=0).values
|
||
feat_range = feat_max - feat_min
|
||
feat_range[feat_range == 0] = 1.0 # avoid division by zero
|
||
features_norm = (features - feat_min) / feat_range
|
||
|
||
norm_stats = {"min": feat_min, "max": feat_max, "range": feat_range}
|
||
|
||
return positions, features_norm, norm_stats
|
||
|
||
|
||
def train_worf(positions: torch.Tensor, features: torch.Tensor,
|
||
num_frequencies: int = 10, epochs: int = 2000, lr: float = 1e-3):
|
||
"""Train the WoRF field."""
|
||
# Encode positions
|
||
encoded = positional_encoding(positions, num_frequencies)
|
||
input_dim = encoded.shape[-1]
|
||
|
||
model = WoRF(input_dim=input_dim, output_dim=features.shape[-1])
|
||
optimiser = torch.optim.Adam(model.parameters(), lr=lr)
|
||
loss_fn = nn.MSELoss()
|
||
|
||
print(f"\nTraining WoRF: {len(positions)} chunks, {input_dim}D input, {features.shape[-1]} features")
|
||
print(f"Positional encoding frequencies: {num_frequencies}")
|
||
print("-" * 60)
|
||
|
||
for epoch in range(epochs):
|
||
pred = model(encoded)
|
||
loss = loss_fn(pred, features)
|
||
|
||
optimiser.zero_grad()
|
||
loss.backward()
|
||
optimiser.step()
|
||
|
||
if epoch % 200 == 0 or epoch == epochs - 1:
|
||
print(f" Epoch {epoch:4d}/{epochs} Loss: {loss.item():.6f}")
|
||
|
||
return model, num_frequencies
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 6. Query the Field — the interesting bit
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def query_field(model: WoRF, num_frequencies: int, num_points: int = 500) -> np.ndarray:
|
||
"""Query the learned field at many points, including between training samples."""
|
||
positions = torch.linspace(0, 1, num_points).unsqueeze(-1)
|
||
encoded = positional_encoding(positions, num_frequencies)
|
||
|
||
with torch.no_grad():
|
||
predictions = model(encoded).numpy()
|
||
|
||
return positions.squeeze().numpy(), predictions
|
||
|
||
|
||
def analyse_field(positions: np.ndarray, predictions: np.ndarray,
|
||
norm_stats: dict, chunks: list[str]):
|
||
"""Analyse what the field learned."""
|
||
print("\n" + "=" * 60)
|
||
print("FIELD ANALYSIS")
|
||
print("=" * 60)
|
||
|
||
# Denormalise for interpretability
|
||
feat_min = norm_stats["min"].numpy()
|
||
feat_range = norm_stats["range"].numpy()
|
||
predictions_real = predictions * feat_range + feat_min
|
||
|
||
# Find peaks and valleys for each feature
|
||
print("\nFeature dynamics across the book:")
|
||
print("-" * 60)
|
||
|
||
for i, name in enumerate(FEATURE_NAMES):
|
||
values = predictions_real[:, i]
|
||
peak_pos = positions[np.argmax(values)]
|
||
valley_pos = positions[np.argmin(values)]
|
||
mean_val = np.mean(values)
|
||
std_val = np.std(values)
|
||
dynamic_range = np.max(values) - np.min(values)
|
||
|
||
print(f" {name:30s} mean={mean_val:.4f} std={std_val:.4f} "
|
||
f"range={dynamic_range:.4f} peak@{peak_pos:.2f} valley@{valley_pos:.2f}")
|
||
|
||
# Find story boundaries by looking for sharp transitions
|
||
print("\n\nSharp transitions (potential story/scene boundaries):")
|
||
print("-" * 60)
|
||
|
||
# Use total gradient magnitude across all features
|
||
gradients = np.diff(predictions, axis=0)
|
||
gradient_magnitude = np.sqrt(np.sum(gradients ** 2, axis=1))
|
||
|
||
# Find top transition points
|
||
top_transitions = np.argsort(gradient_magnitude)[-8:] # top 8 (roughly one per story)
|
||
top_transitions = np.sort(top_transitions)
|
||
|
||
for idx in top_transitions:
|
||
pos = positions[idx]
|
||
# Estimate which chunk this corresponds to
|
||
chunk_idx = int(pos * (len(chunks) - 1))
|
||
chunk_preview = chunks[min(chunk_idx, len(chunks) - 1)][:80]
|
||
print(f" Position {pos:.3f} (magnitude {gradient_magnitude[idx]:.4f})")
|
||
print(f" Text: \"{chunk_preview}...\"")
|
||
print()
|
||
|
||
# Compare dialogue-heavy vs narrative-heavy regions
|
||
print("\nDialogue vs Narrative rhythm:")
|
||
print("-" * 60)
|
||
|
||
dialogue_idx = FEATURE_NAMES.index("dialogue_ratio")
|
||
sent_var_idx = FEATURE_NAMES.index("sentence_length_variance")
|
||
short_idx = FEATURE_NAMES.index("short_sentence_ratio")
|
||
|
||
# Split into quartiles
|
||
n = len(positions)
|
||
for q, label in [(0, "Opening"), (1, "Early-mid"), (2, "Late-mid"), (3, "Closing")]:
|
||
start = q * n // 4
|
||
end = (q + 1) * n // 4
|
||
avg_dialogue = np.mean(predictions_real[start:end, dialogue_idx])
|
||
avg_variance = np.mean(predictions_real[start:end, sent_var_idx])
|
||
avg_short = np.mean(predictions_real[start:end, short_idx])
|
||
print(f" {label:12s} dialogue={avg_dialogue:.4f} "
|
||
f"sent_variance={avg_variance:.4f} short_ratio={avg_short:.4f}")
|
||
|
||
# Interpolation test — what does the field predict BETWEEN chunks?
|
||
print("\n\nInterpolation test (querying between training points):")
|
||
print("-" * 60)
|
||
print("The field predicts style features at positions where no text exists.")
|
||
print("If interpolation is smooth and sensible, the field learned structure.")
|
||
print("If it's noisy/random, it just memorised individual chunks.")
|
||
|
||
# Check smoothness: average absolute second derivative
|
||
second_deriv = np.diff(predictions, n=2, axis=0)
|
||
smoothness = np.mean(np.abs(second_deriv))
|
||
print(f"\n Smoothness score (lower = smoother): {smoothness:.6f}")
|
||
|
||
if smoothness < 0.01:
|
||
print(" → Very smooth field — learned continuous style patterns")
|
||
elif smoothness < 0.05:
|
||
print(" → Moderately smooth — some structure learned")
|
||
else:
|
||
print(" → Rough field — mostly memorised chunks")
|
||
|
||
return predictions_real
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 7. Save results for later
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def save_results(positions, predictions_real, output_path):
|
||
"""Save the field data as JSON for potential visualisation later."""
|
||
results = {
|
||
"positions": positions.tolist(),
|
||
"features": {
|
||
name: predictions_real[:, i].tolist()
|
||
for i, name in enumerate(FEATURE_NAMES)
|
||
},
|
||
"feature_names": FEATURE_NAMES,
|
||
"description": "WoRF continuous field over Wodehouse's 'My Man Jeeves'",
|
||
}
|
||
Path(output_path).write_text(json.dumps(results, indent=2))
|
||
print(f"\nField data saved to {output_path}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def main():
|
||
book_path = Path(__file__).parent / "pg-wood.txt"
|
||
|
||
print("WoRF — Word Radiance Field Experiment")
|
||
print("=" * 60)
|
||
print(f"Source: {book_path.name}")
|
||
|
||
# Load and split
|
||
text = load_and_clean(str(book_path))
|
||
print(f"Clean text: {len(text):,} characters, {len(text.split()):,} words")
|
||
|
||
chunks = split_into_chunks(text, chunk_size=300)
|
||
print(f"Chunks: {len(chunks)} (≈300 words each)")
|
||
|
||
# Prepare training data
|
||
positions, features, norm_stats = prepare_data(chunks)
|
||
print(f"Feature dimensions: {NUM_FEATURES}")
|
||
print(f"Features: {', '.join(FEATURE_NAMES)}")
|
||
|
||
# Train
|
||
model, num_freq = train_worf(positions, features, epochs=3000)
|
||
|
||
# Query the continuous field
|
||
query_positions, predictions = query_field(model, num_freq, num_points=1000)
|
||
|
||
# Analyse
|
||
predictions_real = analyse_field(query_positions, predictions, norm_stats, chunks)
|
||
|
||
# Save
|
||
save_results(query_positions, predictions_real,
|
||
str(book_path.parent / "worf-field-jeeves.json"))
|
||
|
||
print("\n" + "=" * 60)
|
||
print("Done. The field exists. Poke it and see what it tells you.")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|