go-mlx/internal/metal/model_test.go
Snider a2493e0242 test(metal): add model loading robustness tests (Phase 2)
24 new tests covering error paths in model loading:
- Missing/invalid config.json, unsupported architecture
- Missing tokenizer.json for both Gemma3 and Qwen3
- Missing safetensors: was a nil-pointer panic in precomputeScaledWeights,
  fixed with early error return in both LoadGemma3 and LoadQwen3
- Config parsing: defaults, quantization, nested text_config
- isLayerSliding sliding window pattern logic
- resolveWeight with language_model. prefix fallback

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 21:49:07 +00:00

413 lines
11 KiB
Go

//go:build darwin && arm64
package metal
import (
"os"
"path/filepath"
"strings"
"testing"
)
// --- loadModel dispatch ---
func TestLoadModel_MissingConfigJSON(t *testing.T) {
dir := t.TempDir()
_, err := loadModel(dir)
if err == nil {
t.Fatal("expected error for missing config.json")
}
if !strings.Contains(err.Error(), "config") {
t.Errorf("error should mention config, got: %v", err)
}
}
func TestLoadModel_InvalidConfigJSON(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte("{invalid"), 0644)
_, err := loadModel(dir)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
func TestLoadModel_UnsupportedArchitecture(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gpt99"}`), 0644)
_, err := loadModel(dir)
if err == nil {
t.Fatal("expected error for unsupported architecture")
}
if !strings.Contains(err.Error(), "gpt99") {
t.Errorf("error should mention architecture name, got: %v", err)
}
}
func TestLoadModel_Gemma3TextType(t *testing.T) {
// "gemma3_text" should route to Gemma3 loader (will fail on missing tokenizer, but
// that proves the dispatch happened).
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{
"model_type": "gemma3_text",
"hidden_size": 1152,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 1,
"head_dim": 256,
"vocab_size": 1000
}`), 0644)
_, err := loadModel(dir)
if err == nil {
t.Fatal("expected error (missing tokenizer), but dispatch should have reached gemma3")
}
// If the error mentions "tokenizer" or "gemma3", dispatch worked correctly.
if !strings.Contains(err.Error(), "tokenizer") && !strings.Contains(err.Error(), "gemma3") {
t.Errorf("expected gemma3 loader error, got: %v", err)
}
}
// --- LoadGemma3 error paths ---
func TestLoadGemma3_MissingTokenizer(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{
"model_type": "gemma3",
"hidden_size": 1152,
"num_hidden_layers": 1,
"num_attention_heads": 4,
"num_key_value_heads": 1,
"vocab_size": 1000
}`), 0644)
_, err := LoadGemma3(dir)
if err == nil {
t.Fatal("expected error for missing tokenizer")
}
if !strings.Contains(err.Error(), "tokenizer") {
t.Errorf("error should mention tokenizer, got: %v", err)
}
}
func TestLoadGemma3_InvalidConfig(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte("not json"), 0644)
_, err := LoadGemma3(dir)
if err == nil {
t.Fatal("expected error for invalid config")
}
}
func TestLoadGemma3_NoSafetensors(t *testing.T) {
dir := t.TempDir()
writeMinimalConfig(t, dir, "gemma3")
writeMinimalTokenizer(t, dir)
_, err := LoadGemma3(dir)
if err == nil {
t.Fatal("expected error for missing safetensors files")
}
if !strings.Contains(err.Error(), "safetensors") {
t.Errorf("error should mention safetensors, got: %v", err)
}
}
// --- LoadQwen3 error paths ---
func TestLoadQwen3_MissingConfig(t *testing.T) {
dir := t.TempDir()
_, err := LoadQwen3(dir)
if err == nil {
t.Fatal("expected error for missing config.json")
}
}
func TestLoadQwen3_InvalidConfig(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte("{broken"), 0644)
_, err := LoadQwen3(dir)
if err == nil {
t.Fatal("expected error for invalid config")
}
}
func TestLoadQwen3_MissingTokenizer(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{
"model_type": "qwen3",
"hidden_size": 1024,
"num_hidden_layers": 1,
"num_attention_heads": 8,
"num_key_value_heads": 4,
"vocab_size": 1000
}`), 0644)
_, err := LoadQwen3(dir)
if err == nil {
t.Fatal("expected error for missing tokenizer")
}
if !strings.Contains(err.Error(), "tokenizer") {
t.Errorf("error should mention tokenizer, got: %v", err)
}
}
func TestLoadQwen3_NoSafetensors(t *testing.T) {
dir := t.TempDir()
writeMinimalConfig(t, dir, "qwen3")
writeMinimalTokenizer(t, dir)
_, err := LoadQwen3(dir)
if err == nil {
t.Fatal("expected error for missing safetensors files")
}
if !strings.Contains(err.Error(), "safetensors") {
t.Errorf("error should mention safetensors, got: %v", err)
}
}
// --- LoadAndInit error paths ---
func TestLoadAndInit_MissingPath(t *testing.T) {
_, err := LoadAndInit("/nonexistent/model/path")
if err == nil {
t.Fatal("expected error for nonexistent path")
}
}
func TestLoadAndInit_UnsupportedArch(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "falcon"}`), 0644)
_, err := LoadAndInit(dir)
if err == nil {
t.Fatal("expected error for unsupported architecture")
}
if !strings.Contains(err.Error(), "falcon") {
t.Errorf("error should mention architecture, got: %v", err)
}
}
func TestLoadAndInit_NoSafetensors(t *testing.T) {
dir := t.TempDir()
writeMinimalConfig(t, dir, "gemma3")
writeMinimalTokenizer(t, dir)
_, err := LoadAndInit(dir, LoadConfig{ContextLen: 2048})
if err == nil {
t.Fatal("expected error for missing safetensors")
}
}
// --- parseConfig ---
func TestParseConfig_Defaults(t *testing.T) {
cfg, err := parseConfig([]byte(`{
"hidden_size": 1024,
"num_hidden_layers": 8,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 128
}`))
if err != nil {
t.Fatalf("parseConfig: %v", err)
}
if cfg.RopeTheta != 1000000 {
t.Errorf("RopeTheta default = %f, want 1000000", cfg.RopeTheta)
}
if cfg.RopeLocalBaseFreq != 10000 {
t.Errorf("RopeLocalBaseFreq default = %f, want 10000", cfg.RopeLocalBaseFreq)
}
if cfg.RMSNormEps != 1e-6 {
t.Errorf("RMSNormEps default = %f, want 1e-6", cfg.RMSNormEps)
}
if cfg.SlidingWindowPattern != 6 {
t.Errorf("SlidingWindowPattern default = %d, want 6", cfg.SlidingWindowPattern)
}
if cfg.VocabSize != 262208 {
t.Errorf("VocabSize default = %d, want 262208", cfg.VocabSize)
}
}
func TestParseConfig_QuantizationTopLevel(t *testing.T) {
cfg, err := parseConfig([]byte(`{
"hidden_size": 1024,
"num_hidden_layers": 8,
"num_attention_heads": 4,
"head_dim": 128,
"quantization": {"group_size": 64, "bits": 4}
}`))
if err != nil {
t.Fatalf("parseConfig: %v", err)
}
if cfg.Quantization == nil {
t.Fatal("expected quantization config")
}
if cfg.Quantization.GroupSize != 64 {
t.Errorf("GroupSize = %d, want 64", cfg.Quantization.GroupSize)
}
if cfg.Quantization.Bits != 4 {
t.Errorf("Bits = %d, want 4", cfg.Quantization.Bits)
}
}
func TestParseConfig_NestedTextConfig(t *testing.T) {
// Multimodal Gemma3 has text_config nested inside a wrapper.
cfg, err := parseConfig([]byte(`{
"model_type": "gemma3",
"text_config": {
"hidden_size": 2048,
"num_hidden_layers": 16,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 256,
"vocab_size": 262144
}
}`))
if err != nil {
t.Fatalf("parseConfig: %v", err)
}
if cfg.HiddenSize != 2048 {
t.Errorf("HiddenSize = %d, want 2048", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 16 {
t.Errorf("NumHiddenLayers = %d, want 16", cfg.NumHiddenLayers)
}
}
func TestParseConfig_InvalidJSON(t *testing.T) {
_, err := parseConfig([]byte("not json"))
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
// --- parseQwen3Config ---
func TestParseQwen3Config_Defaults(t *testing.T) {
cfg, err := parseQwen3Config([]byte(`{
"hidden_size": 1024,
"num_hidden_layers": 8,
"num_attention_heads": 4,
"num_key_value_heads": 2
}`))
if err != nil {
t.Fatalf("parseQwen3Config: %v", err)
}
if cfg.HeadDim != 256 { // 1024/4
t.Errorf("HeadDim = %d, want 256 (hidden/heads)", cfg.HeadDim)
}
if cfg.RopeTheta != 1000000 {
t.Errorf("RopeTheta default = %f, want 1000000", cfg.RopeTheta)
}
if cfg.VocabSize != 151936 {
t.Errorf("VocabSize default = %d, want 151936", cfg.VocabSize)
}
}
func TestParseQwen3Config_InvalidJSON(t *testing.T) {
_, err := parseQwen3Config([]byte("{broken"))
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
// --- isLayerSliding ---
func TestIsLayerSliding(t *testing.T) {
// Pattern=6: every 6th layer is NOT sliding (global attention).
// Layer 5 (index=5, i+1=6) → 6%6=0 → not sliding (global)
// Layer 0 (index=0, i+1=1) → 1%6=1 → sliding
tests := []struct {
idx int32
pattern int32
want bool
}{
{0, 6, true}, // layer 1: 1%6=1 → sliding
{4, 6, true}, // layer 5: 5%6=5 → sliding
{5, 6, false}, // layer 6: 6%6=0 → global
{11, 6, false}, // layer 12: 12%6=0 → global
{0, 0, false}, // pattern=0 → no sliding
{0, -1, false}, // pattern<0 → no sliding
}
for _, tt := range tests {
got := isLayerSliding(tt.idx, tt.pattern)
if got != tt.want {
t.Errorf("isLayerSliding(%d, %d) = %v, want %v", tt.idx, tt.pattern, got, tt.want)
}
}
}
// --- resolveWeight ---
func TestResolveWeight_Direct(t *testing.T) {
a := FromValue(float32(1))
weights := map[string]*Array{"model.norm.weight": a}
got := resolveWeight(weights, "model.norm.weight")
if got != a {
t.Error("expected direct name resolution")
}
}
func TestResolveWeight_LanguageModelPrefix(t *testing.T) {
a := FromValue(float32(1))
weights := map[string]*Array{"language_model.model.norm.weight": a}
got := resolveWeight(weights, "model.norm.weight")
if got != a {
t.Error("expected language_model. prefix fallback")
}
}
func TestResolveWeight_NotFound(t *testing.T) {
weights := map[string]*Array{}
got := resolveWeight(weights, "nonexistent")
if got != nil {
t.Error("expected nil for missing weight")
}
}
// --- helpers ---
// writeMinimalConfig writes a minimal valid config.json for testing.
func writeMinimalConfig(t *testing.T, dir string, modelType string) {
t.Helper()
config := `{
"model_type": "` + modelType + `",
"hidden_size": 64,
"num_hidden_layers": 1,
"intermediate_size": 128,
"num_attention_heads": 2,
"num_key_value_heads": 1,
"head_dim": 32,
"vocab_size": 100,
"rms_norm_eps": 1e-6
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(config), 0644); err != nil {
t.Fatalf("write config.json: %v", err)
}
}
// writeMinimalTokenizer writes a minimal valid tokenizer.json for testing.
func writeMinimalTokenizer(t *testing.T, dir string) {
t.Helper()
tokenizer := `{
"model": {
"type": "BPE",
"vocab": {"<pad>": 0, "<eos>": 1, "<bos>": 2, "hello": 3, "world": 4},
"merges": []
},
"added_tokens": [
{"id": 0, "content": "<pad>", "special": true},
{"id": 1, "content": "<eos>", "special": true},
{"id": 2, "content": "<bos>", "special": true}
]
}`
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizer), 0644); err != nil {
t.Fatalf("write tokenizer.json: %v", err)
}
}