go/pkg/mlx/sample/sample.go
Claude dfd370cff6
refactor: strip to pure package library
- Fix remaining 187 pkg/ files referencing core/cli → core/go
- Move SDK library code from internal/cmd/sdk/ → pkg/sdk/ (new package)
- Create pkg/rag/helpers.go with convenience functions from internal/cmd/rag/
- Fix pkg/mcp/tools_rag.go to use pkg/rag instead of internal/cmd/rag
- Fix pkg/build/buildcmd/cmd_sdk.go and pkg/release/sdk.go to use pkg/sdk
- Remove all non-library content: main.go, internal/, cmd/, docker/,
  scripts/, tasks/, tools/, .core/, .forgejo/, .woodpecker/, Taskfile.yml
- Run go mod tidy to trim unused dependencies

core/go is now a pure Go package suite (library only).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 14:22:04 +00:00

90 lines
2.4 KiB
Go

//go:build darwin && arm64 && mlx
// Package sample provides composable token sampling strategies.
package sample
import (
"math"
"forge.lthn.ai/core/go/pkg/mlx"
)
// Sampler transforms logits into a sampled token index.
type Sampler interface {
Sample(logits *mlx.Array) *mlx.Array
}
// New creates a composable sampler chain from the given parameters.
// Order: TopP -> MinP -> TopK -> Temperature -> categorical sample.
func New(temp, topP, minP float32, topK int) Sampler {
if temp == 0 {
return greedy{}
}
var samplers []Sampler
if topP > 0 && topP < 1 {
samplers = append(samplers, TopP(topP))
}
if minP > 0 {
samplers = append(samplers, MinPSampler(minP))
}
if topK > 0 {
samplers = append(samplers, TopKSampler(topK))
}
samplers = append(samplers, Temperature(temp))
return chain(samplers)
}
// chain applies a sequence of samplers, then samples from the result.
type chain []Sampler
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
for _, s := range c {
logits = s.Sample(logits)
}
// Final categorical sample from log-probabilities
return mlx.RandomCategorical(logits)
}
// greedy returns the argmax token.
type greedy struct{}
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return mlx.Argmax(logits, -1, false)
}
// Temperature scales logits by 1/temp.
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return mlx.MulScalar(logits, 1.0/float32(t))
}
// TopKSampler masks all but the top-k logits.
type TopKSampler int
func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array {
neg := mlx.Negative(logits)
mask := mlx.Argpartition(neg, int(k)-1, -1)
// Slice the indices beyond top-k
mask = mlx.SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
return mlx.PutAlongAxis(logits, mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}
// TopP implements nucleus sampling (cumulative probability threshold).
type TopP float32
func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
// TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly.
// For now, pass through. TopK + Temperature covers most use cases.
return logits
}
// MinPSampler masks tokens below min_p * max_prob.
type MinPSampler float32
func (p MinPSampler) Sample(logits *mlx.Array) *mlx.Array {
// For now, pass through — MinP is an optimization over TopP.
// Full implementation requires finding max prob and masking below threshold.
return logits
}