Cover generate, chat, classify, batch generate, metrics, model info, discovery, and Metal memory controls. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
101 lines
2.9 KiB
Go
101 lines
2.9 KiB
Go
// Package mlx provides Apple Metal GPU inference via mlx-c bindings.
|
|
//
|
|
// This package implements the [inference.Backend] interface from
|
|
// forge.lthn.ai/core/go-inference for Apple Silicon (M1-M4) GPUs.
|
|
// Import it blank to register the "metal" backend automatically:
|
|
//
|
|
// import _ "forge.lthn.ai/core/go-mlx"
|
|
//
|
|
// Build mlx-c before use:
|
|
//
|
|
// go generate ./...
|
|
//
|
|
// # Generate text
|
|
//
|
|
// m, err := inference.LoadModel("/path/to/model/")
|
|
// if err != nil { log.Fatal(err) }
|
|
// defer m.Close()
|
|
//
|
|
// ctx := context.Background()
|
|
// for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(128)) {
|
|
// fmt.Print(tok.Text)
|
|
// }
|
|
// if err := m.Err(); err != nil { log.Fatal(err) }
|
|
//
|
|
// # Multi-turn chat
|
|
//
|
|
// Chat applies the model's native template (Gemma3, Qwen3, Llama3):
|
|
//
|
|
// for tok := range m.Chat(ctx, []inference.Message{
|
|
// {Role: "system", Content: "You are a helpful assistant."},
|
|
// {Role: "user", Content: "Translate 'hello' to French."},
|
|
// }, inference.WithMaxTokens(64)) {
|
|
// fmt.Print(tok.Text)
|
|
// }
|
|
//
|
|
// # Batch classification
|
|
//
|
|
// Classify runs a single forward pass per prompt (prefill only, no decoding):
|
|
//
|
|
// results, err := m.Classify(ctx, []string{
|
|
// "Bonjour, comment allez-vous?",
|
|
// "The quarterly report shows growth.",
|
|
// }, inference.WithTemperature(0))
|
|
// for i, r := range results {
|
|
// fmt.Printf("prompt %d → %q\n", i, r.Token.Text)
|
|
// }
|
|
//
|
|
// # Batch generation
|
|
//
|
|
// results, err := m.BatchGenerate(ctx, []string{
|
|
// "The capital of France is",
|
|
// "Water boils at",
|
|
// }, inference.WithMaxTokens(32))
|
|
// for i, r := range results {
|
|
// for _, tok := range r.Tokens {
|
|
// fmt.Print(tok.Text)
|
|
// }
|
|
// fmt.Println()
|
|
// }
|
|
//
|
|
// # Performance metrics
|
|
//
|
|
// After any inference call, retrieve timing and memory statistics:
|
|
//
|
|
// for tok := range m.Generate(ctx, prompt, inference.WithMaxTokens(128)) {
|
|
// fmt.Print(tok.Text)
|
|
// }
|
|
// met := m.Metrics()
|
|
// fmt.Printf("decode: %.0f tok/s, peak GPU: %d MB\n",
|
|
// met.DecodeTokensPerSec, met.PeakMemoryBytes/1024/1024)
|
|
//
|
|
// # Model info
|
|
//
|
|
// info := m.Info()
|
|
// fmt.Printf("%s %d-layer, %d-bit quantised\n",
|
|
// info.Architecture, info.NumLayers, info.QuantBits)
|
|
//
|
|
// # Model discovery
|
|
//
|
|
// models, err := inference.Discover("/path/to/models/")
|
|
// for _, d := range models {
|
|
// fmt.Printf("%s (%s, %d-bit)\n", d.Path, d.ModelType, d.QuantBits)
|
|
// }
|
|
//
|
|
// # Metal memory controls
|
|
//
|
|
// These control the Metal allocator directly, not individual models:
|
|
//
|
|
// mlx.SetCacheLimit(4 << 30) // 4 GB cache limit
|
|
// mlx.SetMemoryLimit(32 << 30) // 32 GB hard limit
|
|
//
|
|
// // Between chat turns, reclaim prompt cache memory:
|
|
// mlx.ClearCache()
|
|
//
|
|
// fmt.Printf("active: %d MB, peak: %d MB\n",
|
|
// mlx.GetActiveMemory()/1024/1024, mlx.GetPeakMemory()/1024/1024)
|
|
package mlx
|
|
|
|
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
|
|
//go:generate cmake --build build --parallel
|
|
//go:generate cmake --install build
|