feat(api): migrate to go-inference shared interfaces
Replace local TextModel, Backend, Token, Message, and option types with forge.lthn.ai/core/go-inference. go-mlx is now a pure backend that registers "metal" into the shared inference registry via init(). Deleted: textmodel.go, options.go, backend.go Updated: register_metal.go (implements inference.Backend with Available()), mlx_test.go (uses inference.* types, 4 new tests), go.mod, internal/metal/generate.go (added RepeatPenalty) 159 tests passing (148 internal/metal + 11 root). Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
95d92fffff
commit
bff97ccf19
11 changed files with 196 additions and 249 deletions
35
CLAUDE.md
35
CLAUDE.md
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
Native Apple Metal GPU inference via mlx-c bindings. Module: `forge.lthn.ai/core/go-mlx`
|
||||
|
||||
Pure Go + CGO package that wraps Apple's [MLX framework](https://github.com/ml-explore/mlx) through the [mlx-c](https://github.com/ml-explore/mlx-c) C API. Runs LLM inference on Apple Silicon GPUs (M1-M4) using Metal compute shaders.
|
||||
Implements the `inference.Backend` interface from [`forge.lthn.ai/core/go-inference`](https://forge.lthn.ai/core/go-inference) for Apple Silicon (M1-M4) GPUs using Metal compute shaders via the [mlx-c](https://github.com/ml-explore/mlx-c) C API.
|
||||
|
||||
## Platform
|
||||
|
||||
|
|
@ -41,12 +41,9 @@ CMake installs to `dist/` inside the package directory. The `#cgo` directives in
|
|||
go-mlx/
|
||||
├── Public API (compiles on all platforms)
|
||||
│ ├── mlx.go — Package doc + go:generate CMake directives
|
||||
│ ├── textmodel.go — TextModel interface, Token, Message types
|
||||
│ ├── options.go — GenerateOption, LoadOption functional options
|
||||
│ ├── backend.go — Backend interface, Register/Get/Default/LoadModel
|
||||
│ ├── register_metal.go — //go:build darwin && arm64 — auto-registers metal
|
||||
│ ├── register_metal.go — //go:build darwin && arm64 — auto-registers metal backend
|
||||
│ ├── mlx_stub.go — //go:build !darwin || !arm64 — MetalAvailable() false
|
||||
│ └── mlx_test.go — Integration tests (public API)
|
||||
│ └── mlx_test.go — Integration tests (public API via go-inference)
|
||||
│
|
||||
├── internal/metal/ — All CGO code (darwin/arm64 only)
|
||||
│ ├── metal.go — Init, Materialize, error handler, stream
|
||||
|
|
@ -81,46 +78,50 @@ go-mlx/
|
|||
└── docs/plans/ — Design and implementation docs
|
||||
```
|
||||
|
||||
## Public API
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import "forge.lthn.ai/core/go-mlx"
|
||||
import (
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
_ "forge.lthn.ai/core/go-mlx" // register Metal backend
|
||||
)
|
||||
|
||||
// Load and generate
|
||||
m, err := mlx.LoadModel("/path/to/model/")
|
||||
// Load and generate — types come from go-inference
|
||||
m, err := inference.LoadModel("/path/to/model/")
|
||||
if err != nil { log.Fatal(err) }
|
||||
defer m.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(128)) {
|
||||
for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(128)) {
|
||||
fmt.Print(tok.Text)
|
||||
}
|
||||
if err := m.Err(); err != nil { log.Fatal(err) }
|
||||
|
||||
// Chat with template formatting
|
||||
for tok := range m.Chat(ctx, []mlx.Message{
|
||||
for tok := range m.Chat(ctx, []inference.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}, mlx.WithMaxTokens(64)) {
|
||||
}, inference.WithMaxTokens(64)) {
|
||||
fmt.Print(tok.Text)
|
||||
}
|
||||
|
||||
// Memory controls
|
||||
// Metal-specific memory controls (from mlx package directly)
|
||||
mlx.SetCacheLimit(4 * 1024 * 1024 * 1024) // 4GB
|
||||
mlx.ClearCache()
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
- **[`forge.lthn.ai/core/go-inference`](https://forge.lthn.ai/core/go-inference)** — shared Backend/TextModel interfaces
|
||||
- **mlx-c v0.4.1** (fetched by CMake at build time)
|
||||
- **Apple frameworks**: Foundation, Metal, Accelerate
|
||||
- **Go stdlib only** — no external Go dependencies
|
||||
|
||||
## Downstream Consumers
|
||||
|
||||
- `forge.lthn.ai/core/go-ml` — `backend_mlx.go` uses `mlx.LoadModel()` + `m.Generate()`
|
||||
- `forge.lthn.ai/core/go-ml` — imports go-inference + go-mlx for Metal backend
|
||||
- `forge.lthn.ai/core/go-i18n` — Phase 2a needs Gemma3-1B inference for domain classification
|
||||
- `forge.lthn.ai/core/go-rocm` — sibling backend for AMD GPUs, same go-inference interfaces
|
||||
|
||||
The old API (`Array`, `MatMul`, `model.LoadModel`, etc.) is no longer public — all CGO is in `internal/metal/`. Consumers use the clean `TextModel` interface.
|
||||
Import `_ "forge.lthn.ai/core/go-mlx"` to register the Metal backend. All types (`TextModel`, `Token`, `Backend`, options) come from `go-inference`.
|
||||
|
||||
## Model Format
|
||||
|
||||
|
|
|
|||
62
FINDINGS.md
62
FINDINGS.md
|
|
@ -233,3 +233,65 @@ for tok := range m.Generate(ctx, sentence, mlx.WithMaxTokens(32)) { ... }
|
|||
- 148 existing tests moved to `internal/metal/` — all pass
|
||||
- 7 new integration tests for public API — all pass
|
||||
- Total: 155 tests passing
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-19: Migration to go-inference Shared Interfaces
|
||||
|
||||
### What changed
|
||||
|
||||
go-mlx no longer defines its own `TextModel`, `Backend`, `Token`, `Message`, `GenerateConfig`, `GenerateOption`, `LoadConfig`, `LoadOption` types. These are now provided by `forge.lthn.ai/core/go-inference`, a zero-dependency shared interface package.
|
||||
|
||||
### Files removed
|
||||
|
||||
- `textmodel.go` — `Token`, `Message`, `TextModel` now in go-inference
|
||||
- `options.go` — `GenerateConfig`, `GenerateOption`, `LoadConfig`, `LoadOption` now in go-inference
|
||||
- `backend.go` — `Backend`, `Register`, `Get`, `Default`, `LoadModel` now in go-inference
|
||||
|
||||
### Files updated
|
||||
|
||||
- `register_metal.go` — implements `inference.Backend` (added `Available() bool`), adapts `inference.Token`/`inference.Message`
|
||||
- `mlx_test.go` — all tests use `inference.*` types, added `TestListBackends`, `TestLoadOptions`, `TestLoadOptionsDefaults`
|
||||
- `mlx.go` — package doc updated to show go-inference import pattern
|
||||
- `go.mod` — added `forge.lthn.ai/core/go-inference` dependency (replace directive for local dev)
|
||||
- `internal/metal/generate.go` — `GenerateConfig` gained `RepeatPenalty float32`
|
||||
|
||||
### What go-mlx still exports
|
||||
|
||||
- `MetalAvailable() bool` — convenience check
|
||||
- `SetCacheLimit`, `SetMemoryLimit`, `GetActiveMemory`, `GetPeakMemory`, `ClearCache` — Metal-specific memory controls
|
||||
- Side-effect import (`_ "forge.lthn.ai/core/go-mlx"`) registers the `"metal"` backend into go-inference's registry
|
||||
|
||||
### Consumer migration
|
||||
|
||||
Before:
|
||||
```go
|
||||
import "forge.lthn.ai/core/go-mlx"
|
||||
m, _ := mlx.LoadModel(path)
|
||||
for tok := range m.Generate(ctx, prompt, mlx.WithMaxTokens(128)) { ... }
|
||||
```
|
||||
|
||||
After:
|
||||
```go
|
||||
import (
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
_ "forge.lthn.ai/core/go-mlx" // register Metal backend
|
||||
)
|
||||
m, _ := inference.LoadModel(path)
|
||||
for tok := range m.Generate(ctx, prompt, inference.WithMaxTokens(128)) { ... }
|
||||
```
|
||||
|
||||
### New go-inference features available
|
||||
|
||||
- `inference.List()` — returns all registered backend names
|
||||
- `inference.Backend.Available()` — hardware availability check
|
||||
- `inference.WithRepeatPenalty(p)` — repetition penalty option
|
||||
- `inference.WithContextLen(n)` — context window size
|
||||
- `inference.WithGPULayers(n)` — GPU layer offload control (-1 = all)
|
||||
- `inference.LoadConfig.GPULayers` defaults to -1 (full GPU offload)
|
||||
|
||||
### Test results
|
||||
|
||||
- 148 internal/metal tests — all pass
|
||||
- 11 root integration tests — all pass
|
||||
- Total: 159 tests passing
|
||||
|
|
|
|||
7
TODO.md
7
TODO.md
|
|
@ -57,11 +57,16 @@ Implementation plan: `docs/plans/2026-02-19-backend-abstraction-plan.md`
|
|||
|
||||
---
|
||||
|
||||
## go-inference Integration — ✅ COMPLETE (19 Feb 2026)
|
||||
|
||||
All types (`TextModel`, `Backend`, `Token`, `Message`, options) moved to shared `forge.lthn.ai/core/go-inference` package. go-mlx is now a pure backend implementation — import `_ "forge.lthn.ai/core/go-mlx"` to register the `"metal"` backend. See FINDINGS.md for migration details.
|
||||
|
||||
## Upstream Dependencies
|
||||
|
||||
- **go-i18n Phase 2a** is blocked on this package providing working Gemma3-1B inference
|
||||
- **go-ml/backend_mlx.go** needs updating to use new API: `mlx.LoadModel()` + `m.Generate()`. Old imports (`mlx.Array`, `model.LoadModel`, sub-packages) are now internal.
|
||||
- **go-ml/backend_mlx.go** needs updating to use `inference.LoadModel()` + `m.Generate()` (types from go-inference, `_ "go-mlx"` for Metal registration)
|
||||
- **go-ai** has a `replace` directive pointing at `../go-mlx`. No code changes needed in go-ai itself.
|
||||
- **go-rocm** — sibling backend for AMD GPUs, implements same `inference.Backend` interface
|
||||
- **LEM Lab** uses `MLXBackend` via go-ml. Migration transparent once go-ml updates.
|
||||
|
||||
## Functional Options Convention
|
||||
|
|
|
|||
66
backend.go
66
backend.go
|
|
@ -1,66 +0,0 @@
|
|||
package mlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Backend is a named inference engine that can load models.
|
||||
type Backend interface {
|
||||
// Name returns the backend identifier (e.g. "metal", "mlx_lm").
|
||||
Name() string
|
||||
|
||||
// LoadModel loads a model from the given path.
|
||||
LoadModel(path string, opts ...LoadOption) (TextModel, error)
|
||||
}
|
||||
|
||||
var (
|
||||
backendsMu sync.RWMutex
|
||||
backends = map[string]Backend{}
|
||||
)
|
||||
|
||||
// Register adds a backend to the registry.
|
||||
func Register(b Backend) {
|
||||
backendsMu.Lock()
|
||||
defer backendsMu.Unlock()
|
||||
backends[b.Name()] = b
|
||||
}
|
||||
|
||||
// Get returns a registered backend by name.
|
||||
func Get(name string) (Backend, bool) {
|
||||
backendsMu.RLock()
|
||||
defer backendsMu.RUnlock()
|
||||
b, ok := backends[name]
|
||||
return b, ok
|
||||
}
|
||||
|
||||
// Default returns the first available backend.
|
||||
// Prefers "metal" if registered.
|
||||
func Default() (Backend, error) {
|
||||
backendsMu.RLock()
|
||||
defer backendsMu.RUnlock()
|
||||
if b, ok := backends["metal"]; ok {
|
||||
return b, nil
|
||||
}
|
||||
for _, b := range backends {
|
||||
return b, nil
|
||||
}
|
||||
return nil, fmt.Errorf("mlx: no backends registered")
|
||||
}
|
||||
|
||||
// LoadModel loads a model using the specified or default backend.
|
||||
func LoadModel(path string, opts ...LoadOption) (TextModel, error) {
|
||||
cfg := ApplyLoadOpts(opts)
|
||||
if cfg.Backend != "" {
|
||||
b, ok := Get(cfg.Backend)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("mlx: backend %q not registered", cfg.Backend)
|
||||
}
|
||||
return b.LoadModel(path, opts...)
|
||||
}
|
||||
b, err := Default()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b.LoadModel(path, opts...)
|
||||
}
|
||||
4
go.mod
4
go.mod
|
|
@ -1,3 +1,7 @@
|
|||
module forge.lthn.ai/core/go-mlx
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require forge.lthn.ai/core/go-inference v0.0.0
|
||||
|
||||
replace forge.lthn.ai/core/go-inference => ../go-inference
|
||||
|
|
|
|||
|
|
@ -21,11 +21,12 @@ type ChatMessage struct {
|
|||
|
||||
// GenerateConfig holds generation parameters.
|
||||
type GenerateConfig struct {
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopK int
|
||||
TopP float32
|
||||
StopTokens []int32
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopK int
|
||||
TopP float32
|
||||
StopTokens []int32
|
||||
RepeatPenalty float32
|
||||
}
|
||||
|
||||
// Model wraps a loaded transformer model for text generation.
|
||||
|
|
|
|||
12
mlx.go
12
mlx.go
|
|
@ -1,4 +1,7 @@
|
|||
// Package mlx provides Go bindings for Apple's MLX framework.
|
||||
// Package mlx provides Apple Metal GPU inference via mlx-c bindings.
|
||||
//
|
||||
// This package implements the [inference.Backend] interface from
|
||||
// forge.lthn.ai/core/go-inference for Apple Silicon (M1-M4) GPUs.
|
||||
//
|
||||
// Build mlx-c before use:
|
||||
//
|
||||
|
|
@ -6,11 +9,14 @@
|
|||
//
|
||||
// Load a model and generate text:
|
||||
//
|
||||
// m, err := mlx.LoadModel("/path/to/model/")
|
||||
// import "forge.lthn.ai/core/go-inference"
|
||||
// import _ "forge.lthn.ai/core/go-mlx" // register Metal backend
|
||||
//
|
||||
// m, err := inference.LoadModel("/path/to/model/")
|
||||
// if err != nil { log.Fatal(err) }
|
||||
// defer m.Close()
|
||||
//
|
||||
// for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(128)) {
|
||||
// for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(128)) {
|
||||
// fmt.Print(tok.Text)
|
||||
// }
|
||||
package mlx
|
||||
|
|
|
|||
81
mlx_test.go
81
mlx_test.go
|
|
@ -7,17 +7,23 @@ import (
|
|||
"os"
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
_ "forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
||||
func TestMetalAvailable(t *testing.T) {
|
||||
if !mlx.MetalAvailable() {
|
||||
t.Fatal("MetalAvailable() = false on darwin/arm64")
|
||||
// Metal backend should be registered via init()
|
||||
b, ok := inference.Get("metal")
|
||||
if !ok {
|
||||
t.Fatal("metal backend not registered")
|
||||
}
|
||||
if !b.Available() {
|
||||
t.Fatal("metal backend reports not available on darwin/arm64")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultBackend(t *testing.T) {
|
||||
b, err := mlx.Default()
|
||||
b, err := inference.Default()
|
||||
if err != nil {
|
||||
t.Fatalf("Default() error: %v", err)
|
||||
}
|
||||
|
|
@ -27,7 +33,7 @@ func TestDefaultBackend(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetBackend(t *testing.T) {
|
||||
b, ok := mlx.Get("metal")
|
||||
b, ok := inference.Get("metal")
|
||||
if !ok {
|
||||
t.Fatal("Get(\"metal\") returned false")
|
||||
}
|
||||
|
|
@ -35,33 +41,47 @@ func TestGetBackend(t *testing.T) {
|
|||
t.Errorf("Name() = %q, want %q", b.Name(), "metal")
|
||||
}
|
||||
|
||||
_, ok = mlx.Get("nonexistent")
|
||||
_, ok = inference.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("Get(\"nonexistent\") should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListBackends(t *testing.T) {
|
||||
names := inference.List()
|
||||
found := false
|
||||
for _, name := range names {
|
||||
if name == "metal" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("List() = %v, want \"metal\" included", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadModel_NoBackend(t *testing.T) {
|
||||
_, err := mlx.LoadModel("/nonexistent/path")
|
||||
_, err := inference.LoadModel("/nonexistent/path")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent model path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadModel_WithBackend(t *testing.T) {
|
||||
_, err := mlx.LoadModel("/nonexistent/path", mlx.WithBackend("nonexistent"))
|
||||
_, err := inference.LoadModel("/nonexistent/path", inference.WithBackend("nonexistent"))
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent backend")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptions(t *testing.T) {
|
||||
cfg := mlx.ApplyGenerateOpts([]mlx.GenerateOption{
|
||||
mlx.WithMaxTokens(64),
|
||||
mlx.WithTemperature(0.7),
|
||||
mlx.WithTopK(40),
|
||||
mlx.WithTopP(0.9),
|
||||
mlx.WithStopTokens(1, 2, 3),
|
||||
cfg := inference.ApplyGenerateOpts([]inference.GenerateOption{
|
||||
inference.WithMaxTokens(64),
|
||||
inference.WithTemperature(0.7),
|
||||
inference.WithTopK(40),
|
||||
inference.WithTopP(0.9),
|
||||
inference.WithStopTokens(1, 2, 3),
|
||||
inference.WithRepeatPenalty(1.1),
|
||||
})
|
||||
if cfg.MaxTokens != 64 {
|
||||
t.Errorf("MaxTokens = %d, want 64", cfg.MaxTokens)
|
||||
|
|
@ -78,10 +98,13 @@ func TestOptions(t *testing.T) {
|
|||
if len(cfg.StopTokens) != 3 {
|
||||
t.Errorf("StopTokens len = %d, want 3", len(cfg.StopTokens))
|
||||
}
|
||||
if cfg.RepeatPenalty != 1.1 {
|
||||
t.Errorf("RepeatPenalty = %f, want 1.1", cfg.RepeatPenalty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaults(t *testing.T) {
|
||||
cfg := mlx.DefaultGenerateConfig()
|
||||
cfg := inference.DefaultGenerateConfig()
|
||||
if cfg.MaxTokens != 256 {
|
||||
t.Errorf("default MaxTokens = %d, want 256", cfg.MaxTokens)
|
||||
}
|
||||
|
|
@ -90,6 +113,30 @@ func TestDefaults(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestLoadOptions(t *testing.T) {
|
||||
cfg := inference.ApplyLoadOpts([]inference.LoadOption{
|
||||
inference.WithBackend("metal"),
|
||||
inference.WithContextLen(4096),
|
||||
inference.WithGPULayers(32),
|
||||
})
|
||||
if cfg.Backend != "metal" {
|
||||
t.Errorf("Backend = %q, want %q", cfg.Backend, "metal")
|
||||
}
|
||||
if cfg.ContextLen != 4096 {
|
||||
t.Errorf("ContextLen = %d, want 4096", cfg.ContextLen)
|
||||
}
|
||||
if cfg.GPULayers != 32 {
|
||||
t.Errorf("GPULayers = %d, want 32", cfg.GPULayers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOptionsDefaults(t *testing.T) {
|
||||
cfg := inference.ApplyLoadOpts(nil)
|
||||
if cfg.GPULayers != -1 {
|
||||
t.Errorf("default GPULayers = %d, want -1", cfg.GPULayers)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadModel_Generate requires a model on disk. Skipped in CI.
|
||||
func TestLoadModel_Generate(t *testing.T) {
|
||||
const modelPath = "/Volumes/Data/lem/safetensors/gemma-3/"
|
||||
|
|
@ -97,7 +144,7 @@ func TestLoadModel_Generate(t *testing.T) {
|
|||
t.Skip("model not available at", modelPath)
|
||||
}
|
||||
|
||||
m, err := mlx.LoadModel(modelPath)
|
||||
m, err := inference.LoadModel(modelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel: %v", err)
|
||||
}
|
||||
|
|
@ -109,7 +156,7 @@ func TestLoadModel_Generate(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
var count int
|
||||
for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(16)) {
|
||||
for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(16)) {
|
||||
count++
|
||||
t.Logf("[%d] %q", tok.ID, tok.Text)
|
||||
}
|
||||
|
|
|
|||
77
options.go
77
options.go
|
|
@ -1,77 +0,0 @@
|
|||
package mlx
|
||||
|
||||
// GenerateConfig holds generation parameters.
|
||||
type GenerateConfig struct {
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopK int
|
||||
TopP float32
|
||||
StopTokens []int32
|
||||
}
|
||||
|
||||
// DefaultGenerateConfig returns sensible defaults.
|
||||
func DefaultGenerateConfig() GenerateConfig {
|
||||
return GenerateConfig{
|
||||
MaxTokens: 256,
|
||||
Temperature: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateOption configures text generation.
|
||||
type GenerateOption func(*GenerateConfig)
|
||||
|
||||
// WithMaxTokens sets the maximum number of tokens to generate.
|
||||
func WithMaxTokens(n int) GenerateOption {
|
||||
return func(c *GenerateConfig) { c.MaxTokens = n }
|
||||
}
|
||||
|
||||
// WithTemperature sets the sampling temperature. 0 = greedy.
|
||||
func WithTemperature(t float32) GenerateOption {
|
||||
return func(c *GenerateConfig) { c.Temperature = t }
|
||||
}
|
||||
|
||||
// WithTopK sets top-k sampling. 0 = disabled.
|
||||
func WithTopK(k int) GenerateOption {
|
||||
return func(c *GenerateConfig) { c.TopK = k }
|
||||
}
|
||||
|
||||
// WithTopP sets nucleus sampling threshold. 0 = disabled.
|
||||
func WithTopP(p float32) GenerateOption {
|
||||
return func(c *GenerateConfig) { c.TopP = p }
|
||||
}
|
||||
|
||||
// WithStopTokens sets token IDs that stop generation.
|
||||
func WithStopTokens(ids ...int32) GenerateOption {
|
||||
return func(c *GenerateConfig) { c.StopTokens = ids }
|
||||
}
|
||||
|
||||
// LoadConfig holds model loading parameters.
|
||||
type LoadConfig struct {
|
||||
Backend string // "metal" (default), "mlx_lm"
|
||||
}
|
||||
|
||||
// LoadOption configures model loading.
|
||||
type LoadOption func(*LoadConfig)
|
||||
|
||||
// WithBackend selects a specific inference backend by name.
|
||||
func WithBackend(name string) LoadOption {
|
||||
return func(c *LoadConfig) { c.Backend = name }
|
||||
}
|
||||
|
||||
// ApplyGenerateOpts builds a GenerateConfig from options.
|
||||
func ApplyGenerateOpts(opts []GenerateOption) GenerateConfig {
|
||||
cfg := DefaultGenerateConfig()
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ApplyLoadOpts builds a LoadConfig from options.
|
||||
func ApplyLoadOpts(opts []LoadOption) LoadConfig {
|
||||
var cfg LoadConfig
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
|
@ -6,11 +6,12 @@ import (
|
|||
"context"
|
||||
"iter"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"forge.lthn.ai/core/go-mlx/internal/metal"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register(&metalBackend{})
|
||||
inference.Register(&metalBackend{})
|
||||
}
|
||||
|
||||
// MetalAvailable reports whether native Metal inference is available.
|
||||
|
|
@ -34,12 +35,13 @@ func GetPeakMemory() uint64 { return metal.GetPeakMemory() }
|
|||
// ClearCache clears the Metal memory cache.
|
||||
func ClearCache() { metal.ClearCache() }
|
||||
|
||||
// metalBackend implements Backend for native Metal inference.
|
||||
// metalBackend implements inference.Backend for native Metal inference.
|
||||
type metalBackend struct{}
|
||||
|
||||
func (b *metalBackend) Name() string { return "metal" }
|
||||
func (b *metalBackend) Name() string { return "metal" }
|
||||
func (b *metalBackend) Available() bool { return true }
|
||||
|
||||
func (b *metalBackend) LoadModel(path string, opts ...LoadOption) (TextModel, error) {
|
||||
func (b *metalBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
|
||||
m, err := metal.LoadAndInit(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -47,46 +49,48 @@ func (b *metalBackend) LoadModel(path string, opts ...LoadOption) (TextModel, er
|
|||
return &metalAdapter{m: m}, nil
|
||||
}
|
||||
|
||||
// metalAdapter wraps metal.Model to implement TextModel.
|
||||
// metalAdapter wraps metal.Model to implement inference.TextModel.
|
||||
type metalAdapter struct {
|
||||
m *metal.Model
|
||||
}
|
||||
|
||||
func (a *metalAdapter) Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token] {
|
||||
cfg := ApplyGenerateOpts(opts)
|
||||
func (a *metalAdapter) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
mcfg := metal.GenerateConfig{
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Temperature: cfg.Temperature,
|
||||
TopK: cfg.TopK,
|
||||
TopP: cfg.TopP,
|
||||
StopTokens: cfg.StopTokens,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Temperature: cfg.Temperature,
|
||||
TopK: cfg.TopK,
|
||||
TopP: cfg.TopP,
|
||||
StopTokens: cfg.StopTokens,
|
||||
RepeatPenalty: cfg.RepeatPenalty,
|
||||
}
|
||||
return func(yield func(Token) bool) {
|
||||
return func(yield func(inference.Token) bool) {
|
||||
for tok := range a.m.Generate(ctx, prompt, mcfg) {
|
||||
if !yield(Token{ID: tok.ID, Text: tok.Text}) {
|
||||
if !yield(inference.Token{ID: tok.ID, Text: tok.Text}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *metalAdapter) Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token] {
|
||||
cfg := ApplyGenerateOpts(opts)
|
||||
func (a *metalAdapter) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
mcfg := metal.GenerateConfig{
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Temperature: cfg.Temperature,
|
||||
TopK: cfg.TopK,
|
||||
TopP: cfg.TopP,
|
||||
StopTokens: cfg.StopTokens,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Temperature: cfg.Temperature,
|
||||
TopK: cfg.TopK,
|
||||
TopP: cfg.TopP,
|
||||
StopTokens: cfg.StopTokens,
|
||||
RepeatPenalty: cfg.RepeatPenalty,
|
||||
}
|
||||
// Convert messages
|
||||
mmsgs := make([]metal.ChatMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
mmsgs[i] = metal.ChatMessage{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
return func(yield func(Token) bool) {
|
||||
return func(yield func(inference.Token) bool) {
|
||||
for tok := range a.m.Chat(ctx, mmsgs, mcfg) {
|
||||
if !yield(Token{ID: tok.ID, Text: tok.Text}) {
|
||||
if !yield(inference.Token{ID: tok.ID, Text: tok.Text}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
|||
40
textmodel.go
40
textmodel.go
|
|
@ -1,40 +0,0 @@
|
|||
package mlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
)
|
||||
|
||||
// Token represents a single generated token for streaming.
|
||||
type Token struct {
|
||||
ID int32
|
||||
Text string
|
||||
}
|
||||
|
||||
// Message represents a chat turn for Chat().
|
||||
type Message struct {
|
||||
Role string // "user", "assistant", "system"
|
||||
Content string
|
||||
}
|
||||
|
||||
// TextModel generates text from a loaded model.
|
||||
type TextModel interface {
|
||||
// Generate streams tokens for the given prompt.
|
||||
// Respects ctx cancellation (HTTP handlers, timeouts, graceful shutdown).
|
||||
Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]
|
||||
|
||||
// Chat formats messages using the model's native chat template, then generates.
|
||||
// The model owns its template — callers don't need to know Gemma vs Qwen formatting.
|
||||
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
||||
ModelType() string
|
||||
|
||||
// Err returns the error from the last Generate/Chat call, if any.
|
||||
// Distinguishes normal stop (EOS, max tokens) from failures (OOM, C-level error).
|
||||
// Returns nil if generation completed normally.
|
||||
Err() error
|
||||
|
||||
// Close releases all resources (GPU memory, caches, subprocess).
|
||||
Close() error
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue