feat(cgo): add cgo call wrapper
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
2e910e91ef
commit
14ed08b7f9
3 changed files with 249 additions and 1 deletions
51
call_test.go
Normal file
51
call_test.go
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
package cgo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSizeTAndIntConversions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const exampleLength = 12
|
||||
if got := int(SizeT(exampleLength)); got != exampleLength {
|
||||
t.Fatalf("expected SizeT(%d) to be %d, got %d", exampleLength, exampleLength, got)
|
||||
}
|
||||
|
||||
if got := Int(-1); got != -1 {
|
||||
t.Fatalf("expected Int(-1) to be -1, got %v", got)
|
||||
}
|
||||
|
||||
if got := Int(4096); got != 4096 {
|
||||
t.Fatalf("expected Int(4096) to be 4096, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallWrapsZeroAndNonZeroReturns(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := []byte("agent")
|
||||
buffer := NewBuffer(len(payload))
|
||||
defer buffer.Free()
|
||||
buffer.CopyFrom(payload)
|
||||
|
||||
if err := Call(callSumLengthFunction(), buffer.Ptr(), SizeT(len(payload))); err != nil {
|
||||
t.Fatalf("expected success, got error: %v", err)
|
||||
}
|
||||
|
||||
if err := Call(callFailureFunction()); err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallRejectsUnsupportedInputs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assertPanics(t, "nil function pointer", func() {
|
||||
_ = Call(nil)
|
||||
})
|
||||
|
||||
assertPanics(t, "unsupported argument count", func() {
|
||||
_ = Call(callFailureFunction(), 1, 2, 3, 4)
|
||||
})
|
||||
}
|
||||
38
call_test_support.go
Normal file
38
call_test_support.go
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
package cgo
|
||||
|
||||
/*
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
int call_sum_length(char* data, size_t length) {
|
||||
if (data == NULL || length == 0) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int call_failure(void) {
|
||||
return 13;
|
||||
}
|
||||
|
||||
uintptr_t call_sum_length_ptr(void) {
|
||||
return (uintptr_t)&call_sum_length;
|
||||
}
|
||||
|
||||
uintptr_t call_failure_ptr(void) {
|
||||
return (uintptr_t)&call_failure;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
func callSumLengthFunction() unsafe.Pointer {
|
||||
function := C.call_sum_length_ptr()
|
||||
return *(*unsafe.Pointer)(unsafe.Pointer(&function))
|
||||
}
|
||||
|
||||
func callFailureFunction() unsafe.Pointer {
|
||||
function := C.call_failure_ptr()
|
||||
return *(*unsafe.Pointer)(unsafe.Pointer(&function))
|
||||
}
|
||||
|
|
@ -1,11 +1,170 @@
|
|||
package cgo
|
||||
|
||||
/*
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
typedef int (*cgo_call_int_fn0_t)(void);
|
||||
typedef int (*cgo_call_int_fn1_t)(uintptr_t);
|
||||
typedef int (*cgo_call_int_fn2_t)(uintptr_t, uintptr_t);
|
||||
typedef int (*cgo_call_int_fn3_t)(uintptr_t, uintptr_t, uintptr_t);
|
||||
|
||||
int cgo_call_0(uintptr_t fn) {
|
||||
return ((cgo_call_int_fn0_t)fn)();
|
||||
}
|
||||
|
||||
int cgo_call_1(uintptr_t fn, uintptr_t a0) {
|
||||
return ((cgo_call_int_fn1_t)fn)(a0);
|
||||
}
|
||||
|
||||
int cgo_call_2(uintptr_t fn, uintptr_t a0, uintptr_t a1) {
|
||||
return ((cgo_call_int_fn2_t)fn)(a0, a1);
|
||||
}
|
||||
|
||||
int cgo_call_3(uintptr_t fn, uintptr_t a0, uintptr_t a1, uintptr_t a2) {
|
||||
return ((cgo_call_int_fn3_t)fn)(a0, a1, a2);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// SizeT converts a Go int into a C size_t for cgo calls.
|
||||
//
|
||||
// size := cgo.SizeT(len(data))
|
||||
func SizeT(value int) C.size_t {
|
||||
if value < 0 {
|
||||
panic("cgo.SizeT: negative values are not representable as C.size_t")
|
||||
}
|
||||
|
||||
if value > 0 {
|
||||
sizeBits := int(unsafe.Sizeof(C.size_t(0)) * 8)
|
||||
if sizeBits < strconv.IntSize {
|
||||
maxSize := (uint64(1) << sizeBits) - 1
|
||||
if uint64(value) > maxSize {
|
||||
panic("cgo.SizeT: value exceeds C.size_t range")
|
||||
}
|
||||
}
|
||||
}
|
||||
return C.size_t(value)
|
||||
}
|
||||
|
||||
// Int converts a Go int into a C int for cgo calls.
|
||||
//
|
||||
// rc := cgo.Int(returnCode)
|
||||
func Int(value int) C.int {
|
||||
if value < -2147483648 || value > 2147483647 {
|
||||
panic("cgo.Int: value exceeds C.int range")
|
||||
}
|
||||
return C.int(value)
|
||||
}
|
||||
|
||||
// Call invokes a C function pointer and maps a non-zero return into a Go error.
|
||||
//
|
||||
// err := cgo.Call(unsafe.Pointer(C.some_function), cgo.SizeT(len(data)))
|
||||
// err == nil indicates success.
|
||||
func Call(function unsafe.Pointer, args ...interface{}) error {
|
||||
if function == nil {
|
||||
panic("cgo.Call: function pointer is nil")
|
||||
}
|
||||
|
||||
var result uintptr
|
||||
target := uintptr(function)
|
||||
|
||||
switch len(args) {
|
||||
case 0:
|
||||
result = uintptr(C.cgo_call_0(C.uintptr_t(target)))
|
||||
case 1:
|
||||
a0, ok := toSyscallArg(args[0])
|
||||
if !ok {
|
||||
panic("cgo.Call: unsupported argument type")
|
||||
}
|
||||
result = uintptr(C.cgo_call_1(C.uintptr_t(target), C.uintptr_t(a0)))
|
||||
case 2:
|
||||
a0, ok := toSyscallArg(args[0])
|
||||
if !ok {
|
||||
panic("cgo.Call: unsupported argument type")
|
||||
}
|
||||
a1, ok := toSyscallArg(args[1])
|
||||
if !ok {
|
||||
panic("cgo.Call: unsupported argument type")
|
||||
}
|
||||
result = uintptr(C.cgo_call_2(C.uintptr_t(target), C.uintptr_t(a0), C.uintptr_t(a1)))
|
||||
case 3:
|
||||
a0, ok := toSyscallArg(args[0])
|
||||
if !ok {
|
||||
panic("cgo.Call: unsupported argument type")
|
||||
}
|
||||
a1, ok := toSyscallArg(args[1])
|
||||
if !ok {
|
||||
panic("cgo.Call: unsupported argument type")
|
||||
}
|
||||
a2, ok := toSyscallArg(args[2])
|
||||
if !ok {
|
||||
panic("cgo.Call: unsupported argument type")
|
||||
}
|
||||
result = uintptr(C.cgo_call_3(C.uintptr_t(target), C.uintptr_t(a0), C.uintptr_t(a1), C.uintptr_t(a2)))
|
||||
default:
|
||||
panic("cgo.Call: unsupported argument count: max 3")
|
||||
}
|
||||
|
||||
if result != 0 {
|
||||
return fmt.Errorf("cgo.Call: return code %d", int(result))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toSyscallArg(value interface{}) (uintptr, bool) {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return 0, true
|
||||
case uintptr:
|
||||
return typed, true
|
||||
case unsafe.Pointer:
|
||||
return uintptr(typed), true
|
||||
case C.char:
|
||||
return uintptr(typed), true
|
||||
case C.int:
|
||||
return uintptr(typed), true
|
||||
case C.size_t:
|
||||
return uintptr(typed), true
|
||||
case *C.char:
|
||||
return uintptr(unsafe.Pointer(typed)), true
|
||||
case *byte:
|
||||
return uintptr(unsafe.Pointer(typed)), true
|
||||
case bool:
|
||||
if typed {
|
||||
return 1, true
|
||||
}
|
||||
return 0, true
|
||||
case int:
|
||||
return uintptr(typed), true
|
||||
case int8:
|
||||
return uintptr(typed), true
|
||||
case int16:
|
||||
return uintptr(typed), true
|
||||
case int32:
|
||||
return uintptr(typed), true
|
||||
case int64:
|
||||
return uintptr(typed), true
|
||||
case uint:
|
||||
return uintptr(typed), true
|
||||
case uint8:
|
||||
return uintptr(typed), true
|
||||
case uint16:
|
||||
return uintptr(typed), true
|
||||
case uint32:
|
||||
return uintptr(typed), true
|
||||
case uint64:
|
||||
return uintptr(typed), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// GoString converts a null-terminated C string to a Go string.
|
||||
//
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue