diff --git a/pkg/mlx/sample/sample.go b/pkg/mlx/sample/sample.go index 641c99b..d267f7a 100644 --- a/pkg/mlx/sample/sample.go +++ b/pkg/mlx/sample/sample.go @@ -75,24 +75,9 @@ func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array { type TopP float32 func (p TopP) Sample(logits *mlx.Array) *mlx.Array { - // Softmax to get probabilities - probs := mlx.Softmax(logits) - // Sort descending - neg := mlx.Negative(probs) - sortedIdx := mlx.Argpartition(neg, 0, -1) - sortedProbs := mlx.Take(probs, sortedIdx, -1) - - // Cumulative sum - cumProbs := mlx.Sum(sortedProbs, -1, true) // simplified — full impl needs cumsum - - // Mask tokens beyond threshold - threshold := mlx.FromValue(float32(p)) - mask := mlx.Where( - mlx.FromValue(true), // placeholder — proper impl compares cumprobs > p - mlx.FromValue(float32(math.Inf(-1))), - logits, - ) - return mask + // TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly. + // For now, pass through. TopK + Temperature covers most use cases. + return logits } // MinPSampler masks tokens below min_p * max_prob.