Merge pull request 'chore: Go 1.26 modernization' (#2) from chore/go-1.26-modernization into main
This commit is contained in:
commit
c1baeb9254
6 changed files with 17 additions and 20 deletions
|
|
@ -3,11 +3,11 @@
|
|||
package metal
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -49,8 +49,8 @@ func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConf
|
|||
for i := range indices {
|
||||
indices[i] = i
|
||||
}
|
||||
sort.Slice(indices, func(a, b int) bool {
|
||||
return lengths[indices[a]] > lengths[indices[b]]
|
||||
slices.SortFunc(indices, func(a, b int) int {
|
||||
return cmp.Compare(lengths[b], lengths[a])
|
||||
})
|
||||
|
||||
maxLen := lengths[indices[0]]
|
||||
|
|
@ -160,8 +160,8 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat
|
|||
for i := range indices {
|
||||
indices[i] = i
|
||||
}
|
||||
sort.Slice(indices, func(a, b int) bool {
|
||||
return lengths[indices[a]] > lengths[indices[b]]
|
||||
slices.SortFunc(indices, func(a, b int) int {
|
||||
return cmp.Compare(lengths[b], lengths[a])
|
||||
})
|
||||
|
||||
N := int32(len(prompts))
|
||||
|
|
@ -208,7 +208,7 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat
|
|||
maxTokens = 256
|
||||
}
|
||||
|
||||
for step := 0; step < maxTokens; step++ {
|
||||
for step := range maxTokens {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Return partial results on cancellation.
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ package metal
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"iter"
|
||||
"slices"
|
||||
|
|
@ -185,7 +186,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
Free(logits)
|
||||
}()
|
||||
|
||||
for i := 0; i < cfg.MaxTokens; i++ {
|
||||
for i := range cfg.MaxTokens {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.lastErr = ctx.Err()
|
||||
|
|
@ -256,7 +257,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
func (m *Model) InspectAttention(ctx context.Context, prompt string) (*AttentionResult, error) {
|
||||
tokens := m.tokenizer.Encode(prompt)
|
||||
if len(tokens) == 0 {
|
||||
return nil, fmt.Errorf("empty prompt after tokenisation")
|
||||
return nil, errors.New("empty prompt after tokenisation")
|
||||
}
|
||||
|
||||
caches := m.newCaches()
|
||||
|
|
@ -304,7 +305,7 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention
|
|||
|
||||
keys[i] = make([][]float32, numHeads)
|
||||
stride := validLen * headDim
|
||||
for h := 0; h < numHeads; h++ {
|
||||
for h := range numHeads {
|
||||
start := h * stride
|
||||
end := start + stride
|
||||
if end > len(flat) {
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload
|
|||
|
||||
nInputs := int(C.mlx_vector_array_size(inputs))
|
||||
goInputs := make([]*Array, nInputs)
|
||||
for i := 0; i < nInputs; i++ {
|
||||
for i := range nInputs {
|
||||
a := newArray("GRAD_INPUT")
|
||||
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
|
||||
goInputs[i] = a
|
||||
|
|
@ -324,7 +324,7 @@ func MSELoss(predictions, targets *Array) *Array {
|
|||
func vectorToArrays(vec C.mlx_vector_array) []*Array {
|
||||
n := int(C.mlx_vector_array_size(vec))
|
||||
out := make([]*Array, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
a := newArray("VEC_OUT")
|
||||
C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i))
|
||||
out[i] = a
|
||||
|
|
|
|||
|
|
@ -12,10 +12,11 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
|
@ -152,12 +153,7 @@ func (a *LoRAAdapter) TotalParams() int {
|
|||
|
||||
// SortedNames returns layer names in deterministic sorted order.
|
||||
func (a *LoRAAdapter) SortedNames() []string {
|
||||
names := make([]string, 0, len(a.Layers))
|
||||
for name := range a.Layers {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
return slices.Sorted(maps.Keys(a.Layers))
|
||||
}
|
||||
|
||||
// AllTrainableParams returns all trainable arrays (A and B from every layer),
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ func SliceAxis(a *Array, axis int, start, end int32) *Array {
|
|||
ndim := a.NumDims()
|
||||
starts := make([]int32, ndim)
|
||||
ends := make([]int32, ndim)
|
||||
for i := 0; i < ndim; i++ {
|
||||
for i := range ndim {
|
||||
starts[i] = 0
|
||||
ends[i] = int32(a.Dim(i))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ func (t *Tokenizer) bpeMerge(symbols []string) []string {
|
|||
// Find the pair with the lowest merge rank.
|
||||
bestRank := -1
|
||||
bestIdx := -1
|
||||
for i := 0; i < len(symbols)-1; i++ {
|
||||
for i := range len(symbols) - 1 {
|
||||
key := symbols[i] + " " + symbols[i+1]
|
||||
if rank, ok := t.mergeRanks[key]; ok {
|
||||
if bestRank < 0 || rank < bestRank {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue