From eb8dee31bf8d892fc82398fecbdf0f40a37b2e98 Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 20:05:14 +0000 Subject: [PATCH] test(api): integration tests for public LoadModel + Generate Tests: MetalAvailable, DefaultBackend, GetBackend, LoadModel error paths, options, defaults. Model-dependent test skips when model not available on disk. Co-Authored-By: Virgil --- mlx_test.go | 123 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 mlx_test.go diff --git a/mlx_test.go b/mlx_test.go new file mode 100644 index 0000000..9e1689d --- /dev/null +++ b/mlx_test.go @@ -0,0 +1,123 @@ +//go:build darwin && arm64 + +package mlx_test + +import ( + "context" + "os" + "testing" + + "forge.lthn.ai/core/go-mlx" +) + +func TestMetalAvailable(t *testing.T) { + if !mlx.MetalAvailable() { + t.Fatal("MetalAvailable() = false on darwin/arm64") + } +} + +func TestDefaultBackend(t *testing.T) { + b, err := mlx.Default() + if err != nil { + t.Fatalf("Default() error: %v", err) + } + if b.Name() != "metal" { + t.Errorf("Default().Name() = %q, want %q", b.Name(), "metal") + } +} + +func TestGetBackend(t *testing.T) { + b, ok := mlx.Get("metal") + if !ok { + t.Fatal("Get(\"metal\") returned false") + } + if b.Name() != "metal" { + t.Errorf("Name() = %q, want %q", b.Name(), "metal") + } + + _, ok = mlx.Get("nonexistent") + if ok { + t.Error("Get(\"nonexistent\") should return false") + } +} + +func TestLoadModel_NoBackend(t *testing.T) { + _, err := mlx.LoadModel("/nonexistent/path") + if err == nil { + t.Error("expected error for nonexistent model path") + } +} + +func TestLoadModel_WithBackend(t *testing.T) { + _, err := mlx.LoadModel("/nonexistent/path", mlx.WithBackend("nonexistent")) + if err == nil { + t.Error("expected error for nonexistent backend") + } +} + +func TestOptions(t *testing.T) { + cfg := mlx.ApplyGenerateOpts([]mlx.GenerateOption{ + mlx.WithMaxTokens(64), + mlx.WithTemperature(0.7), + mlx.WithTopK(40), + mlx.WithTopP(0.9), + mlx.WithStopTokens(1, 2, 3), + }) + if cfg.MaxTokens != 64 { + t.Errorf("MaxTokens = %d, want 64", cfg.MaxTokens) + } + if cfg.Temperature != 0.7 { + t.Errorf("Temperature = %f, want 0.7", cfg.Temperature) + } + if cfg.TopK != 40 { + t.Errorf("TopK = %d, want 40", cfg.TopK) + } + if cfg.TopP != 0.9 { + t.Errorf("TopP = %f, want 0.9", cfg.TopP) + } + if len(cfg.StopTokens) != 3 { + t.Errorf("StopTokens len = %d, want 3", len(cfg.StopTokens)) + } +} + +func TestDefaults(t *testing.T) { + cfg := mlx.DefaultGenerateConfig() + if cfg.MaxTokens != 256 { + t.Errorf("default MaxTokens = %d, want 256", cfg.MaxTokens) + } + if cfg.Temperature != 0.0 { + t.Errorf("default Temperature = %f, want 0.0", cfg.Temperature) + } +} + +// TestLoadModel_Generate requires a model on disk. Skipped in CI. +func TestLoadModel_Generate(t *testing.T) { + const modelPath = "/Volumes/Data/lem/safetensors/gemma-3/" + if _, err := os.Stat(modelPath); err != nil { + t.Skip("model not available at", modelPath) + } + + m, err := mlx.LoadModel(modelPath) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } + defer m.Close() + + if m.ModelType() != "gemma3" { + t.Errorf("ModelType() = %q, want %q", m.ModelType(), "gemma3") + } + + ctx := context.Background() + var count int + for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(16)) { + count++ + t.Logf("[%d] %q", tok.ID, tok.Text) + } + if err := m.Err(); err != nil { + t.Fatalf("Generate error: %v", err) + } + if count == 0 { + t.Error("Generate produced no tokens") + } + t.Logf("Generated %d tokens", count) +}