diff --git a/call_test.go b/call_test.go index 52c86c2..b02d228 100644 --- a/call_test.go +++ b/call_test.go @@ -177,6 +177,20 @@ func TestCallSupportsCUintptrArgument(t *testing.T) { } } +func TestCallSupportsArbitraryPointerArgument(t *testing.T) { + t.Parallel() + + payload := []byte("agent") + if err := Call(callAnyPointerArgumentFunction(), &payload[0]); err != nil { + t.Fatalf("expected success for pointer argument, got %v", err) + } + + var pointer *byte + if err := Call(callAnyPointerArgumentFunction(), pointer); err == nil { + t.Fatal("expected error for nil pointer argument") + } +} + func TestCallSupportsByteSliceArgument(t *testing.T) { t.Parallel() diff --git a/call_test_support.go b/call_test_support.go index cc46645..264f053 100644 --- a/call_test_support.go +++ b/call_test_support.go @@ -90,6 +90,10 @@ int call_uintptr_argument(uintptr_t value) { return value == 42 ? 0 : 13; } +int call_any_pointer_argument(uintptr_t value) { + return value == 0 ? 13 : 0; +} + uintptr_t call_sum_length_ptr(void) { return (uintptr_t)&call_sum_length; } @@ -146,6 +150,10 @@ uintptr_t call_uintptr_argument_ptr(void) { return (uintptr_t)&call_uintptr_argument; } +uintptr_t call_any_pointer_argument_ptr(void) { + return (uintptr_t)&call_any_pointer_argument; +} + */ import "C" @@ -221,6 +229,11 @@ func callCUintptrArgumentFunction() unsafe.Pointer { return *(*unsafe.Pointer)(unsafe.Pointer(&function)) } +func callAnyPointerArgumentFunction() unsafe.Pointer { + function := C.call_any_pointer_argument_ptr() + return *(*unsafe.Pointer)(unsafe.Pointer(&function)) +} + func callWithErrnoZero() (int, error) { return WithErrno(func() C.int { return 0 diff --git a/string_conversion.go b/string_conversion.go index a082861..3c624a1 100644 --- a/string_conversion.go +++ b/string_conversion.go @@ -811,6 +811,11 @@ func toSyscallArg(value interface{}) (uintptr, bool) { default: reflected := reflect.ValueOf(value) switch reflected.Kind() { + case reflect.Pointer, reflect.UnsafePointer: + if reflected.IsNil() { + return 0, true + } + return reflected.Pointer(), true case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return uintptr(reflected.Int()), true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: