feat(metal): expose model metadata via Info()
Return architecture, vocab size, layer count, hidden dimension, and quantisation config (bits + group size) for loaded models. Gemma3-1B 4-bit: arch=gemma3, vocab=262144, layers=26, hidden=1152, quant=4-bit/group64. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a44e9f5789
commit
ceb966b66b
3 changed files with 80 additions and 0 deletions
|
|
@ -63,6 +63,41 @@ func (m *Model) Err() error { return m.lastErr }
|
|||
// LastMetrics returns performance metrics from the last inference operation.
|
||||
func (m *Model) LastMetrics() Metrics { return m.lastMetrics }
|
||||
|
||||
// ModelInfo holds metadata about a loaded model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
VocabSize int
|
||||
NumLayers int
|
||||
HiddenSize int
|
||||
QuantBits int
|
||||
QuantGroup int
|
||||
}
|
||||
|
||||
// Info returns metadata about the loaded model.
|
||||
func (m *Model) Info() ModelInfo {
|
||||
info := ModelInfo{
|
||||
Architecture: m.modelType,
|
||||
NumLayers: m.model.NumLayers(),
|
||||
}
|
||||
switch v := m.model.(type) {
|
||||
case *GemmaModel:
|
||||
info.VocabSize = int(v.Cfg.VocabSize)
|
||||
info.HiddenSize = int(v.Cfg.HiddenSize)
|
||||
if v.Cfg.Quantization != nil {
|
||||
info.QuantBits = v.Cfg.Quantization.Bits
|
||||
info.QuantGroup = v.Cfg.Quantization.GroupSize
|
||||
}
|
||||
case *Qwen3Model:
|
||||
info.VocabSize = int(v.Cfg.VocabSize)
|
||||
info.HiddenSize = int(v.Cfg.HiddenSize)
|
||||
if v.Cfg.Quantization != nil {
|
||||
info.QuantBits = v.Cfg.Quantization.Bits
|
||||
info.QuantGroup = v.Cfg.Quantization.GroupSize
|
||||
}
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// Close releases all model weight arrays and associated resources.
|
||||
// After Close, the Model must not be used for generation.
|
||||
func (m *Model) Close() error {
|
||||
|
|
|
|||
34
mlx_test.go
34
mlx_test.go
|
|
@ -443,6 +443,40 @@ func TestLlama_Inference(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// --- ModelInfo tests ---
|
||||
|
||||
func TestModelInfo(t *testing.T) {
|
||||
modelPath := gemma3ModelPath(t)
|
||||
|
||||
m, err := inference.LoadModel(modelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel: %v", err)
|
||||
}
|
||||
defer func() { m.Close(); mlx.ClearCache() }()
|
||||
|
||||
info := m.Info()
|
||||
if info.Architecture != "gemma3" {
|
||||
t.Errorf("Architecture = %q, want %q", info.Architecture, "gemma3")
|
||||
}
|
||||
if info.VocabSize == 0 {
|
||||
t.Error("VocabSize should be > 0")
|
||||
}
|
||||
if info.NumLayers == 0 {
|
||||
t.Error("NumLayers should be > 0")
|
||||
}
|
||||
if info.HiddenSize == 0 {
|
||||
t.Error("HiddenSize should be > 0")
|
||||
}
|
||||
// Gemma3-1B 4-bit should report quantisation.
|
||||
if info.QuantBits == 0 {
|
||||
t.Log("Model is not quantised (or quantisation not detected)")
|
||||
}
|
||||
|
||||
t.Logf("ModelInfo: arch=%s vocab=%d layers=%d hidden=%d quant=%d-bit group=%d",
|
||||
info.Architecture, info.VocabSize, info.NumLayers, info.HiddenSize,
|
||||
info.QuantBits, info.QuantGroup)
|
||||
}
|
||||
|
||||
// --- Metrics tests ---
|
||||
|
||||
// TestGenerate_Metrics validates that metrics are populated after generation.
|
||||
|
|
|
|||
|
|
@ -180,5 +180,16 @@ func (a *metalAdapter) Metrics() inference.GenerateMetrics {
|
|||
}
|
||||
|
||||
func (a *metalAdapter) ModelType() string { return a.m.ModelType() }
|
||||
func (a *metalAdapter) Info() inference.ModelInfo {
|
||||
i := a.m.Info()
|
||||
return inference.ModelInfo{
|
||||
Architecture: i.Architecture,
|
||||
VocabSize: i.VocabSize,
|
||||
NumLayers: i.NumLayers,
|
||||
HiddenSize: i.HiddenSize,
|
||||
QuantBits: i.QuantBits,
|
||||
QuantGroup: i.QuantGroup,
|
||||
}
|
||||
}
|
||||
func (a *metalAdapter) Err() error { return a.m.Err() }
|
||||
func (a *metalAdapter) Close() error { return a.m.Close() }
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue