go/pkg/mlx/nn.go
Claude bc28aad526 feat: add native MLX backend for Apple Silicon inference (pkg/mlx)
CGo wrapper for mlx-c providing zero-Python Metal GPU inference.
Includes Gemma 3 model architecture, BPE tokenizer, KV cache,
composable sampling, and OpenAI-compatible serve command.

Build-tagged (darwin && arm64 && mlx) with stubs for cross-platform.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:53:52 +00:00

59 lines
1.6 KiB
Go

//go:build darwin && arm64 && mlx
package mlx
// Linear is a fully-connected layer: y = x @ W.T + bias.
type Linear struct {
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
// NewLinear creates a Linear layer with optional bias.
func NewLinear(weight, bias *Array) *Linear {
return &Linear{Weight: weight, Bias: bias}
}
// Forward computes the linear transformation.
func (l *Linear) Forward(x *Array) *Array {
out := Matmul(x, Transpose(l.Weight))
if l.Bias != nil && l.Bias.Valid() {
out = Add(out, l.Bias)
}
return out
}
// Embedding is a lookup table for token embeddings.
type Embedding struct {
Weight *Array `weight:"weight"`
}
// Forward looks up embeddings for the given token indices.
func (e *Embedding) Forward(indices *Array) *Array {
return Take(e.Weight, indices, 0)
}
// RMSNormModule is an RMS normalization layer wrapping the fused kernel.
type RMSNormModule struct {
Weight *Array `weight:"weight"`
}
// Forward applies RMS normalization.
func (r *RMSNormModule) Forward(x *Array, eps float32) *Array {
return RMSNorm(x, r.Weight, eps)
}
// RepeatKV repeats key/value heads for grouped-query attention.
// Input shape: [B, num_kv_heads, L, D]
// Output shape: [B, num_kv_heads * factor, L, D]
func RepeatKV(x *Array, factor int32) *Array {
if factor <= 1 {
return x
}
shape := x.Shape()
B, H, L, D := shape[0], shape[1], shape[2], shape[3]
// Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D]
expanded := ExpandDims(x, 2)
expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D})
return Reshape(expanded, B, H*factor, L, D)
}