diff --git a/TODO.md b/TODO.md index 8fbbf48..39439be 100644 --- a/TODO.md +++ b/TODO.md @@ -83,11 +83,11 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe ### Critical — Fix Before Phase 2 -- [ ] **Error handler thread safety** — `last_mlx_error` in metal.go is a bare C `static const char*` with no synchronisation. If MLX ever calls the error handler from a background thread (e.g. async eval completion), this is a data race. Fix: use `_Atomic(const char*)` or a `pthread_mutex_t`. Even if MLX currently serialises error callbacks, this is a time bomb. Low-effort fix, high protection. +- [x] **Error handler thread safety** — ✅ `last_mlx_error` now uses `_Atomic(const char*)` with `atomic_store_explicit` (release) / `atomic_exchange_explicit` (acquire). Thread-safe even if MLX calls the error handler from background threads. -- [ ] **`-mmacosx-version-min=26.0` is wrong** — metal.go line 8 sets `CGO_CFLAGS: -mmacosx-version-min=26.0`. macOS 26 is Tahoe (mid-2025+). This locks out macOS 15 Sequoia users entirely. Should be `15.0` or `14.0` depending on minimum Metal feature level needed. MLX itself targets macOS 13.3+. +- [x] **`-mmacosx-version-min=26.0` is wrong** — ✅ Changed to `13.3` (MLX's own minimum). No longer locks out macOS 14/15 users. -- [ ] **`LoadOption` is ignored in `metalBackend.LoadModel()`** — `register_metal.go:59` accepts `...inference.LoadOption` but never calls `inference.ApplyLoadOpts()`. `WithContextLen()`, `WithGPULayers()`, `WithBackend()` are all silently dropped. At minimum, apply the config and pass `ContextLen` through to cache sizing. +- [x] **`LoadOption` is ignored in `metalBackend.LoadModel()`** — ✅ Now calls `inference.ApplyLoadOpts()`. `ContextLen` passed through to `metal.LoadConfig` → stored on `Model` → replaces unbounded `KVCache` with `RotatingKVCache(contextLen)` in generate loop. `GPULayers=0` logs a warning (Metal always uses full GPU offload). New test: `TestNewCaches_ContextLen`. ### Important — Should Fix diff --git a/internal/metal/backend.go b/internal/metal/backend.go index e5e598b..79d7d97 100644 --- a/internal/metal/backend.go +++ b/internal/metal/backend.go @@ -4,17 +4,26 @@ package metal import "fmt" +// LoadConfig holds configuration applied during model loading. +type LoadConfig struct { + ContextLen int // Context window size (0 = model default, unbounded KV cache) +} + // LoadAndInit initialises Metal and loads a model from the given path. // Returns a *Model ready for generation. -func LoadAndInit(path string) (*Model, error) { +func LoadAndInit(path string, cfg ...LoadConfig) (*Model, error) { Init() im, err := loadModel(path) if err != nil { return nil, fmt.Errorf("metal: %w", err) } - return &Model{ + m := &Model{ model: im, tokenizer: im.Tokenizer(), modelType: im.ModelType(), - }, nil + } + if len(cfg) > 0 && cfg[0].ContextLen > 0 { + m.contextLen = cfg[0].ContextLen + } + return m, nil } diff --git a/internal/metal/error_test.go b/internal/metal/error_test.go index ba26906..9893491 100644 --- a/internal/metal/error_test.go +++ b/internal/metal/error_test.go @@ -38,6 +38,48 @@ func TestLastError_NoError(t *testing.T) { } } +func TestNewCaches_ContextLen(t *testing.T) { + // When contextLen is set, unbounded KVCaches should become RotatingKVCaches. + m := &Model{ + model: &fakeModel{numLayers: 4}, + } + + // Without contextLen — should get plain KVCaches. + caches := m.newCaches() + for i, c := range caches { + if _, ok := c.(*KVCache); !ok { + t.Errorf("cache[%d] without contextLen: got %T, want *KVCache", i, c) + } + } + + // With contextLen — should get RotatingKVCaches. + m.contextLen = 2048 + caches = m.newCaches() + for i, c := range caches { + if _, ok := c.(*RotatingKVCache); !ok { + t.Errorf("cache[%d] with contextLen=2048: got %T, want *RotatingKVCache", i, c) + } + } +} + +// fakeModel is a minimal InternalModel for testing cache creation. +type fakeModel struct { + numLayers int +} + +func (f *fakeModel) Forward(_ *Array, _ []Cache) *Array { return nil } +func (f *fakeModel) NewCache() []Cache { + caches := make([]Cache, f.numLayers) + for i := range caches { + caches[i] = NewKVCache() + } + return caches +} +func (f *fakeModel) NumLayers() int { return f.numLayers } +func (f *fakeModel) Tokenizer() *Tokenizer { return nil } +func (f *fakeModel) ModelType() string { return "fake" } +func (f *fakeModel) ApplyLoRA(_ LoRAConfig) *LoRAAdapter { return nil } + func TestLoadAllSafetensors_MissingFile(t *testing.T) { _, err := LoadAllSafetensors("/nonexistent/path/model.safetensors") if err == nil { diff --git a/internal/metal/generate.go b/internal/metal/generate.go index 77c7a30..c50ad9e 100644 --- a/internal/metal/generate.go +++ b/internal/metal/generate.go @@ -32,10 +32,11 @@ type GenerateConfig struct { // Model wraps a loaded transformer model for text generation. type Model struct { - model InternalModel - tokenizer *Tokenizer - modelType string - lastErr error + model InternalModel + tokenizer *Tokenizer + modelType string + contextLen int // 0 = unbounded (model default) + lastErr error } // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3"). @@ -73,7 +74,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) return func(yield func(Token) bool) { tokens := m.tokenizer.Encode(prompt) - caches := m.model.NewCache() + caches := m.newCaches() sampler := newSampler(cfg.Temperature, cfg.TopP, 0, cfg.TopK) // Prefill: process entire prompt @@ -131,6 +132,22 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) } } +// newCaches creates per-layer KV caches. If contextLen is set, all unbounded +// caches are replaced with RotatingKVCache to cap memory usage. +func (m *Model) newCaches() []Cache { + caches := m.model.NewCache() + if m.contextLen <= 0 { + return caches + } + for i, c := range caches { + // Only replace unbounded caches — rotating caches already have a limit. + if _, ok := c.(*KVCache); ok { + caches[i] = NewRotatingKVCache(m.contextLen) + } + } + return caches +} + // formatChat applies the model's native chat template. func (m *Model) formatChat(messages []ChatMessage) string { switch m.modelType { diff --git a/internal/metal/metal.go b/internal/metal/metal.go index 818ceef..4361814 100644 --- a/internal/metal/metal.go +++ b/internal/metal/metal.go @@ -5,20 +5,21 @@ package metal /* #cgo CXXFLAGS: -std=c++17 -#cgo CFLAGS: -mmacosx-version-min=26.0 +#cgo CFLAGS: -mmacosx-version-min=13.3 #cgo CPPFLAGS: -I${SRCDIR}/../../dist/include #cgo LDFLAGS: -L${SRCDIR}/../../dist/lib -lmlxc -lmlx #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate #cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/../../dist/lib +#include #include #include #include "mlx/c/mlx.h" -static const char *last_mlx_error = NULL; +static _Atomic(const char *) last_mlx_error = NULL; static void mlx_go_error_handler(const char *msg, void *data) { - last_mlx_error = msg; + atomic_store_explicit(&last_mlx_error, msg, memory_order_release); } static void set_error_handler() { @@ -26,9 +27,7 @@ static void set_error_handler() { } static const char* get_and_clear_last_error() { - const char *msg = last_mlx_error; - last_mlx_error = NULL; - return msg; + return atomic_exchange_explicit(&last_mlx_error, NULL, memory_order_acquire); } */ import "C" diff --git a/register_metal.go b/register_metal.go index 754c2aa..7f0ffe6 100644 --- a/register_metal.go +++ b/register_metal.go @@ -5,6 +5,7 @@ package mlx import ( "context" "iter" + "log/slog" "forge.lthn.ai/core/go-inference" "forge.lthn.ai/core/go-mlx/internal/metal" @@ -57,7 +58,13 @@ func (b *metalBackend) Name() string { return "metal" } func (b *metalBackend) Available() bool { return true } func (b *metalBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) { - m, err := metal.LoadAndInit(path) + cfg := inference.ApplyLoadOpts(opts) + if cfg.GPULayers == 0 { + slog.Warn("mlx: GPULayers=0 ignored — Metal always uses full GPU offload") + } + m, err := metal.LoadAndInit(path, metal.LoadConfig{ + ContextLen: cfg.ContextLen, + }) if err != nil { return nil, err }