From 14ed08b7f92fff374a671d7682f7ca40d033f3cc Mon Sep 17 00:00:00 2001 From: Virgil Date: Fri, 3 Apr 2026 19:15:01 +0000 Subject: [PATCH] feat(cgo): add cgo call wrapper Co-Authored-By: Virgil --- call_test.go | 51 ++++++++++++++ call_test_support.go | 38 ++++++++++ string_conversion.go | 161 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 249 insertions(+), 1 deletion(-) create mode 100644 call_test.go create mode 100644 call_test_support.go diff --git a/call_test.go b/call_test.go new file mode 100644 index 0000000..28d386f --- /dev/null +++ b/call_test.go @@ -0,0 +1,51 @@ +package cgo + +import ( + "testing" +) + +func TestSizeTAndIntConversions(t *testing.T) { + t.Parallel() + + const exampleLength = 12 + if got := int(SizeT(exampleLength)); got != exampleLength { + t.Fatalf("expected SizeT(%d) to be %d, got %d", exampleLength, exampleLength, got) + } + + if got := Int(-1); got != -1 { + t.Fatalf("expected Int(-1) to be -1, got %v", got) + } + + if got := Int(4096); got != 4096 { + t.Fatalf("expected Int(4096) to be 4096, got %v", got) + } +} + +func TestCallWrapsZeroAndNonZeroReturns(t *testing.T) { + t.Parallel() + + payload := []byte("agent") + buffer := NewBuffer(len(payload)) + defer buffer.Free() + buffer.CopyFrom(payload) + + if err := Call(callSumLengthFunction(), buffer.Ptr(), SizeT(len(payload))); err != nil { + t.Fatalf("expected success, got error: %v", err) + } + + if err := Call(callFailureFunction()); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestCallRejectsUnsupportedInputs(t *testing.T) { + t.Parallel() + + assertPanics(t, "nil function pointer", func() { + _ = Call(nil) + }) + + assertPanics(t, "unsupported argument count", func() { + _ = Call(callFailureFunction(), 1, 2, 3, 4) + }) +} diff --git a/call_test_support.go b/call_test_support.go new file mode 100644 index 0000000..648df20 --- /dev/null +++ b/call_test_support.go @@ -0,0 +1,38 @@ +package cgo + +/* +#include +#include + +int call_sum_length(char* data, size_t length) { + if (data == NULL || length == 0) { + return 1; + } + return 0; +} + +int call_failure(void) { + return 13; +} + +uintptr_t call_sum_length_ptr(void) { + return (uintptr_t)&call_sum_length; +} + +uintptr_t call_failure_ptr(void) { + return (uintptr_t)&call_failure; +} +*/ +import "C" + +import "unsafe" + +func callSumLengthFunction() unsafe.Pointer { + function := C.call_sum_length_ptr() + return *(*unsafe.Pointer)(unsafe.Pointer(&function)) +} + +func callFailureFunction() unsafe.Pointer { + function := C.call_failure_ptr() + return *(*unsafe.Pointer)(unsafe.Pointer(&function)) +} diff --git a/string_conversion.go b/string_conversion.go index d6761f8..90c87f6 100644 --- a/string_conversion.go +++ b/string_conversion.go @@ -1,11 +1,170 @@ package cgo /* +#include #include + +typedef int (*cgo_call_int_fn0_t)(void); +typedef int (*cgo_call_int_fn1_t)(uintptr_t); +typedef int (*cgo_call_int_fn2_t)(uintptr_t, uintptr_t); +typedef int (*cgo_call_int_fn3_t)(uintptr_t, uintptr_t, uintptr_t); + +int cgo_call_0(uintptr_t fn) { + return ((cgo_call_int_fn0_t)fn)(); +} + +int cgo_call_1(uintptr_t fn, uintptr_t a0) { + return ((cgo_call_int_fn1_t)fn)(a0); +} + +int cgo_call_2(uintptr_t fn, uintptr_t a0, uintptr_t a1) { + return ((cgo_call_int_fn2_t)fn)(a0, a1); +} + +int cgo_call_3(uintptr_t fn, uintptr_t a0, uintptr_t a1, uintptr_t a2) { + return ((cgo_call_int_fn3_t)fn)(a0, a1, a2); +} */ import "C" -import "unsafe" +import ( + "fmt" + "strconv" + "unsafe" +) + +// SizeT converts a Go int into a C size_t for cgo calls. +// +// size := cgo.SizeT(len(data)) +func SizeT(value int) C.size_t { + if value < 0 { + panic("cgo.SizeT: negative values are not representable as C.size_t") + } + + if value > 0 { + sizeBits := int(unsafe.Sizeof(C.size_t(0)) * 8) + if sizeBits < strconv.IntSize { + maxSize := (uint64(1) << sizeBits) - 1 + if uint64(value) > maxSize { + panic("cgo.SizeT: value exceeds C.size_t range") + } + } + } + return C.size_t(value) +} + +// Int converts a Go int into a C int for cgo calls. +// +// rc := cgo.Int(returnCode) +func Int(value int) C.int { + if value < -2147483648 || value > 2147483647 { + panic("cgo.Int: value exceeds C.int range") + } + return C.int(value) +} + +// Call invokes a C function pointer and maps a non-zero return into a Go error. +// +// err := cgo.Call(unsafe.Pointer(C.some_function), cgo.SizeT(len(data))) +// err == nil indicates success. +func Call(function unsafe.Pointer, args ...interface{}) error { + if function == nil { + panic("cgo.Call: function pointer is nil") + } + + var result uintptr + target := uintptr(function) + + switch len(args) { + case 0: + result = uintptr(C.cgo_call_0(C.uintptr_t(target))) + case 1: + a0, ok := toSyscallArg(args[0]) + if !ok { + panic("cgo.Call: unsupported argument type") + } + result = uintptr(C.cgo_call_1(C.uintptr_t(target), C.uintptr_t(a0))) + case 2: + a0, ok := toSyscallArg(args[0]) + if !ok { + panic("cgo.Call: unsupported argument type") + } + a1, ok := toSyscallArg(args[1]) + if !ok { + panic("cgo.Call: unsupported argument type") + } + result = uintptr(C.cgo_call_2(C.uintptr_t(target), C.uintptr_t(a0), C.uintptr_t(a1))) + case 3: + a0, ok := toSyscallArg(args[0]) + if !ok { + panic("cgo.Call: unsupported argument type") + } + a1, ok := toSyscallArg(args[1]) + if !ok { + panic("cgo.Call: unsupported argument type") + } + a2, ok := toSyscallArg(args[2]) + if !ok { + panic("cgo.Call: unsupported argument type") + } + result = uintptr(C.cgo_call_3(C.uintptr_t(target), C.uintptr_t(a0), C.uintptr_t(a1), C.uintptr_t(a2))) + default: + panic("cgo.Call: unsupported argument count: max 3") + } + + if result != 0 { + return fmt.Errorf("cgo.Call: return code %d", int(result)) + } + return nil +} + +func toSyscallArg(value interface{}) (uintptr, bool) { + switch typed := value.(type) { + case nil: + return 0, true + case uintptr: + return typed, true + case unsafe.Pointer: + return uintptr(typed), true + case C.char: + return uintptr(typed), true + case C.int: + return uintptr(typed), true + case C.size_t: + return uintptr(typed), true + case *C.char: + return uintptr(unsafe.Pointer(typed)), true + case *byte: + return uintptr(unsafe.Pointer(typed)), true + case bool: + if typed { + return 1, true + } + return 0, true + case int: + return uintptr(typed), true + case int8: + return uintptr(typed), true + case int16: + return uintptr(typed), true + case int32: + return uintptr(typed), true + case int64: + return uintptr(typed), true + case uint: + return uintptr(typed), true + case uint8: + return uintptr(typed), true + case uint16: + return uintptr(typed), true + case uint32: + return uintptr(typed), true + case uint64: + return uintptr(typed), true + default: + return 0, false + } +} // GoString converts a null-terminated C string to a Go string. //