feat(metal): validate Gemma3-1B inference end-to-end (Phase 2)

- Fix model_type "gemma3_text" not matched in architecture dispatch
- Fix GPT-2 BPE false detection on large SentencePiece vocabs (Gemma3
  262K vocab contains Ġ but uses ▁ for spaces — check "Ġthe" not bare "Ġ")
- Add TestGemma3_1B_Inference: greedy decode, 46 tok/s, coherent output
- Add TestGemma3_1B_Chat: validates chat template formatting
- Add TestGemma3_1B_ContextCancel: validates ctx.Done() stops generation

4-bit quantised Gemma3-1B loads in ~700ms, generates at 46 tok/s on M3 Ultra.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-19 21:44:28 +00:00
parent 443347a2f8
commit 18e8dca9f8
4 changed files with 137 additions and 8 deletions

View file

@ -13,7 +13,7 @@ Dispatched from core/go orchestration. Pick up tasks in order.
## Phase 2: Model Support
- [ ] **Gemma3-1B inference validation** — The go-i18n Phase 2a needs 1B model inference for domain classification at ~5K sentences/sec. Validate Gemma3-1B loads and generates correctly via `mlx.LoadModel()` + `m.Generate()`. Report tokens/sec.
- [x] **Gemma3-1B inference validation** — ✅ End-to-end inference works. 4-bit quantised Gemma3-1B loads and generates coherently at **46 tok/s** on M3 Ultra. Fixed: `model_type: "gemma3_text"` not matched in architecture dispatch, GPT-2 BPE false detection on 262K SentencePiece vocab (checked `Ġthe` instead of bare `Ġ`). 3 new tests: inference (greedy, timing), chat template, context cancellation.
- [ ] **Model loading robustness** — Test with missing files, corrupted safetensors, wrong dtype. Currently no error handling tests for `io.go`.
- [ ] **Add Llama model support** — Only Gemma3 and Qwen3 exist. Llama architecture would cover Meta's model family (Llama 3, CodeLlama).

View file

@ -65,7 +65,7 @@ func loadModel(modelPath string) (InternalModel, error) {
switch probe.ModelType {
case "qwen3":
return LoadQwen3(modelPath)
case "gemma3", "gemma2":
case "gemma3", "gemma3_text", "gemma2":
return LoadGemma3(modelPath)
default:
return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType)

View file

@ -111,8 +111,11 @@ func LoadTokenizer(path string) (*Tokenizer, error) {
t.invVocab[tok.ID] = tok.Content
}
// Detect GPT-2 byte-level BPE (Qwen, GPT, Llama use Ġ for space)
if _, ok := t.vocab["Ġ"]; ok {
// Detect GPT-2 byte-level BPE (Qwen, GPT, DeepSeek use Ġ for space).
// Check for "Ġthe" rather than bare "Ġ" — large SentencePiece vocabs
// (Gemma3 262K) may include Ġ as an obscure character without using
// GPT-2 byte encoding.
if _, ok := t.vocab["Ġthe"]; ok {
t.isGPT2BPE = true
t.gpt2Decoder, t.gpt2Encoder = buildGPT2ByteMaps()
}

View file

@ -5,7 +5,9 @@ package mlx_test
import (
"context"
"os"
"strings"
"testing"
"time"
"forge.lthn.ai/core/go-inference"
_ "forge.lthn.ai/core/go-mlx"
@ -137,12 +139,25 @@ func TestLoadOptionsDefaults(t *testing.T) {
}
}
// gemma3ModelPath returns the path to a Gemma3-1B model on disk, or skips.
func gemma3ModelPath(t *testing.T) string {
t.Helper()
paths := []string{
"/Volumes/Data/lem/gemma-3-1b-it-base",
"/Volumes/Data/lem/safetensors/gemma-3/",
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return p
}
}
t.Skip("no Gemma3 model available")
return ""
}
// TestLoadModel_Generate requires a model on disk. Skipped in CI.
func TestLoadModel_Generate(t *testing.T) {
const modelPath = "/Volumes/Data/lem/safetensors/gemma-3/"
if _, err := os.Stat(modelPath); err != nil {
t.Skip("model not available at", modelPath)
}
modelPath := gemma3ModelPath(t)
m, err := inference.LoadModel(modelPath)
if err != nil {
@ -168,3 +183,114 @@ func TestLoadModel_Generate(t *testing.T) {
}
t.Logf("Generated %d tokens", count)
}
// TestGemma3_1B_Inference validates end-to-end inference with Gemma3-1B.
// Reports tokens/sec for prefill and decode phases.
func TestGemma3_1B_Inference(t *testing.T) {
modelPath := gemma3ModelPath(t)
loadStart := time.Now()
m, err := inference.LoadModel(modelPath)
loadDur := time.Since(loadStart)
if err != nil {
t.Fatalf("LoadModel: %v", err)
}
defer m.Close()
t.Logf("Model loaded in %s", loadDur)
if m.ModelType() != "gemma3" {
t.Fatalf("ModelType() = %q, want %q", m.ModelType(), "gemma3")
}
// Generate with greedy sampling (temperature=0) for deterministic output.
ctx := context.Background()
const maxTokens = 64
genStart := time.Now()
var tokens []inference.Token
var output strings.Builder
for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(maxTokens)) {
tokens = append(tokens, tok)
output.WriteString(tok.Text)
}
genDur := time.Since(genStart)
if err := m.Err(); err != nil {
t.Fatalf("Generate error: %v", err)
}
nTokens := len(tokens)
if nTokens == 0 {
t.Fatal("Generate produced no tokens")
}
tps := float64(nTokens) / genDur.Seconds()
t.Logf("Generated %d tokens in %s (%.1f tok/s)", nTokens, genDur, tps)
t.Logf("Output: %s", output.String())
// Log individual tokens for debugging.
for i, tok := range tokens {
t.Logf(" [%d] id=%d %q", i, tok.ID, tok.Text)
}
// Sanity: the output should contain something related to "4".
if !strings.Contains(output.String(), "4") {
t.Errorf("Expected output to contain '4' for 'What is 2+2?', got: %s", output.String())
}
}
// TestGemma3_1B_Chat validates chat template formatting and generation.
func TestGemma3_1B_Chat(t *testing.T) {
modelPath := gemma3ModelPath(t)
m, err := inference.LoadModel(modelPath)
if err != nil {
t.Fatalf("LoadModel: %v", err)
}
defer m.Close()
ctx := context.Background()
var output strings.Builder
var count int
for tok := range m.Chat(ctx, []inference.Message{
{Role: "user", Content: "Reply with exactly one word: the capital of France."},
}, inference.WithMaxTokens(16)) {
output.WriteString(tok.Text)
count++
}
if err := m.Err(); err != nil {
t.Fatalf("Chat error: %v", err)
}
if count == 0 {
t.Fatal("Chat produced no tokens")
}
t.Logf("Chat output (%d tokens): %s", count, output.String())
}
// TestGemma3_1B_ContextCancel validates that context cancellation stops generation.
func TestGemma3_1B_ContextCancel(t *testing.T) {
modelPath := gemma3ModelPath(t)
m, err := inference.LoadModel(modelPath)
if err != nil {
t.Fatalf("LoadModel: %v", err)
}
defer m.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var count int
for range m.Generate(ctx, "Tell me a long story about dragons.", inference.WithMaxTokens(1000)) {
count++
if count >= 5 {
cancel()
}
}
if count > 20 {
t.Errorf("Expected generation to stop near 5 tokens after cancel, got %d", count)
}
if err := m.Err(); err != context.Canceled {
t.Logf("Err() = %v (expected context.Canceled or nil)", err)
}
t.Logf("Stopped after %d tokens", count)
}