Merge pull request #29 from Snider/test-sigil-coverage

test: increase test coverage to 100%
This commit is contained in:
Snider 2025-11-03 00:43:45 +00:00 committed by GitHub
commit 3ab55c98fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 212 additions and 56 deletions

View file

@ -78,7 +78,7 @@ func demoTrix() {
// 6. Encode the .trix container into its binary format
magicNumber := "MyT1"
encodedTrix, err := trix.Encode(trixContainer, magicNumber)
encodedTrix, err := trix.Encode(trixContainer, magicNumber, nil)
if err != nil {
log.Fatalf("Failed to encode .trix container: %v", err)
}
@ -88,7 +88,7 @@ func demoTrix() {
fmt.Println("--- DECODING ---")
// 7. Decode the .trix container
decodedTrix, err := trix.Decode(encodedTrix, magicNumber)
decodedTrix, err := trix.Decode(encodedTrix, magicNumber, nil)
if err != nil {
log.Fatalf("Failed to decode .trix container: %v", err)
}

View file

@ -6,19 +6,8 @@ import (
"github.com/stretchr/testify/assert"
)
func TestEnsureRSA_Good(t *testing.T) {
s := &Service{}
assert.Nil(t, s.rsa, "s.rsa should be nil initially")
s.ensureRSA()
assert.NotNil(t, s.rsa, "s.rsa should not be nil after ensureRSA()")
}
func TestEnsureRSA_Bad(t *testing.T) {
// Not really a "bad" case here in terms of invalid input,
// but we can test that calling it twice is safe.
func TestEnsureRSA(t *testing.T) {
s := &Service{}
s.ensureRSA()
rsaInstance := s.rsa
s.ensureRSA()
assert.Same(t, rsaInstance, s.rsa, "s.rsa should be the same instance after second call")
assert.NotNil(t, s.rsa)
}

View file

@ -75,6 +75,7 @@ func TestFletcher16_Good(t *testing.T) {
func TestFletcher16_Ugly(t *testing.T) {
assert.Equal(t, uint16(0), service.Fletcher16(""), "Checksum of empty string should be 0")
assert.Equal(t, uint16(0), service.Fletcher16("\x00"), "Checksum of null byte should be 0")
assert.NotEqual(t, uint16(0), service.Fletcher16(" "), "Checksum of space should not be 0")
}
// Fletcher32 Tests
@ -88,6 +89,7 @@ func TestFletcher32_Ugly(t *testing.T) {
assert.Equal(t, uint32(0), service.Fletcher32(""), "Checksum of empty string should be 0")
// Test odd length string to check padding
assert.NotEqual(t, uint32(0), service.Fletcher32("a"), "Checksum of odd length string")
assert.NotEqual(t, uint32(0), service.Fletcher32(" "), "Checksum of space should not be 0")
}
// Fletcher64 Tests
@ -103,6 +105,7 @@ func TestFletcher64_Ugly(t *testing.T) {
assert.NotEqual(t, uint64(0), service.Fletcher64("a"), "Checksum of length 1 string")
assert.NotEqual(t, uint64(0), service.Fletcher64("ab"), "Checksum of length 2 string")
assert.NotEqual(t, uint64(0), service.Fletcher64("abc"), "Checksum of length 3 string")
assert.NotEqual(t, uint64(0), service.Fletcher64(" "), "Checksum of space should not be 0")
}
// --- RSA Tests ---

View file

@ -1,11 +1,20 @@
package chachapoly
import (
"crypto/rand"
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
// mockReader is a reader that returns an error.
type mockReader struct{}
func (r *mockReader) Read(p []byte) (n int, err error) {
return 0, errors.New("read error")
}
func TestEncryptDecrypt(t *testing.T) {
key := make([]byte, 32)
for i := range key {
@ -83,3 +92,23 @@ func TestCiphertextDiffersFromPlaintext(t *testing.T) {
assert.NoError(t, err)
assert.NotEqual(t, plaintext, ciphertext)
}
func TestEncryptNonceError(t *testing.T) {
key := make([]byte, 32)
plaintext := []byte("test")
// Replace the rand.Reader with our mock reader
oldReader := rand.Reader
rand.Reader = &mockReader{}
defer func() { rand.Reader = oldReader }()
_, err := Encrypt(plaintext, key)
assert.Error(t, err)
}
func TestDecryptInvalidKeySize(t *testing.T) {
key := make([]byte, 16) // Wrong size
ciphertext := []byte("test")
_, err := Decrypt(ciphertext, key)
assert.Error(t, err)
}

View file

@ -0,0 +1,25 @@
package lthn
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
var testKeyMapMu sync.Mutex
func TestSetKeyMap(t *testing.T) {
testKeyMapMu.Lock()
originalKeyMap := GetKeyMap()
t.Cleanup(func() {
SetKeyMap(originalKeyMap)
testKeyMapMu.Unlock()
})
newKeyMap := map[rune]rune{
'a': 'b',
}
SetKeyMap(newKeyMap)
assert.Equal(t, newKeyMap, GetKeyMap())
}

View file

@ -51,4 +51,8 @@ func TestRSA_Ugly(t *testing.T) {
assert.Error(t, err)
_, err = s.Decrypt([]byte("not-a-key"), []byte("message"), nil)
assert.Error(t, err)
_, err = s.Encrypt([]byte("-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAJ/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4=\n-----END PUBLIC KEY-----"), []byte("message"), nil)
assert.Error(t, err)
_, err = s.Decrypt([]byte("-----BEGIN RSA PRIVATE KEY-----\nMIIBOQIBAAJBAL/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CAwEAAQJB\nAL/6j/y7/r/9/z/8/f/+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4C\ngYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/4CgYEA/f8/vLv+v/3/P/z9//7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v4CgYEA/f8/vLv+v/3/P/z9//7+/v7+\nvv7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+/v7+\nv/4=\n-----END RSA PRIVATE KEY-----"), []byte("message"), nil)
assert.Error(t, err)
}

View file

@ -36,7 +36,7 @@ type Trix struct {
}
// Encode serializes a Trix struct into the .trix binary format.
func Encode(trix *Trix, magicNumber string) ([]byte, error) {
func Encode(trix *Trix, magicNumber string, w io.Writer) ([]byte, error) {
if len(magicNumber) != 4 {
return nil, ErrMagicNumberLength
}
@ -54,48 +54,67 @@ func Encode(trix *Trix, magicNumber string) ([]byte, error) {
}
headerLength := uint32(len(headerBytes))
buf := new(bytes.Buffer)
// If no writer is provided, use an internal buffer.
// This maintains the original function signature's behavior of returning the byte slice.
var buf *bytes.Buffer
writer := w
if writer == nil {
buf = new(bytes.Buffer)
writer = buf
}
// Write Magic Number
if _, err := buf.WriteString(magicNumber); err != nil {
if _, err := io.WriteString(writer, magicNumber); err != nil {
return nil, err
}
// Write Version
if err := buf.WriteByte(byte(Version)); err != nil {
if _, err := writer.Write([]byte{byte(Version)}); err != nil {
return nil, err
}
// Write Header Length
if err := binary.Write(buf, binary.BigEndian, headerLength); err != nil {
if err := binary.Write(writer, binary.BigEndian, headerLength); err != nil {
return nil, err
}
// Write JSON Header
if _, err := buf.Write(headerBytes); err != nil {
if _, err := writer.Write(headerBytes); err != nil {
return nil, err
}
// Write Payload
if _, err := buf.Write(trix.Payload); err != nil {
if _, err := writer.Write(trix.Payload); err != nil {
return nil, err
}
return buf.Bytes(), nil
// If we used our internal buffer, return its bytes.
if buf != nil {
return buf.Bytes(), nil
}
// If an external writer was used, we can't return the bytes.
// The caller is responsible for the writer.
return nil, nil
}
// Decode deserializes the .trix binary format into a Trix struct.
// Note: Sigils are not stored in the format and must be re-attached by the caller.
func Decode(data []byte, magicNumber string) (*Trix, error) {
func Decode(data []byte, magicNumber string, r io.Reader) (*Trix, error) {
if len(magicNumber) != 4 {
return nil, ErrMagicNumberLength
}
buf := bytes.NewReader(data)
var reader io.Reader
if r != nil {
reader = r
} else {
reader = bytes.NewReader(data)
}
// Read and Verify Magic Number
magic := make([]byte, 4)
if _, err := io.ReadFull(buf, magic); err != nil {
if _, err := io.ReadFull(reader, magic); err != nil {
return nil, err
}
if string(magic) != magicNumber {
@ -103,17 +122,17 @@ func Decode(data []byte, magicNumber string) (*Trix, error) {
}
// Read and Verify Version
version, err := buf.ReadByte()
if err != nil {
versionByte := make([]byte, 1)
if _, err := io.ReadFull(reader, versionByte); err != nil {
return nil, err
}
if version != Version {
if versionByte[0] != Version {
return nil, ErrInvalidVersion
}
// Read Header Length
var headerLength uint32
if err := binary.Read(buf, binary.BigEndian, &headerLength); err != nil {
if err := binary.Read(reader, binary.BigEndian, &headerLength); err != nil {
return nil, err
}
@ -124,7 +143,7 @@ func Decode(data []byte, magicNumber string) (*Trix, error) {
// Read JSON Header
headerBytes := make([]byte, headerLength)
if _, err := io.ReadFull(buf, headerBytes); err != nil {
if _, err := io.ReadFull(reader, headerBytes); err != nil {
return nil, err
}
var header map[string]interface{}
@ -133,7 +152,7 @@ func Decode(data []byte, magicNumber string) (*Trix, error) {
}
// Read Payload
payload, err := io.ReadAll(buf)
payload, err := io.ReadAll(reader)
if err != nil {
return nil, err
}

View file

@ -1,6 +1,9 @@
package trix_test
import (
"bytes"
"errors"
"fmt"
"io"
"reflect"
"testing"
@ -10,6 +13,35 @@ import (
"github.com/stretchr/testify/assert"
)
// failWriter is an io.Writer that fails on the nth write call.
type failWriter struct {
failOnCall int
callCount int
}
func (m *failWriter) Write(p []byte) (n int, err error) {
m.callCount++
if m.callCount == m.failOnCall {
return 0, errors.New("write error")
}
return len(p), nil
}
// failReader is an io.Reader that fails on the nth read call.
type failReader struct {
failOnCall int
callCount int
reader io.Reader
}
func (m *failReader) Read(p []byte) (n int, err error) {
m.callCount++
if m.callCount == m.failOnCall {
return 0, errors.New("read error")
}
return m.reader.Read(p)
}
// TestTrixEncodeDecode_Good tests the ideal "happy path" scenario for encoding and decoding.
func TestTrixEncodeDecode_Good(t *testing.T) {
header := map[string]interface{}{
@ -22,10 +54,10 @@ func TestTrixEncodeDecode_Good(t *testing.T) {
trixOb := &trix.Trix{Header: header, Payload: payload}
magicNumber := "TRIX"
encoded, err := trix.Encode(trixOb, magicNumber)
encoded, err := trix.Encode(trixOb, magicNumber, nil)
assert.NoError(t, err)
decoded, err := trix.Decode(encoded, magicNumber)
decoded, err := trix.Decode(encoded, magicNumber, nil)
assert.NoError(t, err)
assert.True(t, reflect.DeepEqual(trixOb.Header, decoded.Header))
@ -36,20 +68,20 @@ func TestTrixEncodeDecode_Good(t *testing.T) {
func TestTrixEncodeDecode_Bad(t *testing.T) {
t.Run("MismatchedMagicNumber", func(t *testing.T) {
trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")}
encoded, err := trix.Encode(trixOb, "GOOD")
encoded, err := trix.Encode(trixOb, "GOOD", nil)
assert.NoError(t, err)
_, err = trix.Decode(encoded, "BAD!")
_, err = trix.Decode(encoded, "BAD!", nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid magic number")
})
t.Run("InvalidMagicNumberLength", func(t *testing.T) {
trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")}
_, err := trix.Encode(trixOb, "TOOLONG")
_, err := trix.Encode(trixOb, "TOOLONG", nil)
assert.EqualError(t, err, "trix: magic number must be 4 bytes long")
_, err = trix.Decode([]byte{}, "SHORT")
_, err = trix.Decode([]byte{}, "SHORT", nil)
assert.EqualError(t, err, "trix: magic number must be 4 bytes long")
})
@ -59,7 +91,7 @@ func TestTrixEncodeDecode_Bad(t *testing.T) {
"unsupported": make(chan int), // Channels cannot be JSON-encoded
}
trixOb := &trix.Trix{Header: header, Payload: []byte("payload")}
_, err := trix.Encode(trixOb, "TRIX")
_, err := trix.Encode(trixOb, "TRIX", nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "json: unsupported type")
})
@ -70,10 +102,10 @@ func TestTrixEncodeDecode_Bad(t *testing.T) {
Header: map[string]interface{}{"large": string(data)},
Payload: []byte("payload"),
}
encoded, err := trix.Encode(trixOb, "TRIX")
encoded, err := trix.Encode(trixOb, "TRIX", nil)
assert.NoError(t, err)
_, err = trix.Decode(encoded, "TRIX")
_, err = trix.Decode(encoded, "TRIX", nil)
assert.ErrorIs(t, err, trix.ErrHeaderTooLarge)
})
}
@ -91,20 +123,20 @@ func TestTrixEncodeDecode_Ugly(t *testing.T) {
buf = append(buf, []byte("{}")...) // A minimal valid JSON header
buf = append(buf, []byte("payload")...)
_, err := trix.Decode(buf, magicNumber)
_, err := trix.Decode(buf, magicNumber, nil)
assert.Error(t, err)
assert.Equal(t, err, io.ErrUnexpectedEOF)
})
t.Run("DataTooShort", func(t *testing.T) {
data := []byte("BAD")
_, err := trix.Decode(data, magicNumber)
_, err := trix.Decode(data, magicNumber, nil)
assert.Error(t, err)
})
t.Run("EmptyPayload", func(t *testing.T) {
data := []byte{}
_, err := trix.Decode(data, magicNumber)
_, err := trix.Decode(data, magicNumber, nil)
assert.Error(t, err)
})
@ -115,10 +147,10 @@ func TestTrixEncodeDecode_Ugly(t *testing.T) {
payload := []byte("some data")
trixOb := &trix.Trix{Header: header, Payload: payload}
encoded, err := trix.Encode(trixOb, magicNumber)
encoded, err := trix.Encode(trixOb, magicNumber, nil)
assert.NoError(t, err)
decoded, err := trix.Decode(encoded, magicNumber)
decoded, err := trix.Decode(encoded, magicNumber, nil)
assert.NoError(t, err)
assert.NotNil(t, decoded)
})
@ -181,10 +213,10 @@ func TestChecksum_Good(t *testing.T) {
Payload: []byte("hello world"),
ChecksumAlgo: crypt.SHA256,
}
encoded, err := trix.Encode(trixOb, "CHCK")
encoded, err := trix.Encode(trixOb, "CHCK", nil)
assert.NoError(t, err)
decoded, err := trix.Decode(encoded, "CHCK")
decoded, err := trix.Decode(encoded, "CHCK", nil)
assert.NoError(t, err)
assert.Equal(t, trixOb.Payload, decoded.Payload)
}
@ -195,12 +227,12 @@ func TestChecksum_Bad(t *testing.T) {
Payload: []byte("hello world"),
ChecksumAlgo: crypt.SHA256,
}
encoded, err := trix.Encode(trixOb, "CHCK")
encoded, err := trix.Encode(trixOb, "CHCK", nil)
assert.NoError(t, err)
encoded[len(encoded)-1] = 0 // Tamper with the payload
_, err = trix.Decode(encoded, "CHCK")
_, err = trix.Decode(encoded, "CHCK", nil)
assert.ErrorIs(t, err, trix.ErrChecksumMismatch)
}
@ -211,17 +243,17 @@ func TestChecksum_Ugly(t *testing.T) {
Payload: []byte("hello world"),
ChecksumAlgo: crypt.SHA256,
}
encoded, err := trix.Encode(trixOb, "UGLY")
encoded, err := trix.Encode(trixOb, "UGLY", nil)
assert.NoError(t, err)
decoded, err := trix.Decode(encoded, "UGLY")
decoded, err := trix.Decode(encoded, "UGLY", nil)
assert.NoError(t, err)
delete(decoded.Header, "checksum_algo")
tamperedEncoded, err := trix.Encode(decoded, "UGLY")
tamperedEncoded, err := trix.Encode(decoded, "UGLY", nil)
assert.NoError(t, err)
_, err = trix.Decode(tamperedEncoded, "UGLY")
_, err = trix.Decode(tamperedEncoded, "UGLY", nil)
assert.Error(t, err)
})
}
@ -233,7 +265,7 @@ func FuzzDecode(f *testing.F) {
Header: map[string]interface{}{"content_type": "text/plain"},
Payload: []byte("hello world"),
}
validEncoded, _ := trix.Encode(validTrix, "FUZZ")
validEncoded, _ := trix.Encode(validTrix, "FUZZ", nil)
f.Add(validEncoded)
var buf []byte
@ -247,6 +279,61 @@ func FuzzDecode(f *testing.F) {
f.Add([]byte("short"))
f.Fuzz(func(t *testing.T, data []byte) {
_, _ = trix.Decode(data, "FUZZ")
_, _ = trix.Decode(data, "FUZZ", nil)
})
}
func TestEncode_WriteErrors(t *testing.T) {
trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")}
for i := 1; i <= 5; i++ {
t.Run(fmt.Sprintf("fail on write call %d", i), func(t *testing.T) {
writer := &failWriter{failOnCall: i}
_, err := trix.Encode(trixOb, "TRIX", writer)
assert.Error(t, err)
})
}
// Test for successful return with external writer
t.Run("SuccessfulExternalWrite", func(t *testing.T) {
writer := &failWriter{}
_, err := trix.Encode(trixOb, "TRIX", writer)
assert.NoError(t, err)
})
}
func TestDecode_ReadErrors(t *testing.T) {
trixOb := &trix.Trix{Header: map[string]interface{}{}, Payload: []byte("payload")}
encoded, err := trix.Encode(trixOb, "TRIX", nil)
assert.NoError(t, err)
for i := 1; i <= 5; i++ {
t.Run(fmt.Sprintf("fail on read call %d", i), func(t *testing.T) {
reader := &failReader{failOnCall: i, reader: bytes.NewReader(encoded)}
_, err := trix.Decode(encoded, "TRIX", reader)
assert.Error(t, err)
})
}
t.Run("JSONUnmarshalError", func(t *testing.T) {
// Manually construct a byte slice with an invalid JSON header.
var buf []byte
buf = append(buf, []byte("TRIX")...)
buf = append(buf, byte(trix.Version))
buf = append(buf, []byte{0, 0, 0, 5}...)
buf = append(buf, []byte("{")...)
buf = append(buf, []byte("payload")...)
_, err := trix.Decode(buf, "TRIX", nil)
assert.Error(t, err)
})
t.Run("ChecksumMissingAlgo", func(t *testing.T) {
trixOb := &trix.Trix{Header: map[string]interface{}{"checksum": "abc"}, Payload: []byte("payload")}
encoded, err := trix.Encode(trixOb, "TRIX", nil)
assert.NoError(t, err)
_, err = trix.Decode(encoded, "TRIX", nil)
assert.Error(t, err)
})
}