- Add ForwardMasked to InternalModel, Gemma3 and Qwen3 architectures - Thread attention mask through decoder layers and SDPA calls - Use ScaledDotProductAttentionWithMask when explicit mask provided - Create batch.go with padded batching, mask construction, Classify (prefill-only) and BatchGenerate (autoregressive) implementations - Wire Classify/BatchGenerate through metalAdapter to go-inference - Tests: mask unit tests (shape, values, multi-batch), Classify with 4 prompts (152 prompts/s), WithLogits, BatchGenerate with 2 prompts Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
78 lines
2.2 KiB
Go
78 lines
2.2 KiB
Go
//go:build darwin && arm64
|
|
|
|
package metal
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
)
|
|
|
|
// InternalModel is the common interface for all transformer model architectures.
|
|
type InternalModel interface {
|
|
// Forward runs the model forward pass on token IDs with KV caches.
|
|
Forward(tokens *Array, caches []Cache) *Array
|
|
|
|
// ForwardMasked runs the forward pass with an explicit attention mask.
|
|
// mask shape: [B, 1, L, L] — additive mask (0 = attend, -inf = ignore).
|
|
// Used for batched inference with padded sequences.
|
|
ForwardMasked(tokens *Array, mask *Array, caches []Cache) *Array
|
|
|
|
// NewCache creates per-layer KV caches for generation.
|
|
NewCache() []Cache
|
|
|
|
// NumLayers returns the number of transformer layers.
|
|
NumLayers() int
|
|
|
|
// Tokenizer returns the model's tokenizer.
|
|
Tokenizer() *Tokenizer
|
|
|
|
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
|
ModelType() string
|
|
|
|
// ApplyLoRA wraps target projection layers with LoRA adapters for training.
|
|
// Returns the adapter which holds references to all LoRA layers.
|
|
ApplyLoRA(cfg LoRAConfig) *LoRAAdapter
|
|
}
|
|
|
|
// QuantizationConfig holds quantization parameters from config.json.
|
|
type QuantizationConfig struct {
|
|
GroupSize int `json:"group_size"`
|
|
Bits int `json:"bits"`
|
|
}
|
|
|
|
// resolveWeight looks up a weight with optional "language_model." prefix.
|
|
func resolveWeight(weights map[string]*Array, name string) *Array {
|
|
if w, ok := weights[name]; ok {
|
|
return w
|
|
}
|
|
if w, ok := weights["language_model."+name]; ok {
|
|
return w
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// loadModel auto-detects the model architecture from config.json and loads it.
|
|
func loadModel(modelPath string) (InternalModel, error) {
|
|
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("model: load config: %w", err)
|
|
}
|
|
|
|
var probe struct {
|
|
ModelType string `json:"model_type"`
|
|
}
|
|
if err := json.Unmarshal(data, &probe); err != nil {
|
|
return nil, fmt.Errorf("model: parse model_type: %w", err)
|
|
}
|
|
|
|
switch probe.ModelType {
|
|
case "qwen3", "qwen2", "llama":
|
|
return LoadQwen3(modelPath)
|
|
case "gemma3", "gemma3_text", "gemma2":
|
|
return LoadGemma3(modelPath)
|
|
default:
|
|
return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType)
|
|
}
|
|
}
|