// Package mlx provides Apple Metal GPU inference via mlx-c bindings. // // This package implements the [inference.Backend] interface from // forge.lthn.ai/core/go-inference for Apple Silicon (M1-M4) GPUs. // Import it blank to register the "metal" backend automatically: // // import _ "forge.lthn.ai/core/go-mlx" // // Build mlx-c before use: // // go generate ./... // // # Generate text // // m, err := inference.LoadModel("/path/to/model/") // if err != nil { log.Fatal(err) } // defer m.Close() // // ctx := context.Background() // for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(128)) { // fmt.Print(tok.Text) // } // if err := m.Err(); err != nil { log.Fatal(err) } // // # Multi-turn chat // // Chat applies the model's native template (Gemma3, Qwen3, Llama3): // // for tok := range m.Chat(ctx, []inference.Message{ // {Role: "system", Content: "You are a helpful assistant."}, // {Role: "user", Content: "Translate 'hello' to French."}, // }, inference.WithMaxTokens(64)) { // fmt.Print(tok.Text) // } // // # Batch classification // // Classify runs a single forward pass per prompt (prefill only, no decoding): // // results, err := m.Classify(ctx, []string{ // "Bonjour, comment allez-vous?", // "The quarterly report shows growth.", // }, inference.WithTemperature(0)) // for i, r := range results { // fmt.Printf("prompt %d → %q\n", i, r.Token.Text) // } // // # Batch generation // // results, err := m.BatchGenerate(ctx, []string{ // "The capital of France is", // "Water boils at", // }, inference.WithMaxTokens(32)) // for i, r := range results { // for _, tok := range r.Tokens { // fmt.Print(tok.Text) // } // fmt.Println() // } // // # Performance metrics // // After any inference call, retrieve timing and memory statistics: // // for tok := range m.Generate(ctx, prompt, inference.WithMaxTokens(128)) { // fmt.Print(tok.Text) // } // met := m.Metrics() // fmt.Printf("decode: %.0f tok/s, peak GPU: %d MB\n", // met.DecodeTokensPerSec, met.PeakMemoryBytes/1024/1024) // // # Model info // // info := m.Info() // fmt.Printf("%s %d-layer, %d-bit quantised\n", // info.Architecture, info.NumLayers, info.QuantBits) // // # Model discovery // // models, err := inference.Discover("/path/to/models/") // for _, d := range models { // fmt.Printf("%s (%s, %d-bit)\n", d.Path, d.ModelType, d.QuantBits) // } // // # Metal memory controls // // These control the Metal allocator directly, not individual models: // // mlx.SetCacheLimit(4 << 30) // 4 GB cache limit // mlx.SetMemoryLimit(32 << 30) // 32 GB hard limit // // // Between chat turns, reclaim prompt cache memory: // mlx.ClearCache() // // fmt.Printf("active: %d MB, peak: %d MB\n", // mlx.GetActiveMemory()/1024/1024, mlx.GetPeakMemory()/1024/1024) package mlx //go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release //go:generate cmake --build build --parallel //go:generate cmake --install build