From 098f4963643d4c07da3dcb291c83e5dd950c20e7 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 02:22:13 +0000 Subject: [PATCH] fix: correct SDPA mask mode and slice logits to last position --- pkg/ml/backend_mlx.go | 16 ++++++++++++++++ pkg/mlx/fast.go | 4 ++-- pkg/mlx/model/gemma3.go | 10 ---------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pkg/ml/backend_mlx.go b/pkg/ml/backend_mlx.go index 8e427fdb..f26c89c8 100644 --- a/pkg/ml/backend_mlx.go +++ b/pkg/ml/backend_mlx.go @@ -87,6 +87,8 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) } logits := b.model.Forward(input, b.caches) + // Take last position: [B, L, V] → [B, V] + logits = lastPosition(logits) next := sampler.Sample(logits) mlx.Materialize(next) @@ -101,6 +103,19 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) return b.tok.Decode(output), nil } +// lastPosition extracts the last sequence position from [B, L, V] logits → [B, V]. +func lastPosition(logits *mlx.Array) *mlx.Array { + shape := logits.Shape() + if len(shape) == 3 && shape[1] > 1 { + L := shape[1] + logits = mlx.Slice(logits, []int32{0, L - 1, 0}, []int32{shape[0], L, shape[2]}) + logits = mlx.Reshape(logits, shape[0], shape[2]) + } else if len(shape) == 3 && shape[1] == 1 { + logits = mlx.Reshape(logits, shape[0], shape[2]) + } + return logits +} + // Chat formats messages and generates a response. func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) { // Format as Gemma chat @@ -148,6 +163,7 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) } logits := b.model.Forward(input, b.caches) + logits = lastPosition(logits) next := sampler.Sample(logits) mlx.Materialize(next) diff --git a/pkg/mlx/fast.go b/pkg/mlx/fast.go index 58e9e5e4..936c64a3 100644 --- a/pkg/mlx/fast.go +++ b/pkg/mlx/fast.go @@ -48,7 +48,7 @@ func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, off // ScaledDotProductAttention computes attention using a fused Metal kernel. func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array { - mode := "none" + mode := "" if causal { mode = "causal" } @@ -67,7 +67,7 @@ func ScaledDotProductAttention(query, key, value *Array, scale float32, causal b // ScaledDotProductAttentionWithMask computes attention with an explicit mask. func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array { - cMode := C.CString("none") + cMode := C.CString("array") defer C.free(unsafe.Pointer(cMode)) sinksArr := C.mlx_array_new() diff --git a/pkg/mlx/model/gemma3.go b/pkg/mlx/model/gemma3.go index c9ff7c83..5354e28f 100644 --- a/pkg/mlx/model/gemma3.go +++ b/pkg/mlx/model/gemma3.go @@ -95,9 +95,6 @@ type MLP struct { DownProj *mlx.Linear } -// debugOnce is a temporary flag for shape debugging (remove after fix). -var debugOnce = true - // compiledGELU is a singleton for the compiled GELU function. var compiledGELU *mlx.CompiledFunc @@ -372,13 +369,6 @@ func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b k := a.KProj.Forward(x) v := a.VProj.Forward(x) - if debugOnce { - slog.Info("mlx: debug", "x_shape", x.Shape(), "q_shape", q.Shape(), "k_shape", k.Shape(), "v_shape", v.Shape(), - "B", B, "L", L, "num_heads", cfg.NumAttentionHeads, "num_kv_heads", cfg.NumKeyValueHeads, "head_dim", cfg.HeadDim, - "q_scales", a.QProj.Scales != nil) - debugOnce = false - } - // Reshape to [B, num_heads, L, head_dim] q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)