go-ai/mlx/random.go
Claude e84d6ad3c9
feat: extract AI/ML packages from core/go
LEM scoring pipeline, native MLX Metal bindings, Claude SDK wrapper,
RAG with Qdrant/Ollama, unified AI facade, and MCP protocol server.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 15:25:55 +00:00

46 lines
1.1 KiB
Go

//go:build darwin && arm64 && mlx
package mlx
/*
#include "mlx/c/mlx.h"
*/
import "C"
// RandomCategorical samples from a categorical distribution defined by logprobs.
// Returns indices sampled according to the log-probability distribution along the last axis.
func RandomCategorical(logprobs *Array) *Array {
out := New("RANDOM_CATEGORICAL", logprobs)
key := C.mlx_array_new()
defer C.mlx_array_free(key)
C.mlx_random_categorical(
&out.ctx,
logprobs.ctx,
C.int(-1), // axis
key, // null key = use default RNG
DefaultStream().ctx,
)
return out
}
// RandomUniform generates uniform random values in [low, high).
func RandomUniform(low, high float32, shape []int32, dtype DType) *Array {
out := New("RANDOM_UNIFORM")
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
lo := FromValue(low)
hi := FromValue(high)
key := C.mlx_array_new()
defer C.mlx_array_free(key)
C.mlx_random_uniform(
&out.ctx,
lo.ctx, hi.ctx,
&cShape[0], C.size_t(len(cShape)),
C.mlx_dtype(dtype),
key,
DefaultStream().ctx,
)
return out
}