feat(buffer): implement RFC buffer lifecycle with safety checks
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
7b75f95ee7
commit
530bfd55f9
3 changed files with 193 additions and 0 deletions
110
buffer.go
Normal file
110
buffer.go
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
package cgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Buffer owns byte memory that can be passed safely to C.
|
||||||
|
//
|
||||||
|
// buffer := NewBuffer(16)
|
||||||
|
// defer buffer.Free()
|
||||||
|
// n := buffer.CopyFrom([]byte("payload"))
|
||||||
|
// _ = buffer.Bytes()[:n]
|
||||||
|
type Buffer struct {
|
||||||
|
data []byte
|
||||||
|
length int
|
||||||
|
pointer unsafe.Pointer
|
||||||
|
freed atomic.Bool
|
||||||
|
isPinned bool
|
||||||
|
pinner runtime.Pinner
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBuffer allocates memory and pins it so Ptr can be used across C boundaries.
|
||||||
|
func NewBuffer(size int) *Buffer {
|
||||||
|
if size < 0 {
|
||||||
|
panic("cgo.NewBuffer: size must be non-negative")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]byte, size)
|
||||||
|
buffer := &Buffer{
|
||||||
|
data: data,
|
||||||
|
length: size,
|
||||||
|
}
|
||||||
|
|
||||||
|
if size > 0 {
|
||||||
|
buffer.pinner.Pin(&buffer.data[0])
|
||||||
|
buffer.isPinned = true
|
||||||
|
buffer.pointer = unsafe.Pointer(&buffer.data[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free releases the pinned memory backing slice and marks the buffer as freed.
|
||||||
|
func (b *Buffer) Free() {
|
||||||
|
if b == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !b.freed.CompareAndSwap(false, true) {
|
||||||
|
panic("cgo.Buffer.Free: double-free detected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.isPinned {
|
||||||
|
b.pinner.Unpin()
|
||||||
|
b.isPinned = false
|
||||||
|
}
|
||||||
|
|
||||||
|
b.pointer = nil
|
||||||
|
b.data = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopyFrom copies bytes from src into the buffer and returns bytes copied.
|
||||||
|
func (b *Buffer) CopyFrom(src []byte) int {
|
||||||
|
b.assertNotFreed()
|
||||||
|
if len(src) == 0 || b.length == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
copied := len(src)
|
||||||
|
if copied > b.length {
|
||||||
|
copied = b.length
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(b.data[:copied], src[:copied])
|
||||||
|
return copied
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes returns a mutable byte slice backed by the buffer memory.
|
||||||
|
func (b *Buffer) Bytes() []byte {
|
||||||
|
b.assertNotFreed()
|
||||||
|
return b.data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ptr returns the raw pointer to the buffer.
|
||||||
|
func (b *Buffer) Ptr() unsafe.Pointer {
|
||||||
|
b.assertNotFreed()
|
||||||
|
return b.pointer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the current buffer length.
|
||||||
|
func (b *Buffer) Len() int {
|
||||||
|
b.assertNotFreed()
|
||||||
|
return b.length
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFreed reports whether Free has already been called.
|
||||||
|
func (b *Buffer) IsFreed() bool {
|
||||||
|
return b.freed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) assertNotFreed() {
|
||||||
|
if b == nil {
|
||||||
|
panic("cgo.Buffer: use-after-free detected: buffer is nil")
|
||||||
|
}
|
||||||
|
if b.freed.Load() {
|
||||||
|
panic("cgo.Buffer: use-after-free detected")
|
||||||
|
}
|
||||||
|
}
|
||||||
80
buffer_test.go
Normal file
80
buffer_test.go
Normal file
|
|
@ -0,0 +1,80 @@
|
||||||
|
package cgo
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestBufferLifecycleAndCopy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
const capacity = 8
|
||||||
|
buffer := NewBuffer(capacity)
|
||||||
|
defer buffer.Free()
|
||||||
|
|
||||||
|
copied := buffer.CopyFrom([]byte{1, 2, 3, 4, 5})
|
||||||
|
if copied != 5 {
|
||||||
|
t.Fatalf("expected 5 bytes copied, got %d", copied)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := buffer.Len(); got != capacity {
|
||||||
|
t.Fatalf("expected buffer length %d, got %d", capacity, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := buffer.Bytes()[:5], []byte{1, 2, 3, 4, 5}; string(got) != string(want) {
|
||||||
|
t.Fatalf("expected bytes %v, got %v", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufferCopyClipsToCapacity(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
buffer := NewBuffer(2)
|
||||||
|
defer buffer.Free()
|
||||||
|
|
||||||
|
copied := buffer.CopyFrom([]byte("abc"))
|
||||||
|
if copied != 2 {
|
||||||
|
t.Fatalf("expected 2 bytes copied, got %d", copied)
|
||||||
|
}
|
||||||
|
if got, want := buffer.Bytes(), []byte("ab"); string(got) != string(want) {
|
||||||
|
t.Fatalf("expected bytes %q, got %q", string(want), string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufferDoubleFreePanics(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
buffer := NewBuffer(1)
|
||||||
|
buffer.Free()
|
||||||
|
assertPanics(t, "double-free", func() {
|
||||||
|
buffer.Free()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufferUseAfterFreePanics(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
buffer := NewBuffer(4)
|
||||||
|
buffer.Free()
|
||||||
|
|
||||||
|
assertPanics(t, "use-after-free", func() {
|
||||||
|
_ = buffer.Len()
|
||||||
|
})
|
||||||
|
assertPanics(t, "use-after-free", func() {
|
||||||
|
buffer.CopyFrom([]byte("x"))
|
||||||
|
})
|
||||||
|
assertPanics(t, "use-after-free", func() {
|
||||||
|
_ = buffer.Bytes()
|
||||||
|
})
|
||||||
|
assertPanics(t, "use-after-free", func() {
|
||||||
|
_ = buffer.Ptr()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertPanics(t *testing.T, want string, fn func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Fatalf("expected panic for %s, got none", want)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
fn()
|
||||||
|
}
|
||||||
3
go.mod
Normal file
3
go.mod
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
module dappco.re/go/cgo
|
||||||
|
|
||||||
|
go 1.22
|
||||||
Loading…
Add table
Reference in a new issue