diff --git a/internal/gguf/gguf.go b/internal/gguf/gguf.go new file mode 100644 index 0000000..0c908c8 --- /dev/null +++ b/internal/gguf/gguf.go @@ -0,0 +1,332 @@ +// Package gguf provides a GGUF binary metadata parser for reading model headers. +// +// GGUF (GGML Universal File) is the file format used by llama.cpp and other +// GGML-based inference engines. This package reads the metadata key-value pairs +// from the file header without loading tensor data, enabling fast model discovery. +// +// Supports GGUF v2 (uint32 counts) and v3 (uint64 counts). +package gguf + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + "math" + "os" + "strings" +) + +// ggufMagic is the GGUF file magic number: "GGUF" in little-endian. +const ggufMagic = 0x46554747 + +// GGUF value type codes. +const ( + typeUint8 uint32 = 0 + typeInt8 uint32 = 1 + typeUint16 uint32 = 2 + typeInt16 uint32 = 3 + typeUint32 uint32 = 4 + typeInt32 uint32 = 5 + typeFloat32 uint32 = 6 + typeBool uint32 = 7 + typeString uint32 = 8 + typeArray uint32 = 9 + typeUint64 uint32 = 10 + typeInt64 uint32 = 11 + typeFloat64 uint32 = 12 +) + +// Metadata holds the interesting fields extracted from a GGUF file header. +type Metadata struct { + Architecture string // "gemma3", "llama", "qwen2" + Name string // human-readable model name + SizeLabel string // "1B", "8B", etc. + ContextLength uint32 // native context window + BlockCount uint32 // transformer layers + FileType uint32 // GGML quantisation file type + FileSize int64 // file size on disk in bytes +} + +// fileTypeNames maps GGML quantisation file type numbers to human-readable names. +var fileTypeNames = map[uint32]string{ + 0: "F32", + 1: "F16", + 2: "Q4_0", + 3: "Q4_1", + 7: "Q8_0", + 8: "Q5_0", + 9: "Q5_1", + 10: "Q2_K", + 11: "Q3_K_S", + 12: "Q3_K_M", + 13: "Q3_K_L", + 14: "Q4_K_S", + 15: "Q4_K_M", + 16: "Q5_K_S", + 17: "Q5_K_M", + 18: "Q6_K", +} + +// FileTypeName returns a human-readable name for a GGML quantisation file type. +// Unknown types return "type_N" where N is the numeric value. +func FileTypeName(ft uint32) string { + if name, ok := fileTypeNames[ft]; ok { + return name + } + return fmt.Sprintf("type_%d", ft) +} + +// ReadMetadata reads the GGUF header from the file at path and returns the +// extracted metadata. Only metadata KV pairs are read; tensor data is not loaded. +func ReadMetadata(path string) (Metadata, error) { + f, err := os.Open(path) + if err != nil { + return Metadata{}, err + } + defer f.Close() + + info, err := f.Stat() + if err != nil { + return Metadata{}, err + } + + r := bufio.NewReader(f) + + // Read and validate magic number. + var magic uint32 + if err := binary.Read(r, binary.LittleEndian, &magic); err != nil { + return Metadata{}, fmt.Errorf("reading magic: %w", err) + } + if magic != ggufMagic { + return Metadata{}, fmt.Errorf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic) + } + + // Read version. + var version uint32 + if err := binary.Read(r, binary.LittleEndian, &version); err != nil { + return Metadata{}, fmt.Errorf("reading version: %w", err) + } + if version < 2 || version > 3 { + return Metadata{}, fmt.Errorf("unsupported GGUF version: %d", version) + } + + // Read tensor count and KV count. v3 uses uint64, v2 uses uint32. + var tensorCount, kvCount uint64 + if version == 3 { + if err := binary.Read(r, binary.LittleEndian, &tensorCount); err != nil { + return Metadata{}, fmt.Errorf("reading tensor count: %w", err) + } + if err := binary.Read(r, binary.LittleEndian, &kvCount); err != nil { + return Metadata{}, fmt.Errorf("reading kv count: %w", err) + } + } else { + var tc, kc uint32 + if err := binary.Read(r, binary.LittleEndian, &tc); err != nil { + return Metadata{}, fmt.Errorf("reading tensor count: %w", err) + } + if err := binary.Read(r, binary.LittleEndian, &kc); err != nil { + return Metadata{}, fmt.Errorf("reading kv count: %w", err) + } + tensorCount = uint64(tc) + kvCount = uint64(kc) + } + _ = tensorCount // we only read metadata KVs + + // Read all KV pairs. We store interesting keys and skip the rest. + // Architecture-specific keys (e.g. llama.context_length) may appear before + // the general.architecture key, so we collect all candidates and resolve after. + var meta Metadata + meta.FileSize = info.Size() + + // candidateContextLength and candidateBlockCount store values keyed by + // their full key name (e.g. "llama.context_length") so we can match them + // against the architecture once it is known. + candidateContextLength := make(map[string]uint32) + candidateBlockCount := make(map[string]uint32) + + for i := uint64(0); i < kvCount; i++ { + key, err := readString(r) + if err != nil { + return Metadata{}, fmt.Errorf("reading key %d: %w", i, err) + } + + var valType uint32 + if err := binary.Read(r, binary.LittleEndian, &valType); err != nil { + return Metadata{}, fmt.Errorf("reading value type for key %q: %w", key, err) + } + + // Check whether this is an interesting key before reading the value. + switch { + case key == "general.architecture": + v, err := readTypedValue(r, valType) + if err != nil { + return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + } + if s, ok := v.(string); ok { + meta.Architecture = s + } + + case key == "general.name": + v, err := readTypedValue(r, valType) + if err != nil { + return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + } + if s, ok := v.(string); ok { + meta.Name = s + } + + case key == "general.file_type": + v, err := readTypedValue(r, valType) + if err != nil { + return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + } + if u, ok := v.(uint32); ok { + meta.FileType = u + } + + case key == "general.size_label": + v, err := readTypedValue(r, valType) + if err != nil { + return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + } + if s, ok := v.(string); ok { + meta.SizeLabel = s + } + + case strings.HasSuffix(key, ".context_length"): + v, err := readTypedValue(r, valType) + if err != nil { + return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + } + if u, ok := v.(uint32); ok { + candidateContextLength[key] = u + } + + case strings.HasSuffix(key, ".block_count"): + v, err := readTypedValue(r, valType) + if err != nil { + return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + } + if u, ok := v.(uint32); ok { + candidateBlockCount[key] = u + } + + default: + // Skip uninteresting value. + if err := skipValue(r, valType); err != nil { + return Metadata{}, fmt.Errorf("skipping value for key %q: %w", key, err) + } + } + } + + // Resolve architecture-specific keys. + if meta.Architecture != "" { + prefix := meta.Architecture + "." + if v, ok := candidateContextLength[prefix+"context_length"]; ok { + meta.ContextLength = v + } + if v, ok := candidateBlockCount[prefix+"block_count"]; ok { + meta.BlockCount = v + } + } + + return meta, nil +} + +// maxStringLength is a sanity limit for GGUF string values. No metadata string +// should ever approach 1 MiB; this prevents memory exhaustion from malformed files. +const maxStringLength = 1 << 20 + +// readString reads a GGUF string: uint64 length followed by that many bytes. +func readString(r io.Reader) (string, error) { + var length uint64 + if err := binary.Read(r, binary.LittleEndian, &length); err != nil { + return "", err + } + if length > maxStringLength { + return "", fmt.Errorf("string length %d exceeds maximum %d", length, maxStringLength) + } + buf := make([]byte, length) + if _, err := io.ReadFull(r, buf); err != nil { + return "", err + } + return string(buf), nil +} + +// readTypedValue reads a value of the given GGUF type and returns it as a Go +// value. String, uint32, and uint64 types return typed values (uint64 is +// downcast to uint32 when it fits). All others are read and discarded. +func readTypedValue(r io.Reader, valType uint32) (any, error) { + switch valType { + case typeString: + return readString(r) + case typeUint32: + var v uint32 + err := binary.Read(r, binary.LittleEndian, &v) + return v, err + case typeUint64: + var v uint64 + if err := binary.Read(r, binary.LittleEndian, &v); err != nil { + return nil, err + } + if v <= math.MaxUint32 { + return uint32(v), nil + } + return v, nil + default: + // Read and discard the value, returning nil. + err := skipValue(r, valType) + return nil, err + } +} + +// skipValue reads and discards a GGUF value of the given type from r. +func skipValue(r io.Reader, valType uint32) error { + switch valType { + case typeUint8, typeInt8, typeBool: + _, err := readN(r, 1) + return err + case typeUint16, typeInt16: + _, err := readN(r, 2) + return err + case typeUint32, typeInt32, typeFloat32: + _, err := readN(r, 4) + return err + case typeUint64, typeInt64, typeFloat64: + _, err := readN(r, 8) + return err + case typeString: + var length uint64 + if err := binary.Read(r, binary.LittleEndian, &length); err != nil { + return err + } + if length > maxStringLength { + return fmt.Errorf("string length %d exceeds maximum %d", length, maxStringLength) + } + _, err := readN(r, int64(length)) + return err + case typeArray: + var elemType uint32 + if err := binary.Read(r, binary.LittleEndian, &elemType); err != nil { + return err + } + var count uint64 + if err := binary.Read(r, binary.LittleEndian, &count); err != nil { + return err + } + for i := uint64(0); i < count; i++ { + if err := skipValue(r, elemType); err != nil { + return err + } + } + return nil + default: + return fmt.Errorf("unknown GGUF value type: %d", valType) + } +} + +// readN reads and discards exactly n bytes from r. +func readN(r io.Reader, n int64) (int64, error) { + return io.CopyN(io.Discard, r, n) +} diff --git a/internal/gguf/gguf_test.go b/internal/gguf/gguf_test.go new file mode 100644 index 0000000..fe298f3 --- /dev/null +++ b/internal/gguf/gguf_test.go @@ -0,0 +1,155 @@ +package gguf + +import ( + "encoding/binary" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// writeTestGGUFOrdered creates a synthetic GGUF v3 file with KV pairs in the +// exact order specified. Each element is a [2]any{key string, value any}. +func writeTestGGUFOrdered(t *testing.T, kvs [][2]any) string { + t.Helper() + + dir := t.TempDir() + path := filepath.Join(dir, "test.gguf") + + f, err := os.Create(path) + require.NoError(t, err) + defer f.Close() + + // Magic + require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(0x46554747))) + // Version 3 + require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(3))) + // Tensor count (uint64): 0 + require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(0))) + // KV count (uint64) + require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(len(kvs)))) + + for _, kv := range kvs { + key := kv[0].(string) + writeKV(t, f, key, kv[1]) + } + + return path +} + +func writeKV(t *testing.T, f *os.File, key string, val any) { + t.Helper() + + // Key: uint64 length + bytes + require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(len(key)))) + _, err := f.Write([]byte(key)) + require.NoError(t, err) + + switch v := val.(type) { + case string: + // Type: 8 (string) + require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(8))) + // String value: uint64 length + bytes + require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(len(v)))) + _, err := f.Write([]byte(v)) + require.NoError(t, err) + case uint32: + // Type: 4 (uint32) + require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(4))) + require.NoError(t, binary.Write(f, binary.LittleEndian, v)) + default: + t.Fatalf("writeKV: unsupported value type %T", val) + } +} + +func TestReadMetadata_Gemma3(t *testing.T) { + path := writeTestGGUFOrdered(t, [][2]any{ + {"general.architecture", "gemma3"}, + {"general.name", "Test Gemma3 1B"}, + {"general.file_type", uint32(17)}, + {"general.size_label", "1B"}, + {"gemma3.context_length", uint32(32768)}, + {"gemma3.block_count", uint32(26)}, + }) + + m, err := ReadMetadata(path) + require.NoError(t, err) + + assert.Equal(t, "gemma3", m.Architecture) + assert.Equal(t, "Test Gemma3 1B", m.Name) + assert.Equal(t, uint32(17), m.FileType) + assert.Equal(t, "1B", m.SizeLabel) + assert.Equal(t, uint32(32768), m.ContextLength) + assert.Equal(t, uint32(26), m.BlockCount) + assert.Greater(t, m.FileSize, int64(0)) +} + +func TestReadMetadata_Llama(t *testing.T) { + path := writeTestGGUFOrdered(t, [][2]any{ + {"general.architecture", "llama"}, + {"general.name", "Test Llama 8B"}, + {"general.file_type", uint32(15)}, + {"general.size_label", "8B"}, + {"llama.context_length", uint32(131072)}, + {"llama.block_count", uint32(32)}, + }) + + m, err := ReadMetadata(path) + require.NoError(t, err) + + assert.Equal(t, "llama", m.Architecture) + assert.Equal(t, "Test Llama 8B", m.Name) + assert.Equal(t, uint32(15), m.FileType) + assert.Equal(t, "8B", m.SizeLabel) + assert.Equal(t, uint32(131072), m.ContextLength) + assert.Equal(t, uint32(32), m.BlockCount) + assert.Greater(t, m.FileSize, int64(0)) +} + +func TestReadMetadata_ArchAfterContextLength(t *testing.T) { + // Architecture key comes AFTER the arch-specific keys. + // The parser must handle deferred resolution of arch-prefixed keys. + path := writeTestGGUFOrdered(t, [][2]any{ + {"general.name", "Out-of-Order Model"}, + {"general.file_type", uint32(15)}, + {"general.size_label", "8B"}, + {"llama.context_length", uint32(4096)}, + {"llama.block_count", uint32(32)}, + {"general.architecture", "llama"}, + }) + + m, err := ReadMetadata(path) + require.NoError(t, err) + + assert.Equal(t, "llama", m.Architecture) + assert.Equal(t, "Out-of-Order Model", m.Name) + assert.Equal(t, uint32(4096), m.ContextLength) + assert.Equal(t, uint32(32), m.BlockCount) +} + +func TestReadMetadata_InvalidMagic(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "notgguf.bin") + + err := os.WriteFile(path, []byte("this is not a GGUF file at all"), 0644) + require.NoError(t, err) + + _, err = ReadMetadata(path) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid magic") +} + +func TestReadMetadata_FileNotFound(t *testing.T) { + _, err := ReadMetadata("/nonexistent/path/model.gguf") + require.Error(t, err) +} + +func TestFileTypeName(t *testing.T) { + assert.Equal(t, "Q4_K_M", FileTypeName(15)) + assert.Equal(t, "Q5_K_M", FileTypeName(17)) + assert.Equal(t, "Q8_0", FileTypeName(7)) + assert.Equal(t, "F16", FileTypeName(1)) + assert.Equal(t, "type_999", FileTypeName(999)) +}