feat/updates #1

Merged
Snider merged 51 commits from feat/updates into dev 2026-02-16 05:54:07 +00:00
Showing only changes of commit e9d9a3c3a0 - Show all commits

View file

@ -75,24 +75,9 @@ func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array {
type TopP float32
func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
// Softmax to get probabilities
probs := mlx.Softmax(logits)
// Sort descending
neg := mlx.Negative(probs)
sortedIdx := mlx.Argpartition(neg, 0, -1)
sortedProbs := mlx.Take(probs, sortedIdx, -1)
// Cumulative sum
cumProbs := mlx.Sum(sortedProbs, -1, true) // simplified — full impl needs cumsum
// Mask tokens beyond threshold
threshold := mlx.FromValue(float32(p))
mask := mlx.Where(
mlx.FromValue(true), // placeholder — proper impl compares cumprobs > p
mlx.FromValue(float32(math.Inf(-1))),
logits,
)
return mask
// TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly.
// For now, pass through. TopK + Temperature covers most use cases.
return logits
}
// MinPSampler masks tokens below min_p * max_prob.