feat: add Metal memory budget monitoring after each request
Tracks model size at load time and checks Metal active memory after each generation. If usage exceeds 3× model size, forces double GC and cache clear as a safety net. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c5689c3e83
commit
045f8fc110
1 changed files with 33 additions and 12 deletions
|
|
@ -18,11 +18,12 @@ import (
|
|||
|
||||
// MLXBackend implements Backend for native Metal inference via mlx-c.
|
||||
type MLXBackend struct {
|
||||
model *model.GemmaModel
|
||||
tok *tokenizer.Tokenizer
|
||||
caches []cache.Cache
|
||||
sampler sample.Sampler
|
||||
mu sync.Mutex
|
||||
model *model.GemmaModel
|
||||
tok *tokenizer.Tokenizer
|
||||
caches []cache.Cache
|
||||
sampler sample.Sampler
|
||||
mu sync.Mutex
|
||||
modelBytes uint64 // model size at load time, for memory budget
|
||||
}
|
||||
|
||||
// NewMLXBackend loads a model from a safetensors directory and creates
|
||||
|
|
@ -43,16 +44,18 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
|||
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
|
||||
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
|
||||
|
||||
modelMB := mlx.GetActiveMemory() / 1024 / 1024
|
||||
slog.Info("mlx: model loaded",
|
||||
"layers", m.NumLayers(),
|
||||
"memory_mb", mlx.GetActiveMemory()/1024/1024,
|
||||
"memory_mb", modelMB,
|
||||
)
|
||||
|
||||
return &MLXBackend{
|
||||
model: m,
|
||||
tok: m.Tokenizer(),
|
||||
caches: m.NewCache(),
|
||||
sampler: sample.New(0.1, 0, 0, 0), // default low temp
|
||||
model: m,
|
||||
tok: m.Tokenizer(),
|
||||
caches: m.NewCache(),
|
||||
sampler: sample.New(0.1, 0, 0, 0), // default low temp
|
||||
modelBytes: mlx.GetActiveMemory(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -114,9 +117,10 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
|||
}
|
||||
}
|
||||
|
||||
// Full cleanup between requests
|
||||
// Cleanup between requests
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
b.checkMemory()
|
||||
return b.tok.Decode(output), nil
|
||||
}
|
||||
|
||||
|
|
@ -200,12 +204,29 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
|
|||
}
|
||||
}
|
||||
|
||||
// Full cleanup between requests
|
||||
// Cleanup between requests
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
b.checkMemory()
|
||||
return b.tok.Decode(output), nil
|
||||
}
|
||||
|
||||
// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget.
|
||||
func (b *MLXBackend) checkMemory() {
|
||||
active := mlx.GetActiveMemory()
|
||||
budget := b.modelBytes * 3 // 3× model size = danger zone
|
||||
if active > budget {
|
||||
slog.Warn("mlx: memory over budget, forcing cleanup",
|
||||
"active_mb", active/1024/1024,
|
||||
"model_mb", b.modelBytes/1024/1024,
|
||||
"peak_mb", mlx.GetPeakMemory()/1024/1024,
|
||||
)
|
||||
runtime.GC()
|
||||
runtime.GC() // double GC to run finalizers
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the backend identifier.
|
||||
func (b *MLXBackend) Name() string { return "mlx" }
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue