24 new tests covering error paths in model loading: - Missing/invalid config.json, unsupported architecture - Missing tokenizer.json for both Gemma3 and Qwen3 - Missing safetensors: was a nil-pointer panic in precomputeScaledWeights, fixed with early error return in both LoadGemma3 and LoadQwen3 - Config parsing: defaults, quantization, nested text_config - isLayerSliding sliding window pattern logic - resolveWeight with language_model. prefix fallback Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
413 lines
11 KiB
Go
413 lines
11 KiB
Go
//go:build darwin && arm64
|
|
|
|
package metal
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
// --- loadModel dispatch ---
|
|
|
|
func TestLoadModel_MissingConfigJSON(t *testing.T) {
|
|
dir := t.TempDir()
|
|
_, err := loadModel(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing config.json")
|
|
}
|
|
if !strings.Contains(err.Error(), "config") {
|
|
t.Errorf("error should mention config, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestLoadModel_InvalidConfigJSON(t *testing.T) {
|
|
dir := t.TempDir()
|
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte("{invalid"), 0644)
|
|
|
|
_, err := loadModel(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid JSON")
|
|
}
|
|
}
|
|
|
|
func TestLoadModel_UnsupportedArchitecture(t *testing.T) {
|
|
dir := t.TempDir()
|
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gpt99"}`), 0644)
|
|
|
|
_, err := loadModel(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for unsupported architecture")
|
|
}
|
|
if !strings.Contains(err.Error(), "gpt99") {
|
|
t.Errorf("error should mention architecture name, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestLoadModel_Gemma3TextType(t *testing.T) {
|
|
// "gemma3_text" should route to Gemma3 loader (will fail on missing tokenizer, but
|
|
// that proves the dispatch happened).
|
|
dir := t.TempDir()
|
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{
|
|
"model_type": "gemma3_text",
|
|
"hidden_size": 1152,
|
|
"num_hidden_layers": 2,
|
|
"num_attention_heads": 4,
|
|
"num_key_value_heads": 1,
|
|
"head_dim": 256,
|
|
"vocab_size": 1000
|
|
}`), 0644)
|
|
|
|
_, err := loadModel(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error (missing tokenizer), but dispatch should have reached gemma3")
|
|
}
|
|
// If the error mentions "tokenizer" or "gemma3", dispatch worked correctly.
|
|
if !strings.Contains(err.Error(), "tokenizer") && !strings.Contains(err.Error(), "gemma3") {
|
|
t.Errorf("expected gemma3 loader error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// --- LoadGemma3 error paths ---
|
|
|
|
func TestLoadGemma3_MissingTokenizer(t *testing.T) {
|
|
dir := t.TempDir()
|
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{
|
|
"model_type": "gemma3",
|
|
"hidden_size": 1152,
|
|
"num_hidden_layers": 1,
|
|
"num_attention_heads": 4,
|
|
"num_key_value_heads": 1,
|
|
"vocab_size": 1000
|
|
}`), 0644)
|
|
|
|
_, err := LoadGemma3(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing tokenizer")
|
|
}
|
|
if !strings.Contains(err.Error(), "tokenizer") {
|
|
t.Errorf("error should mention tokenizer, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestLoadGemma3_InvalidConfig(t *testing.T) {
|
|
dir := t.TempDir()
|
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte("not json"), 0644)
|
|
|
|
_, err := LoadGemma3(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid config")
|
|
}
|
|
}
|
|
|
|
func TestLoadGemma3_NoSafetensors(t *testing.T) {
|
|
dir := t.TempDir()
|
|
writeMinimalConfig(t, dir, "gemma3")
|
|
writeMinimalTokenizer(t, dir)
|
|
|
|
_, err := LoadGemma3(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing safetensors files")
|
|
}
|
|
if !strings.Contains(err.Error(), "safetensors") {
|
|
t.Errorf("error should mention safetensors, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// --- LoadQwen3 error paths ---
|
|
|
|
func TestLoadQwen3_MissingConfig(t *testing.T) {
|
|
dir := t.TempDir()
|
|
_, err := LoadQwen3(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing config.json")
|
|
}
|
|
}
|
|
|
|
func TestLoadQwen3_InvalidConfig(t *testing.T) {
|
|
dir := t.TempDir()
|
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte("{broken"), 0644)
|
|
|
|
_, err := LoadQwen3(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid config")
|
|
}
|
|
}
|
|
|
|
func TestLoadQwen3_MissingTokenizer(t *testing.T) {
|
|
dir := t.TempDir()
|
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{
|
|
"model_type": "qwen3",
|
|
"hidden_size": 1024,
|
|
"num_hidden_layers": 1,
|
|
"num_attention_heads": 8,
|
|
"num_key_value_heads": 4,
|
|
"vocab_size": 1000
|
|
}`), 0644)
|
|
|
|
_, err := LoadQwen3(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing tokenizer")
|
|
}
|
|
if !strings.Contains(err.Error(), "tokenizer") {
|
|
t.Errorf("error should mention tokenizer, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestLoadQwen3_NoSafetensors(t *testing.T) {
|
|
dir := t.TempDir()
|
|
writeMinimalConfig(t, dir, "qwen3")
|
|
writeMinimalTokenizer(t, dir)
|
|
|
|
_, err := LoadQwen3(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing safetensors files")
|
|
}
|
|
if !strings.Contains(err.Error(), "safetensors") {
|
|
t.Errorf("error should mention safetensors, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// --- LoadAndInit error paths ---
|
|
|
|
func TestLoadAndInit_MissingPath(t *testing.T) {
|
|
_, err := LoadAndInit("/nonexistent/model/path")
|
|
if err == nil {
|
|
t.Fatal("expected error for nonexistent path")
|
|
}
|
|
}
|
|
|
|
func TestLoadAndInit_UnsupportedArch(t *testing.T) {
|
|
dir := t.TempDir()
|
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "falcon"}`), 0644)
|
|
|
|
_, err := LoadAndInit(dir)
|
|
if err == nil {
|
|
t.Fatal("expected error for unsupported architecture")
|
|
}
|
|
if !strings.Contains(err.Error(), "falcon") {
|
|
t.Errorf("error should mention architecture, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestLoadAndInit_NoSafetensors(t *testing.T) {
|
|
dir := t.TempDir()
|
|
writeMinimalConfig(t, dir, "gemma3")
|
|
writeMinimalTokenizer(t, dir)
|
|
|
|
_, err := LoadAndInit(dir, LoadConfig{ContextLen: 2048})
|
|
if err == nil {
|
|
t.Fatal("expected error for missing safetensors")
|
|
}
|
|
}
|
|
|
|
// --- parseConfig ---
|
|
|
|
func TestParseConfig_Defaults(t *testing.T) {
|
|
cfg, err := parseConfig([]byte(`{
|
|
"hidden_size": 1024,
|
|
"num_hidden_layers": 8,
|
|
"num_attention_heads": 4,
|
|
"num_key_value_heads": 2,
|
|
"head_dim": 128
|
|
}`))
|
|
if err != nil {
|
|
t.Fatalf("parseConfig: %v", err)
|
|
}
|
|
if cfg.RopeTheta != 1000000 {
|
|
t.Errorf("RopeTheta default = %f, want 1000000", cfg.RopeTheta)
|
|
}
|
|
if cfg.RopeLocalBaseFreq != 10000 {
|
|
t.Errorf("RopeLocalBaseFreq default = %f, want 10000", cfg.RopeLocalBaseFreq)
|
|
}
|
|
if cfg.RMSNormEps != 1e-6 {
|
|
t.Errorf("RMSNormEps default = %f, want 1e-6", cfg.RMSNormEps)
|
|
}
|
|
if cfg.SlidingWindowPattern != 6 {
|
|
t.Errorf("SlidingWindowPattern default = %d, want 6", cfg.SlidingWindowPattern)
|
|
}
|
|
if cfg.VocabSize != 262208 {
|
|
t.Errorf("VocabSize default = %d, want 262208", cfg.VocabSize)
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_QuantizationTopLevel(t *testing.T) {
|
|
cfg, err := parseConfig([]byte(`{
|
|
"hidden_size": 1024,
|
|
"num_hidden_layers": 8,
|
|
"num_attention_heads": 4,
|
|
"head_dim": 128,
|
|
"quantization": {"group_size": 64, "bits": 4}
|
|
}`))
|
|
if err != nil {
|
|
t.Fatalf("parseConfig: %v", err)
|
|
}
|
|
if cfg.Quantization == nil {
|
|
t.Fatal("expected quantization config")
|
|
}
|
|
if cfg.Quantization.GroupSize != 64 {
|
|
t.Errorf("GroupSize = %d, want 64", cfg.Quantization.GroupSize)
|
|
}
|
|
if cfg.Quantization.Bits != 4 {
|
|
t.Errorf("Bits = %d, want 4", cfg.Quantization.Bits)
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_NestedTextConfig(t *testing.T) {
|
|
// Multimodal Gemma3 has text_config nested inside a wrapper.
|
|
cfg, err := parseConfig([]byte(`{
|
|
"model_type": "gemma3",
|
|
"text_config": {
|
|
"hidden_size": 2048,
|
|
"num_hidden_layers": 16,
|
|
"num_attention_heads": 8,
|
|
"num_key_value_heads": 2,
|
|
"head_dim": 256,
|
|
"vocab_size": 262144
|
|
}
|
|
}`))
|
|
if err != nil {
|
|
t.Fatalf("parseConfig: %v", err)
|
|
}
|
|
if cfg.HiddenSize != 2048 {
|
|
t.Errorf("HiddenSize = %d, want 2048", cfg.HiddenSize)
|
|
}
|
|
if cfg.NumHiddenLayers != 16 {
|
|
t.Errorf("NumHiddenLayers = %d, want 16", cfg.NumHiddenLayers)
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_InvalidJSON(t *testing.T) {
|
|
_, err := parseConfig([]byte("not json"))
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid JSON")
|
|
}
|
|
}
|
|
|
|
// --- parseQwen3Config ---
|
|
|
|
func TestParseQwen3Config_Defaults(t *testing.T) {
|
|
cfg, err := parseQwen3Config([]byte(`{
|
|
"hidden_size": 1024,
|
|
"num_hidden_layers": 8,
|
|
"num_attention_heads": 4,
|
|
"num_key_value_heads": 2
|
|
}`))
|
|
if err != nil {
|
|
t.Fatalf("parseQwen3Config: %v", err)
|
|
}
|
|
if cfg.HeadDim != 256 { // 1024/4
|
|
t.Errorf("HeadDim = %d, want 256 (hidden/heads)", cfg.HeadDim)
|
|
}
|
|
if cfg.RopeTheta != 1000000 {
|
|
t.Errorf("RopeTheta default = %f, want 1000000", cfg.RopeTheta)
|
|
}
|
|
if cfg.VocabSize != 151936 {
|
|
t.Errorf("VocabSize default = %d, want 151936", cfg.VocabSize)
|
|
}
|
|
}
|
|
|
|
func TestParseQwen3Config_InvalidJSON(t *testing.T) {
|
|
_, err := parseQwen3Config([]byte("{broken"))
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid JSON")
|
|
}
|
|
}
|
|
|
|
// --- isLayerSliding ---
|
|
|
|
func TestIsLayerSliding(t *testing.T) {
|
|
// Pattern=6: every 6th layer is NOT sliding (global attention).
|
|
// Layer 5 (index=5, i+1=6) → 6%6=0 → not sliding (global)
|
|
// Layer 0 (index=0, i+1=1) → 1%6=1 → sliding
|
|
tests := []struct {
|
|
idx int32
|
|
pattern int32
|
|
want bool
|
|
}{
|
|
{0, 6, true}, // layer 1: 1%6=1 → sliding
|
|
{4, 6, true}, // layer 5: 5%6=5 → sliding
|
|
{5, 6, false}, // layer 6: 6%6=0 → global
|
|
{11, 6, false}, // layer 12: 12%6=0 → global
|
|
{0, 0, false}, // pattern=0 → no sliding
|
|
{0, -1, false}, // pattern<0 → no sliding
|
|
}
|
|
for _, tt := range tests {
|
|
got := isLayerSliding(tt.idx, tt.pattern)
|
|
if got != tt.want {
|
|
t.Errorf("isLayerSliding(%d, %d) = %v, want %v", tt.idx, tt.pattern, got, tt.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
// --- resolveWeight ---
|
|
|
|
func TestResolveWeight_Direct(t *testing.T) {
|
|
a := FromValue(float32(1))
|
|
weights := map[string]*Array{"model.norm.weight": a}
|
|
|
|
got := resolveWeight(weights, "model.norm.weight")
|
|
if got != a {
|
|
t.Error("expected direct name resolution")
|
|
}
|
|
}
|
|
|
|
func TestResolveWeight_LanguageModelPrefix(t *testing.T) {
|
|
a := FromValue(float32(1))
|
|
weights := map[string]*Array{"language_model.model.norm.weight": a}
|
|
|
|
got := resolveWeight(weights, "model.norm.weight")
|
|
if got != a {
|
|
t.Error("expected language_model. prefix fallback")
|
|
}
|
|
}
|
|
|
|
func TestResolveWeight_NotFound(t *testing.T) {
|
|
weights := map[string]*Array{}
|
|
got := resolveWeight(weights, "nonexistent")
|
|
if got != nil {
|
|
t.Error("expected nil for missing weight")
|
|
}
|
|
}
|
|
|
|
// --- helpers ---
|
|
|
|
// writeMinimalConfig writes a minimal valid config.json for testing.
|
|
func writeMinimalConfig(t *testing.T, dir string, modelType string) {
|
|
t.Helper()
|
|
config := `{
|
|
"model_type": "` + modelType + `",
|
|
"hidden_size": 64,
|
|
"num_hidden_layers": 1,
|
|
"intermediate_size": 128,
|
|
"num_attention_heads": 2,
|
|
"num_key_value_heads": 1,
|
|
"head_dim": 32,
|
|
"vocab_size": 100,
|
|
"rms_norm_eps": 1e-6
|
|
}`
|
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(config), 0644); err != nil {
|
|
t.Fatalf("write config.json: %v", err)
|
|
}
|
|
}
|
|
|
|
// writeMinimalTokenizer writes a minimal valid tokenizer.json for testing.
|
|
func writeMinimalTokenizer(t *testing.T, dir string) {
|
|
t.Helper()
|
|
tokenizer := `{
|
|
"model": {
|
|
"type": "BPE",
|
|
"vocab": {"<pad>": 0, "<eos>": 1, "<bos>": 2, "hello": 3, "world": 4},
|
|
"merges": []
|
|
},
|
|
"added_tokens": [
|
|
{"id": 0, "content": "<pad>", "special": true},
|
|
{"id": 1, "content": "<eos>", "special": true},
|
|
{"id": 2, "content": "<bos>", "special": true}
|
|
]
|
|
}`
|
|
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizer), 0644); err != nil {
|
|
t.Fatalf("write tokenizer.json: %v", err)
|
|
}
|
|
}
|