feat/updates #1
1 changed files with 3 additions and 18 deletions
|
|
@ -75,24 +75,9 @@ func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array {
|
|||
type TopP float32
|
||||
|
||||
func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
|
||||
// Softmax to get probabilities
|
||||
probs := mlx.Softmax(logits)
|
||||
// Sort descending
|
||||
neg := mlx.Negative(probs)
|
||||
sortedIdx := mlx.Argpartition(neg, 0, -1)
|
||||
sortedProbs := mlx.Take(probs, sortedIdx, -1)
|
||||
|
||||
// Cumulative sum
|
||||
cumProbs := mlx.Sum(sortedProbs, -1, true) // simplified — full impl needs cumsum
|
||||
|
||||
// Mask tokens beyond threshold
|
||||
threshold := mlx.FromValue(float32(p))
|
||||
mask := mlx.Where(
|
||||
mlx.FromValue(true), // placeholder — proper impl compares cumprobs > p
|
||||
mlx.FromValue(float32(math.Inf(-1))),
|
||||
logits,
|
||||
)
|
||||
return mask
|
||||
// TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly.
|
||||
// For now, pass through. TopK + Temperature covers most use cases.
|
||||
return logits
|
||||
}
|
||||
|
||||
// MinPSampler masks tokens below min_p * max_prob.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue