diff --git a/internal/metal/batch.go b/internal/metal/batch.go index 3c5c04e..103003f 100644 --- a/internal/metal/batch.go +++ b/internal/metal/batch.go @@ -3,11 +3,11 @@ package metal import ( + "cmp" "context" "fmt" "math" "slices" - "sort" "time" ) @@ -49,8 +49,8 @@ func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConf for i := range indices { indices[i] = i } - sort.Slice(indices, func(a, b int) bool { - return lengths[indices[a]] > lengths[indices[b]] + slices.SortFunc(indices, func(a, b int) int { + return cmp.Compare(lengths[b], lengths[a]) }) maxLen := lengths[indices[0]] @@ -160,8 +160,8 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat for i := range indices { indices[i] = i } - sort.Slice(indices, func(a, b int) bool { - return lengths[indices[a]] > lengths[indices[b]] + slices.SortFunc(indices, func(a, b int) int { + return cmp.Compare(lengths[b], lengths[a]) }) N := int32(len(prompts)) @@ -208,7 +208,7 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat maxTokens = 256 } - for step := 0; step < maxTokens; step++ { + for step := range maxTokens { select { case <-ctx.Done(): // Return partial results on cancellation. diff --git a/internal/metal/generate.go b/internal/metal/generate.go index 342d095..066cc64 100644 --- a/internal/metal/generate.go +++ b/internal/metal/generate.go @@ -4,6 +4,7 @@ package metal import ( "context" + "errors" "fmt" "iter" "slices" @@ -185,7 +186,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) Free(logits) }() - for i := 0; i < cfg.MaxTokens; i++ { + for i := range cfg.MaxTokens { select { case <-ctx.Done(): m.lastErr = ctx.Err() @@ -256,7 +257,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) func (m *Model) InspectAttention(ctx context.Context, prompt string) (*AttentionResult, error) { tokens := m.tokenizer.Encode(prompt) if len(tokens) == 0 { - return nil, fmt.Errorf("empty prompt after tokenisation") + return nil, errors.New("empty prompt after tokenisation") } caches := m.newCaches() @@ -304,7 +305,7 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention keys[i] = make([][]float32, numHeads) stride := validLen * headDim - for h := 0; h < numHeads; h++ { + for h := range numHeads { start := h * stride end := start + stride if end > len(flat) { diff --git a/internal/metal/grad.go b/internal/metal/grad.go index 485a432..b08ea9f 100644 --- a/internal/metal/grad.go +++ b/internal/metal/grad.go @@ -39,7 +39,7 @@ func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload nInputs := int(C.mlx_vector_array_size(inputs)) goInputs := make([]*Array, nInputs) - for i := 0; i < nInputs; i++ { + for i := range nInputs { a := newArray("GRAD_INPUT") C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i)) goInputs[i] = a @@ -324,7 +324,7 @@ func MSELoss(predictions, targets *Array) *Array { func vectorToArrays(vec C.mlx_vector_array) []*Array { n := int(C.mlx_vector_array_size(vec)) out := make([]*Array, n) - for i := 0; i < n; i++ { + for i := range n { a := newArray("VEC_OUT") C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i)) out[i] = a diff --git a/internal/metal/lora.go b/internal/metal/lora.go index 8d080b2..2bcb661 100644 --- a/internal/metal/lora.go +++ b/internal/metal/lora.go @@ -12,10 +12,11 @@ import ( "encoding/json" "fmt" "log/slog" + "maps" "math" "os" "path/filepath" - "sort" + "slices" "strings" "unsafe" ) @@ -152,12 +153,7 @@ func (a *LoRAAdapter) TotalParams() int { // SortedNames returns layer names in deterministic sorted order. func (a *LoRAAdapter) SortedNames() []string { - names := make([]string, 0, len(a.Layers)) - for name := range a.Layers { - names = append(names, name) - } - sort.Strings(names) - return names + return slices.Sorted(maps.Keys(a.Layers)) } // AllTrainableParams returns all trainable arrays (A and B from every layer), diff --git a/internal/metal/slice.go b/internal/metal/slice.go index 4e8b4af..8a6928d 100644 --- a/internal/metal/slice.go +++ b/internal/metal/slice.go @@ -31,7 +31,7 @@ func SliceAxis(a *Array, axis int, start, end int32) *Array { ndim := a.NumDims() starts := make([]int32, ndim) ends := make([]int32, ndim) - for i := 0; i < ndim; i++ { + for i := range ndim { starts[i] = 0 ends[i] = int32(a.Dim(i)) } diff --git a/internal/metal/tokenizer.go b/internal/metal/tokenizer.go index 6f2c1f4..ea798dc 100644 --- a/internal/metal/tokenizer.go +++ b/internal/metal/tokenizer.go @@ -191,7 +191,7 @@ func (t *Tokenizer) bpeMerge(symbols []string) []string { // Find the pair with the lowest merge rank. bestRank := -1 bestIdx := -1 - for i := 0; i < len(symbols)-1; i++ { + for i := range len(symbols) - 1 { key := symbols[i] + " " + symbols[i+1] if rank, ok := t.mergeRanks[key]; ok { if bestRank < 0 || rank < bestRank {