From ceb966b66b6fb6760c8a589d0d3512b812f7c920 Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 23:36:23 +0000 Subject: [PATCH] feat(metal): expose model metadata via Info() Return architecture, vocab size, layer count, hidden dimension, and quantisation config (bits + group size) for loaded models. Gemma3-1B 4-bit: arch=gemma3, vocab=262144, layers=26, hidden=1152, quant=4-bit/group64. Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- internal/metal/generate.go | 35 +++++++++++++++++++++++++++++++++++ mlx_test.go | 34 ++++++++++++++++++++++++++++++++++ register_metal.go | 11 +++++++++++ 3 files changed, 80 insertions(+) diff --git a/internal/metal/generate.go b/internal/metal/generate.go index c149d49..0203391 100644 --- a/internal/metal/generate.go +++ b/internal/metal/generate.go @@ -63,6 +63,41 @@ func (m *Model) Err() error { return m.lastErr } // LastMetrics returns performance metrics from the last inference operation. func (m *Model) LastMetrics() Metrics { return m.lastMetrics } +// ModelInfo holds metadata about a loaded model. +type ModelInfo struct { + Architecture string + VocabSize int + NumLayers int + HiddenSize int + QuantBits int + QuantGroup int +} + +// Info returns metadata about the loaded model. +func (m *Model) Info() ModelInfo { + info := ModelInfo{ + Architecture: m.modelType, + NumLayers: m.model.NumLayers(), + } + switch v := m.model.(type) { + case *GemmaModel: + info.VocabSize = int(v.Cfg.VocabSize) + info.HiddenSize = int(v.Cfg.HiddenSize) + if v.Cfg.Quantization != nil { + info.QuantBits = v.Cfg.Quantization.Bits + info.QuantGroup = v.Cfg.Quantization.GroupSize + } + case *Qwen3Model: + info.VocabSize = int(v.Cfg.VocabSize) + info.HiddenSize = int(v.Cfg.HiddenSize) + if v.Cfg.Quantization != nil { + info.QuantBits = v.Cfg.Quantization.Bits + info.QuantGroup = v.Cfg.Quantization.GroupSize + } + } + return info +} + // Close releases all model weight arrays and associated resources. // After Close, the Model must not be used for generation. func (m *Model) Close() error { diff --git a/mlx_test.go b/mlx_test.go index 7dc1265..a00eed7 100644 --- a/mlx_test.go +++ b/mlx_test.go @@ -443,6 +443,40 @@ func TestLlama_Inference(t *testing.T) { } } +// --- ModelInfo tests --- + +func TestModelInfo(t *testing.T) { + modelPath := gemma3ModelPath(t) + + m, err := inference.LoadModel(modelPath) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } + defer func() { m.Close(); mlx.ClearCache() }() + + info := m.Info() + if info.Architecture != "gemma3" { + t.Errorf("Architecture = %q, want %q", info.Architecture, "gemma3") + } + if info.VocabSize == 0 { + t.Error("VocabSize should be > 0") + } + if info.NumLayers == 0 { + t.Error("NumLayers should be > 0") + } + if info.HiddenSize == 0 { + t.Error("HiddenSize should be > 0") + } + // Gemma3-1B 4-bit should report quantisation. + if info.QuantBits == 0 { + t.Log("Model is not quantised (or quantisation not detected)") + } + + t.Logf("ModelInfo: arch=%s vocab=%d layers=%d hidden=%d quant=%d-bit group=%d", + info.Architecture, info.VocabSize, info.NumLayers, info.HiddenSize, + info.QuantBits, info.QuantGroup) +} + // --- Metrics tests --- // TestGenerate_Metrics validates that metrics are populated after generation. diff --git a/register_metal.go b/register_metal.go index 77435d1..4f90432 100644 --- a/register_metal.go +++ b/register_metal.go @@ -180,5 +180,16 @@ func (a *metalAdapter) Metrics() inference.GenerateMetrics { } func (a *metalAdapter) ModelType() string { return a.m.ModelType() } +func (a *metalAdapter) Info() inference.ModelInfo { + i := a.m.Info() + return inference.ModelInfo{ + Architecture: i.Architecture, + VocabSize: i.VocabSize, + NumLayers: i.NumLayers, + HiddenSize: i.HiddenSize, + QuantBits: i.QuantBits, + QuantGroup: i.QuantGroup, + } +} func (a *metalAdapter) Err() error { return a.m.Err() } func (a *metalAdapter) Close() error { return a.m.Close() }