fix: correct SDPA mask mode and slice logits to last position

This commit is contained in:
Claude 2026-02-16 02:22:13 +00:00 committed by Snider
parent 09da05d799
commit 098f496364
3 changed files with 18 additions and 12 deletions

View file

@ -87,6 +87,8 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
}
logits := b.model.Forward(input, b.caches)
// Take last position: [B, L, V] → [B, V]
logits = lastPosition(logits)
next := sampler.Sample(logits)
mlx.Materialize(next)
@ -101,6 +103,19 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
return b.tok.Decode(output), nil
}
// lastPosition extracts the last sequence position from [B, L, V] logits → [B, V].
func lastPosition(logits *mlx.Array) *mlx.Array {
shape := logits.Shape()
if len(shape) == 3 && shape[1] > 1 {
L := shape[1]
logits = mlx.Slice(logits, []int32{0, L - 1, 0}, []int32{shape[0], L, shape[2]})
logits = mlx.Reshape(logits, shape[0], shape[2])
} else if len(shape) == 3 && shape[1] == 1 {
logits = mlx.Reshape(logits, shape[0], shape[2])
}
return logits
}
// Chat formats messages and generates a response.
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
// Format as Gemma chat
@ -148,6 +163,7 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
}
logits := b.model.Forward(input, b.caches)
logits = lastPosition(logits)
next := sampler.Sample(logits)
mlx.Materialize(next)

View file

@ -48,7 +48,7 @@ func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, off
// ScaledDotProductAttention computes attention using a fused Metal kernel.
func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array {
mode := "none"
mode := ""
if causal {
mode = "causal"
}
@ -67,7 +67,7 @@ func ScaledDotProductAttention(query, key, value *Array, scale float32, causal b
// ScaledDotProductAttentionWithMask computes attention with an explicit mask.
func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array {
cMode := C.CString("none")
cMode := C.CString("array")
defer C.free(unsafe.Pointer(cMode))
sinksArr := C.mlx_array_new()

View file

@ -95,9 +95,6 @@ type MLP struct {
DownProj *mlx.Linear
}
// debugOnce is a temporary flag for shape debugging (remove after fix).
var debugOnce = true
// compiledGELU is a singleton for the compiled GELU function.
var compiledGELU *mlx.CompiledFunc
@ -372,13 +369,6 @@ func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
if debugOnce {
slog.Info("mlx: debug", "x_shape", x.Shape(), "q_shape", q.Shape(), "k_shape", k.Shape(), "v_shape", v.Shape(),
"B", B, "L", L, "num_heads", cfg.NumAttentionHeads, "num_kv_heads", cfg.NumKeyValueHeads, "head_dim", cfg.HeadDim,
"q_scales", a.QProj.Scales != nil)
debugOnce = false
}
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)