feat(metal): bind CumSum, implement TopP and MinP sampling

New ops: CumSum, Sort, Argsort, Greater, MaxAxis — all bound to mlx-c.

TopP (nucleus) sampling now fully implemented: sorts probabilities
descending, computes cumulative sum, masks tokens beyond the threshold,
and scatters the mask back to original positions via argsort.

MinP sampling now fully implemented: computes softmax, finds max
probability, masks tokens below min_p * max_prob.

Both were previously stubs that passed through logits unchanged.

10 new tests (CumSum variants, Sort, Argsort, Greater, MaxAxis,
TopP, MinP). 176 total tests passing.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-19 20:39:44 +00:00
parent df0b300b1a
commit f39126f6bd
4 changed files with 212 additions and 16 deletions

View file

@ -351,3 +351,39 @@ func LogSumExp(a *Array, axis int, keepDims bool) *Array {
C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
return out
}
// CumSum returns the cumulative sum along the given axis.
// reverse=false for forward, inclusive=true to include the current element.
func CumSum(a *Array, axis int, reverse, inclusive bool) *Array {
out := New("CUMSUM", a)
C.mlx_cumsum(&out.ctx, a.ctx, C.int(axis), C._Bool(reverse), C._Bool(inclusive), DefaultStream().ctx)
return out
}
// Sort returns the array sorted along the given axis.
func Sort(a *Array, axis int) *Array {
out := New("SORT", a)
C.mlx_sort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// Argsort returns the indices that would sort the array along the given axis.
func Argsort(a *Array, axis int) *Array {
out := New("ARGSORT", a)
C.mlx_argsort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// Greater returns element-wise a > b as a bool array.
func Greater(a, b *Array) *Array {
out := New("GREATER", a, b)
C.mlx_greater(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// MaxAxis returns the maximum value along the given axis.
func MaxAxis(a *Array, axis int, keepDims bool) *Array {
out := New("MAX_AXIS", a)
C.mlx_max_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
return out
}

View file

@ -515,6 +515,87 @@ func TestAdd_Broadcasting(t *testing.T) {
// --- Random ---
// --- Cumulative and sorting ops ---
func TestCumSum(t *testing.T) {
a := FromValues([]float32{1, 2, 3, 4}, 1, 4)
c := CumSum(a, -1, false, true)
Materialize(c)
floatSliceApprox(t, c.Floats(), []float32{1, 3, 6, 10})
}
func TestCumSum_Exclusive(t *testing.T) {
a := FromValues([]float32{1, 2, 3, 4}, 1, 4)
c := CumSum(a, -1, false, false) // exclusive
Materialize(c)
floatSliceApprox(t, c.Floats(), []float32{0, 1, 3, 6})
}
func TestCumSum_Reverse(t *testing.T) {
a := FromValues([]float32{1, 2, 3, 4}, 1, 4)
c := CumSum(a, -1, true, true) // reverse
Materialize(c)
floatSliceApprox(t, c.Floats(), []float32{10, 9, 7, 4})
}
func TestSort(t *testing.T) {
a := FromValues([]float32{3, 1, 4, 1, 5}, 1, 5)
c := Sort(a, -1)
Materialize(c)
floatSliceApprox(t, c.Floats(), []float32{1, 1, 3, 4, 5})
}
func TestArgsort(t *testing.T) {
a := FromValues([]float32{3, 1, 4, 1, 5}, 1, 5)
c := Argsort(a, -1)
Materialize(c)
// indices of sorted order: [1, 3, 0, 2, 4]
got := c.Ints()
want := []int{1, 3, 0, 2, 4}
for i := range got {
if got[i] != want[i] {
t.Errorf("Argsort[%d] = %d, want %d", i, got[i], want[i])
}
}
}
func TestGreater(t *testing.T) {
a := FromValues([]float32{1, 5, 3}, 3)
b := FromValues([]float32{2, 2, 3}, 3)
c := Greater(a, b)
// Greater returns bool dtype — cast to int32 for data extraction
c = AsType(c, DTypeInt32)
Materialize(c)
// 1>2=false, 5>2=true, 3>3=false
got := c.DataInt32()
want := []int32{0, 1, 0}
for i := range got {
if got[i] != want[i] {
t.Errorf("Greater[%d] = %d, want %d", i, got[i], want[i])
}
}
}
func TestMaxAxis(t *testing.T) {
a := FromValues([]float32{1, 5, 3, 4, 2, 6}, 2, 3)
c := MaxAxis(a, -1, false) // max per row
Materialize(c)
floatSliceApprox(t, c.Floats(), []float32{5, 6})
}
func TestMaxAxis_KeepDims(t *testing.T) {
a := FromValues([]float32{1, 5, 3, 4, 2, 6}, 2, 3)
c := MaxAxis(a, -1, true)
Materialize(c)
shape := c.Shape()
if shape[0] != 2 || shape[1] != 1 {
t.Errorf("shape = %v, want [2 1]", shape)
}
}
// --- Random ---
func TestRandomCategorical(t *testing.T) {
// Heavily weighted towards index 2
logprobs := FromValues([]float32{-100, -100, 0}, 1, 3)

View file

@ -68,20 +68,65 @@ func (k TopKSampler) Sample(logits *Array) *Array {
return PutAlongAxis(logits, mask, FromValue(float32(math.Inf(-1))), -1)
}
// TopP implements nucleus sampling (cumulative probability threshold).
// TopP implements nucleus (top-p) sampling.
// Keeps the smallest set of tokens whose cumulative probability exceeds p.
type TopP float32
func (p TopP) Sample(logits *Array) *Array {
// TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly.
// For now, pass through. TopK + Temperature covers most use cases.
return logits
// Convert logits to probabilities
probs := Softmax(logits)
// Sort descending via argsort of negated probs
neg := Negative(probs)
sortIdx := Argsort(neg, -1)
sortedProbs := TakeAlongAxis(probs, sortIdx, -1)
// Cumulative sum of sorted probabilities
cumProbs := CumSum(sortedProbs, -1, false, true)
// Mask in sorted space: keep tokens where cumprob (excluding current) <= threshold
shiftedCum := Subtract(cumProbs, sortedProbs)
threshold := FromValue(float32(p))
sortedMask := Where(
Greater(shiftedCum, threshold),
FromValue(float32(math.Inf(-1))),
FromValue(float32(0)),
)
// Scatter mask back to original positions
mask := PutAlongAxis(
Zeros(logits.Shape(), DTypeFloat32),
sortIdx,
sortedMask,
-1,
)
// Apply mask: -inf where excluded, original logit where kept
return Where(
Greater(FromValue(float32(0)), mask),
FromValue(float32(math.Inf(-1))),
logits,
)
}
// MinPSampler masks tokens below min_p * max_prob.
// MinPSampler masks tokens whose probability is below min_p * max_prob.
type MinPSampler float32
func (p MinPSampler) Sample(logits *Array) *Array {
// For now, pass through — MinP is an optimization over TopP.
// Full implementation requires finding max prob and masking below threshold.
return logits
// Convert logits to probabilities
probs := Softmax(logits)
// Find the maximum probability
maxProb := MaxAxis(probs, -1, true)
// Threshold = min_p * max_prob
threshold := MulScalar(maxProb, float32(p))
// Mask tokens below threshold
mask := Where(
Greater(threshold, probs),
FromValue(float32(math.Inf(-1))),
logits,
)
return mask
}

View file

@ -89,26 +89,60 @@ func TestNew_Chain(t *testing.T) {
}
}
func TestTopP_PassThrough(t *testing.T) {
// TopP is currently a stub — verify it doesn't break the chain
func TestTopP_DominantLogit(t *testing.T) {
// With one dominant logit, TopP should always pick it
logits := FromValues([]float32{-10, -10, 100, -10}, 1, 4)
s := newSampler(0.5, 0.9, 0, 0) // topP=0.9 (stub), temp=0.5
s := newSampler(0.5, 0.9, 0, 0) // topP=0.9, temp=0.5
token := s.Sample(logits)
Materialize(token)
if token.Int() != 2 {
t.Errorf("topP stub + temp sample = %d, want 2", token.Int())
t.Errorf("topP dominant sample = %d, want 2", token.Int())
}
}
func TestMinP_PassThrough(t *testing.T) {
// MinP is currently a stub — verify it doesn't break the chain
func TestTopP_RestrictsOptions(t *testing.T) {
// Two equal high logits, two low. TopP=0.5 should mostly restrict to top tokens.
logits := FromValues([]float32{10, 10, -100, -100}, 1, 4)
s := newSampler(1.0, 0.5, 0, 0) // topP=0.5, temp=1.0
seen := map[int]bool{}
for range 30 {
token := s.Sample(logits)
Materialize(token)
seen[token.Int()] = true
}
// Should only pick indices 0 or 1 (the two high-probability tokens)
for idx := range seen {
if idx != 0 && idx != 1 {
t.Errorf("topP=0.5 sampled index %d, expected only 0 or 1", idx)
}
}
}
func TestMinP_DominantLogit(t *testing.T) {
// With one dominant logit, MinP should always pick it
logits := FromValues([]float32{-10, -10, 100, -10}, 1, 4)
s := newSampler(0.5, 0, 0.1, 0) // minP=0.1 (stub), temp=0.5
s := newSampler(0.5, 0, 0.1, 0) // minP=0.1, temp=0.5
token := s.Sample(logits)
Materialize(token)
if token.Int() != 2 {
t.Errorf("minP stub + temp sample = %d, want 2", token.Int())
t.Errorf("minP dominant sample = %d, want 2", token.Int())
}
}
func TestMinP_RestrictsOptions(t *testing.T) {
// One very high logit, rest are low. MinP=0.1 should mask the low tokens.
logits := FromValues([]float32{-100, 50, -100, -100}, 1, 4)
s := newSampler(1.0, 0, 0.1, 0) // minP=0.1, temp=1.0
for range 20 {
token := s.Sample(logits)
Materialize(token)
if token.Int() != 1 {
t.Errorf("minP with dominant logit sampled %d, want 1", token.Int())
}
}
}