fix(metal): address 3 critical code review items
1. Error handler thread safety: last_mlx_error now uses _Atomic(const char*) with atomic_store_explicit/atomic_exchange_explicit (release/acquire). 2. macOS version minimum: -mmacosx-version-min changed from 26.0 to 13.3 (MLX's own minimum), no longer locks out macOS 14/15 users. 3. LoadOption applied in metalBackend.LoadModel(): calls ApplyLoadOpts(), passes ContextLen through to Model which replaces unbounded KVCache with RotatingKVCache when set. GPULayers=0 logs a warning. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
bd5668967c
commit
c96f9bd006
6 changed files with 92 additions and 18 deletions
6
TODO.md
6
TODO.md
|
|
@ -83,11 +83,11 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe
|
|||
|
||||
### Critical — Fix Before Phase 2
|
||||
|
||||
- [ ] **Error handler thread safety** — `last_mlx_error` in metal.go is a bare C `static const char*` with no synchronisation. If MLX ever calls the error handler from a background thread (e.g. async eval completion), this is a data race. Fix: use `_Atomic(const char*)` or a `pthread_mutex_t`. Even if MLX currently serialises error callbacks, this is a time bomb. Low-effort fix, high protection.
|
||||
- [x] **Error handler thread safety** — ✅ `last_mlx_error` now uses `_Atomic(const char*)` with `atomic_store_explicit` (release) / `atomic_exchange_explicit` (acquire). Thread-safe even if MLX calls the error handler from background threads.
|
||||
|
||||
- [ ] **`-mmacosx-version-min=26.0` is wrong** — metal.go line 8 sets `CGO_CFLAGS: -mmacosx-version-min=26.0`. macOS 26 is Tahoe (mid-2025+). This locks out macOS 15 Sequoia users entirely. Should be `15.0` or `14.0` depending on minimum Metal feature level needed. MLX itself targets macOS 13.3+.
|
||||
- [x] **`-mmacosx-version-min=26.0` is wrong** — ✅ Changed to `13.3` (MLX's own minimum). No longer locks out macOS 14/15 users.
|
||||
|
||||
- [ ] **`LoadOption` is ignored in `metalBackend.LoadModel()`** — `register_metal.go:59` accepts `...inference.LoadOption` but never calls `inference.ApplyLoadOpts()`. `WithContextLen()`, `WithGPULayers()`, `WithBackend()` are all silently dropped. At minimum, apply the config and pass `ContextLen` through to cache sizing.
|
||||
- [x] **`LoadOption` is ignored in `metalBackend.LoadModel()`** — ✅ Now calls `inference.ApplyLoadOpts()`. `ContextLen` passed through to `metal.LoadConfig` → stored on `Model` → replaces unbounded `KVCache` with `RotatingKVCache(contextLen)` in generate loop. `GPULayers=0` logs a warning (Metal always uses full GPU offload). New test: `TestNewCaches_ContextLen`.
|
||||
|
||||
### Important — Should Fix
|
||||
|
||||
|
|
|
|||
|
|
@ -4,17 +4,26 @@ package metal
|
|||
|
||||
import "fmt"
|
||||
|
||||
// LoadConfig holds configuration applied during model loading.
|
||||
type LoadConfig struct {
|
||||
ContextLen int // Context window size (0 = model default, unbounded KV cache)
|
||||
}
|
||||
|
||||
// LoadAndInit initialises Metal and loads a model from the given path.
|
||||
// Returns a *Model ready for generation.
|
||||
func LoadAndInit(path string) (*Model, error) {
|
||||
func LoadAndInit(path string, cfg ...LoadConfig) (*Model, error) {
|
||||
Init()
|
||||
im, err := loadModel(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("metal: %w", err)
|
||||
}
|
||||
return &Model{
|
||||
m := &Model{
|
||||
model: im,
|
||||
tokenizer: im.Tokenizer(),
|
||||
modelType: im.ModelType(),
|
||||
}, nil
|
||||
}
|
||||
if len(cfg) > 0 && cfg[0].ContextLen > 0 {
|
||||
m.contextLen = cfg[0].ContextLen
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,6 +38,48 @@ func TestLastError_NoError(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewCaches_ContextLen(t *testing.T) {
|
||||
// When contextLen is set, unbounded KVCaches should become RotatingKVCaches.
|
||||
m := &Model{
|
||||
model: &fakeModel{numLayers: 4},
|
||||
}
|
||||
|
||||
// Without contextLen — should get plain KVCaches.
|
||||
caches := m.newCaches()
|
||||
for i, c := range caches {
|
||||
if _, ok := c.(*KVCache); !ok {
|
||||
t.Errorf("cache[%d] without contextLen: got %T, want *KVCache", i, c)
|
||||
}
|
||||
}
|
||||
|
||||
// With contextLen — should get RotatingKVCaches.
|
||||
m.contextLen = 2048
|
||||
caches = m.newCaches()
|
||||
for i, c := range caches {
|
||||
if _, ok := c.(*RotatingKVCache); !ok {
|
||||
t.Errorf("cache[%d] with contextLen=2048: got %T, want *RotatingKVCache", i, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fakeModel is a minimal InternalModel for testing cache creation.
|
||||
type fakeModel struct {
|
||||
numLayers int
|
||||
}
|
||||
|
||||
func (f *fakeModel) Forward(_ *Array, _ []Cache) *Array { return nil }
|
||||
func (f *fakeModel) NewCache() []Cache {
|
||||
caches := make([]Cache, f.numLayers)
|
||||
for i := range caches {
|
||||
caches[i] = NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
func (f *fakeModel) NumLayers() int { return f.numLayers }
|
||||
func (f *fakeModel) Tokenizer() *Tokenizer { return nil }
|
||||
func (f *fakeModel) ModelType() string { return "fake" }
|
||||
func (f *fakeModel) ApplyLoRA(_ LoRAConfig) *LoRAAdapter { return nil }
|
||||
|
||||
func TestLoadAllSafetensors_MissingFile(t *testing.T) {
|
||||
_, err := LoadAllSafetensors("/nonexistent/path/model.safetensors")
|
||||
if err == nil {
|
||||
|
|
|
|||
|
|
@ -32,10 +32,11 @@ type GenerateConfig struct {
|
|||
|
||||
// Model wraps a loaded transformer model for text generation.
|
||||
type Model struct {
|
||||
model InternalModel
|
||||
tokenizer *Tokenizer
|
||||
modelType string
|
||||
lastErr error
|
||||
model InternalModel
|
||||
tokenizer *Tokenizer
|
||||
modelType string
|
||||
contextLen int // 0 = unbounded (model default)
|
||||
lastErr error
|
||||
}
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
||||
|
|
@ -73,7 +74,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
|
||||
return func(yield func(Token) bool) {
|
||||
tokens := m.tokenizer.Encode(prompt)
|
||||
caches := m.model.NewCache()
|
||||
caches := m.newCaches()
|
||||
sampler := newSampler(cfg.Temperature, cfg.TopP, 0, cfg.TopK)
|
||||
|
||||
// Prefill: process entire prompt
|
||||
|
|
@ -131,6 +132,22 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
}
|
||||
}
|
||||
|
||||
// newCaches creates per-layer KV caches. If contextLen is set, all unbounded
|
||||
// caches are replaced with RotatingKVCache to cap memory usage.
|
||||
func (m *Model) newCaches() []Cache {
|
||||
caches := m.model.NewCache()
|
||||
if m.contextLen <= 0 {
|
||||
return caches
|
||||
}
|
||||
for i, c := range caches {
|
||||
// Only replace unbounded caches — rotating caches already have a limit.
|
||||
if _, ok := c.(*KVCache); ok {
|
||||
caches[i] = NewRotatingKVCache(m.contextLen)
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// formatChat applies the model's native chat template.
|
||||
func (m *Model) formatChat(messages []ChatMessage) string {
|
||||
switch m.modelType {
|
||||
|
|
|
|||
|
|
@ -5,20 +5,21 @@ package metal
|
|||
|
||||
/*
|
||||
#cgo CXXFLAGS: -std=c++17
|
||||
#cgo CFLAGS: -mmacosx-version-min=26.0
|
||||
#cgo CFLAGS: -mmacosx-version-min=13.3
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/../../dist/include
|
||||
#cgo LDFLAGS: -L${SRCDIR}/../../dist/lib -lmlxc -lmlx
|
||||
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||
#cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/../../dist/lib
|
||||
|
||||
#include <stdatomic.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
|
||||
static const char *last_mlx_error = NULL;
|
||||
static _Atomic(const char *) last_mlx_error = NULL;
|
||||
|
||||
static void mlx_go_error_handler(const char *msg, void *data) {
|
||||
last_mlx_error = msg;
|
||||
atomic_store_explicit(&last_mlx_error, msg, memory_order_release);
|
||||
}
|
||||
|
||||
static void set_error_handler() {
|
||||
|
|
@ -26,9 +27,7 @@ static void set_error_handler() {
|
|||
}
|
||||
|
||||
static const char* get_and_clear_last_error() {
|
||||
const char *msg = last_mlx_error;
|
||||
last_mlx_error = NULL;
|
||||
return msg;
|
||||
return atomic_exchange_explicit(&last_mlx_error, NULL, memory_order_acquire);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ package mlx
|
|||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"log/slog"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"forge.lthn.ai/core/go-mlx/internal/metal"
|
||||
|
|
@ -57,7 +58,13 @@ func (b *metalBackend) Name() string { return "metal" }
|
|||
func (b *metalBackend) Available() bool { return true }
|
||||
|
||||
func (b *metalBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
|
||||
m, err := metal.LoadAndInit(path)
|
||||
cfg := inference.ApplyLoadOpts(opts)
|
||||
if cfg.GPULayers == 0 {
|
||||
slog.Warn("mlx: GPULayers=0 ignored — Metal always uses full GPU offload")
|
||||
}
|
||||
m, err := metal.LoadAndInit(path, metal.LoadConfig{
|
||||
ContextLen: cfg.ContextLen,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue