diff --git a/internal/metal/batch.go b/internal/metal/batch.go index 3c5c04e..09c32ae 100644 --- a/internal/metal/batch.go +++ b/internal/metal/batch.go @@ -3,11 +3,11 @@ package metal import ( + "cmp" "context" "fmt" "math" "slices" - "sort" "time" ) @@ -49,8 +49,8 @@ func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConf for i := range indices { indices[i] = i } - sort.Slice(indices, func(a, b int) bool { - return lengths[indices[a]] > lengths[indices[b]] + slices.SortFunc(indices, func(a, b int) int { + return cmp.Compare(lengths[b], lengths[a]) }) maxLen := lengths[indices[0]] @@ -160,8 +160,8 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat for i := range indices { indices[i] = i } - sort.Slice(indices, func(a, b int) bool { - return lengths[indices[a]] > lengths[indices[b]] + slices.SortFunc(indices, func(a, b int) int { + return cmp.Compare(lengths[b], lengths[a]) }) N := int32(len(prompts)) diff --git a/internal/metal/lora.go b/internal/metal/lora.go index 8d080b2..aab49c4 100644 --- a/internal/metal/lora.go +++ b/internal/metal/lora.go @@ -15,7 +15,7 @@ import ( "math" "os" "path/filepath" - "sort" + "slices" "strings" "unsafe" ) @@ -156,7 +156,7 @@ func (a *LoRAAdapter) SortedNames() []string { for name := range a.Layers { names = append(names, name) } - sort.Strings(names) + slices.Sort(names) return names }