diff --git a/pkg/trix/fuzz_test.go b/pkg/trix/fuzz_test.go new file mode 100644 index 0000000..28565d7 --- /dev/null +++ b/pkg/trix/fuzz_test.go @@ -0,0 +1,34 @@ +package trix + +import ( + "testing" +) + +func FuzzDecode(f *testing.F) { + // Seed with a valid encoded Trix object + validTrix := &Trix{ + Header: map[string]interface{}{"content_type": "text/plain"}, + Payload: []byte("hello world"), + } + validEncoded, _ := Encode(validTrix, "FUZZ") + f.Add(validEncoded) + + // Seed with the corrupted header length from the ugly test + var buf []byte + buf = append(buf, []byte("UGLY")...) + buf = append(buf, byte(Version)) + buf = append(buf, []byte{0, 0, 3, 232}...) // BigEndian representation of 1000 + buf = append(buf, []byte("{}")...) + buf = append(buf, []byte("payload")...) + f.Add(buf) + + // Seed with a short, invalid input + f.Add([]byte("short")) + + f.Fuzz(func(t *testing.T, data []byte) { + // The fuzzer will generate random data here. + // We just need to call our function and make sure it doesn't panic. + // The fuzzer will report any crashes as failures. + _, _ = Decode(data, "FUZZ") + }) +} diff --git a/pkg/trix/testdata/fuzz/FuzzDecode/d02802ef987b399b b/pkg/trix/testdata/fuzz/FuzzDecode/d02802ef987b399b new file mode 100644 index 0000000..729aabe --- /dev/null +++ b/pkg/trix/testdata/fuzz/FuzzDecode/d02802ef987b399b @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("FUZZ\x02in\"\"") diff --git a/pkg/trix/trix.go b/pkg/trix/trix.go index fba0d7e..61c88d7 100644 --- a/pkg/trix/trix.go +++ b/pkg/trix/trix.go @@ -13,7 +13,8 @@ import ( ) const ( - Version = 2 + Version = 2 + MaxHeaderSize = 16 * 1024 * 1024 // 16 MB ) var ( @@ -22,6 +23,7 @@ var ( ErrMagicNumberLength = errors.New("trix: magic number must be 4 bytes long") ErrNilSigil = errors.New("trix: sigil cannot be nil") ErrChecksumMismatch = errors.New("trix: checksum mismatch") + ErrHeaderTooLarge = errors.New("trix: header size exceeds maximum allowed") ) // Trix represents the structure of a .trix file. @@ -115,6 +117,11 @@ func Decode(data []byte, magicNumber string) (*Trix, error) { return nil, err } + // Sanity check the header length to prevent massive allocations. + if headerLength > MaxHeaderSize { + return nil, ErrHeaderTooLarge + } + // Read JSON Header headerBytes := make([]byte, headerLength) if _, err := io.ReadFull(buf, headerBytes); err != nil {